diff --git a/.gitignore b/.gitignore index 6b47e3c..b240bb9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ error_log.txt __pycache__ Test ?/ .venv -permissions.json \ No newline at end of file +permissions.json +local_database.sqlite \ No newline at end of file diff --git a/bot_discord.py b/bot_discord.py index 0d8bb41..bfce4e9 100644 --- a/bot_discord.py +++ b/bot_discord.py @@ -4,13 +4,26 @@ from discord.ext import commands import importlib import cmd_discord +from modules import db + class DiscordBot(commands.Bot): def __init__(self, config, log_func): super().__init__(command_prefix="!", intents=discord.Intents.all()) self.config = config self.log = log_func # Use the logging function from bots.py + self.db_conn = None # We'll set this later self.load_commands() + def set_db_connection(self, db_conn): + """ + Store the DB connection in the bot so commands can use it. + """ + self.db_conn = db_conn + try: + db.ensure_quotes_table(self.db_conn, self.log) + except Exception as e: + self.log(f"Critical: unable to ensure quotes table: {e}", "FATAL") + def load_commands(self): """ Load all commands dynamically from cmd_discord.py. @@ -24,7 +37,7 @@ class DiscordBot(commands.Bot): async def on_command(self, ctx): """Logs every command execution at DEBUG level.""" - self.log(f"Discord Command executed: {ctx.command} by {ctx.author} in #{ctx.channel}", "DEBUG") + self.log(f"Discord Command executed: {ctx.command} by {ctx.author} in #{ctx.channel}: {ctx.message.content}", "DEBUG") async def on_ready(self): self.log(f"Discord bot is online as {self.user}", "INFO") diff --git a/bot_twitch.py b/bot_twitch.py index e247561..6241021 100644 --- a/bot_twitch.py +++ b/bot_twitch.py @@ -6,6 +6,8 @@ from twitchio.ext import commands import importlib import cmd_twitch +from modules import db + class TwitchBot(commands.Bot): def __init__(self, config, log_func): self.client_id = os.getenv("TWITCH_CLIENT_ID") @@ -14,6 +16,7 @@ class TwitchBot(commands.Bot): self.refresh_token = os.getenv("TWITCH_REFRESH_TOKEN") self.log = log_func # Use the logging function from bots.py self.config = config + self.db_conn = None # We'll set this later # 1) Initialize the parent Bot FIRST super().__init__( @@ -27,6 +30,16 @@ class TwitchBot(commands.Bot): # 2) Then load commands self.load_commands() + def set_db_connection(self, db_conn): + """ + Store the DB connection so that commands can use it. + """ + self.db_conn = db_conn + try: + db.ensure_quotes_table(self.db_conn, self.log) + except Exception as e: + self.log(f"Critical: unable to ensure quotes table: {e}", "FATAL") + async def event_message(self, message): """Logs and processes incoming Twitch messages.""" if message.echo: @@ -34,7 +47,9 @@ class TwitchBot(commands.Bot): # Log the command if it's a command if message.content.startswith("!"): - self.log(f"Twitch Command executed: {message.content} by {message.author.name} in #{message.channel.name}", "DEBUG") + _cmd = message.content[1:] # Remove the leading "!" + _cmd = _cmd.split(" ", 1)[0] + self.log(f"Twitch Command executed: {_cmd} by {message.author.name} in #{message.channel.name}: {message.content}", "DEBUG") # Process the message for command execution await self.handle_commands(message) @@ -138,5 +153,7 @@ class TwitchBot(commands.Bot): await self.refresh_access_token() except Exception as e: self.log(f"Unable to refresh Twitch token! Twitch bot will be offline!", "CRITICAL") + if "'NoneType' object has no attribute 'cancel'" in str(e): + self.log(f"The Twitch bot experienced an initialization glitch. Try starting again", "FATAL") await asyncio.sleep(5) # Wait before retrying to authenticate diff --git a/bots.py b/bots.py index 6797a04..e42c280 100644 --- a/bots.py +++ b/bots.py @@ -5,12 +5,15 @@ import asyncio import sys import time import traceback +import globals + from discord.ext import commands from dotenv import load_dotenv + from bot_discord import DiscordBot from bot_twitch import TwitchBot -import globals +from modules.db import init_db_connection, run_db_operation # Load environment variables load_dotenv() @@ -54,6 +57,9 @@ def log(message, level="INFO"): try: print(log_message) # Print to terminal + if level == "FATAL": + print(f"!!! FATAL ERROR LOGGED, SHUTTING DOWN !!!") + sys.exit(1) except Exception: pass # Prevent logging failures from crashing the bot @@ -64,13 +70,29 @@ def log(message, level="INFO"): ############################### async def main(): - global discord_bot, twitch_bot + global discord_bot, twitch_bot, db_conn + + # Before creating your DiscordBot/TwitchBot, initialize DB + db_conn = init_db_connection(config_data, log) + if not db_conn: + # If we get None, it means FATAL. We might sys.exit(1) or handle it differently. + log("Terminating bot due to no DB connection.", "FATAL") + sys.exit(1) log("Initializing bots...", "INFO") + # Create both bots discord_bot = DiscordBot(config_data, log) twitch_bot = TwitchBot(config_data, log) + # Provide DB connection to both bots + try: + discord_bot.set_db_connection(db_conn) + twitch_bot.set_db_connection(db_conn) + log(f"Initialized database connection to both bots", "INFO") + except Exception as e: + log(f"Unable to initialize database connection to one or both bots: {e}", "FATAL") + log("Starting Discord and Twitch bots...", "INFO") discord_task = asyncio.create_task(discord_bot.run(os.getenv("DISCORD_BOT_TOKEN"))) diff --git a/cmd_common/common_commands.py b/cmd_common/common_commands.py index 1eb0d0c..7dfa733 100644 --- a/cmd_common/common_commands.py +++ b/cmd_common/common_commands.py @@ -4,6 +4,8 @@ import time from modules import utility import globals +from modules.db import run_db_operation + def howl(username: str) -> str: """ Generates a howl response based on a random percentage. @@ -44,4 +46,285 @@ def greet(target_display_name: str, platform_name: str) -> str: """ Returns a greeting string for the given user displayname on a given platform. """ - return f"Hello {target_display_name}, welcome to {platform_name}!" \ No newline at end of file + return f"Hello {target_display_name}, welcome to {platform_name}!" + +###################### +# Quotes +###################### + +def create_quotes_table(db_conn, log_func): + """ + Creates the 'quotes' table if it does not exist, with the columns: + ID, QUOTE_TEXT, QUOTEE, QUOTE_CHANNEL, QUOTE_DATETIME, QUOTE_GAME, QUOTE_REMOVED + Uses a slightly different CREATE statement depending on MariaDB vs SQLite. + """ + if not db_conn: + log_func("No database connection available to create quotes table!", "FATAL") + return + + # Detect if this is SQLite or MariaDB + db_name = str(type(db_conn)).lower() + if 'sqlite3' in db_name: + # SQLite + create_table_sql = """ + CREATE TABLE IF NOT EXISTS quotes ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + QUOTE_TEXT TEXT, + QUOTEE TEXT, + QUOTE_CHANNEL TEXT, + QUOTE_DATETIME TEXT, + QUOTE_GAME TEXT, + QUOTE_REMOVED BOOLEAN DEFAULT 0 + ) + """ + else: + # Assume MariaDB + # Adjust column types as appropriate for your setup + create_table_sql = """ + CREATE TABLE IF NOT EXISTS quotes ( + ID INT PRIMARY KEY AUTO_INCREMENT, + QUOTE_TEXT TEXT, + QUOTEE VARCHAR(100), + QUOTE_CHANNEL VARCHAR(100), + QUOTE_DATETIME DATETIME DEFAULT CURRENT_TIMESTAMP, + QUOTE_GAME VARCHAR(200), + QUOTE_REMOVED BOOLEAN DEFAULT FALSE + ) + """ + + run_db_operation(db_conn, "write", create_table_sql, log_func=log_func) + + +async def handle_quote_command(db_conn, log_func, is_discord: bool, ctx, args, get_twitch_game_for_channel=None): + """ + Core logic for !quote command, shared by both Discord and Twitch. + - `db_conn`: your active DB connection + - `log_func`: your log(...) function + - `is_discord`: True if this command is being called from Discord, False if from Twitch + - `ctx`: the context object (discord.py ctx or twitchio context) + - `args`: a list of arguments (e.g. ["add", "some quote text..."] or ["remove", "3"] or ["2"] etc.) + - `get_twitch_game_for_channel`: function(channel_name) -> str or None + + Behavior: + 1) `!quote add some text here` + -> Adds a new quote, stores channel=Discord or twitch channel name, game if twitch. + 2) `!quote remove N` + -> Mark quote #N as removed. + 3) `!quote N` + -> Retrieve quote #N, if not removed. + 4) `!quote` (no args) + -> Retrieve a random (not-removed) quote. + """ + # If no subcommand, treat as "random" + if len(args) == 0: + return await retrieve_random_quote(db_conn, log_func, is_discord, ctx) + + sub = args[0].lower() + + if sub == "add": + # everything after "add" is the quote text + quote_text = " ".join(args[1:]).strip() + if not quote_text: + return await send_message(ctx, "Please provide the quote text after 'add'.") + await add_new_quote(db_conn, log_func, is_discord, ctx, quote_text, get_twitch_game_for_channel) + elif sub == "remove": + if len(args) < 2: + return await send_message(ctx, "Please specify which quote ID to remove.") + await remove_quote(db_conn, log_func, ctx, args[1]) + else: + # Possibly a quote ID + if sub.isdigit(): + quote_id = int(sub) + await retrieve_specific_quote(db_conn, log_func, ctx, quote_id) + else: + # unrecognized subcommand => fallback to random + await retrieve_random_quote(db_conn, log_func, is_discord, ctx) + + +async def add_new_quote(db_conn, log_func, is_discord, ctx, quote_text, get_twitch_game_for_channel): + """ + Insert a new quote into the DB. + QUOTEE = the user who typed the command + QUOTE_CHANNEL = "Discord" or the twitch channel name + QUOTE_GAME = The current game if from Twitch, None if from Discord + QUOTE_REMOVED = false by default + QUOTE_DATETIME = current date/time (or DB default) + """ + user_name = get_author_name(ctx, is_discord) + channel_name = "Discord" if is_discord else get_channel_name(ctx) + game_name = None + if not is_discord and get_twitch_game_for_channel: + # Attempt to get the current game from the Twitch API (placeholder function) + game_name = get_twitch_game_for_channel(channel_name) # might return str or None + + # Insert quote + insert_sql = """ + INSERT INTO quotes (QUOTE_TEXT, QUOTEE, QUOTE_CHANNEL, QUOTE_DATETIME, QUOTE_GAME, QUOTE_REMOVED) + VALUES (?, ?, ?, CURRENT_TIMESTAMP, ?, 0) + """ + # For MariaDB, parameter placeholders are often %s, but if you set paramstyle='qmark', it can use ? as well. + # Adjust if needed for your environment. + params = (quote_text, user_name, channel_name, game_name) + + result = run_db_operation(db_conn, "write", insert_sql, params, log_func=log_func) + if result is not None: + await send_message(ctx, "Quote added successfully!") + else: + await send_message(ctx, "Failed to add quote.") + + +async def remove_quote(db_conn, log_func, ctx, quote_id_str): + """ + Mark quote #ID as removed (QUOTE_REMOVED=1). + """ + if not quote_id_str.isdigit(): + return await send_message(ctx, f"'{quote_id_str}' is not a valid quote ID.") + + quote_id = int(quote_id_str) + remover_user = str(ctx.author.name) + + # Mark as removed + update_sql = """ + UPDATE quotes + SET QUOTE_REMOVED = 1, + QUOTE_REMOVED_BY = ? + WHERE ID = ? + AND QUOTE_REMOVED = 0 + """ + params = (remover_user, quote_id) + rowcount = run_db_operation(db_conn, "update", update_sql, params, log_func=log_func) + + if rowcount and rowcount > 0: + await send_message(ctx, f"Removed quote #{quote_id}.") + else: + await send_message(ctx, "Could not remove that quote (maybe it's already removed or doesn't exist).") + + +async def retrieve_specific_quote(db_conn, log_func, ctx, quote_id): + """ + Retrieve a specific quote by ID, if not removed. + If not found, or removed, inform user of the valid ID range (1 - {max_id}) + If no quotes exist at all, say "No quotes are created yet." + """ + # First, see if we have any quotes at all + max_id = get_max_quote_id(db_conn, log_func) + if max_id < 1: + return await send_message(ctx, "No quotes are created yet.") + + # Query for that specific quote + select_sql = """ + SELECT + ID, + QUOTE_TEXT, + QUOTEE, + QUOTE_CHANNEL, + QUOTE_DATETIME, + QUOTE_GAME, + QUOTE_REMOVED, + QUOTE_REMOVED_BY + FROM quotes + WHERE ID = ? + """ + rows = run_db_operation(db_conn, "read", select_sql, (quote_id,), log_func=log_func) + + if not rows: + # no match + return await send_message(ctx, f"I couldn't find that quote (1-{max_id}).") + + row = rows[0] + quote_number = row[0] + quote_text = row[1] + quotee = row[2] + quote_channel = row[3] + quote_datetime = row[4] + quote_game = row[5] + quote_removed = row[6] + quote_removed_by = row[7] if row[7] else "Unknown" + + if quote_removed == 1: + # It's removed + await send_message(ctx, f"Quote {quote_number}: [REMOVED by {quote_removed_by}]") + else: + # It's not removed + await send_message(ctx, f"Quote {quote_number}: {quote_text}") + + +async def retrieve_random_quote(db_conn, log_func, is_discord, ctx): + """ + Grab a random quote (QUOTE_REMOVED=0). + If no quotes exist or all removed, respond with "No quotes are created yet." + """ + # First check if we have any quotes + max_id = get_max_quote_id(db_conn, log_func) + if max_id < 1: + return await send_message(ctx, "No quotes are created yet.") + + # We have quotes, try selecting a random one from the not-removed set + if is_sqlite(db_conn): + random_sql = """ + SELECT ID, QUOTE_TEXT + FROM quotes + WHERE QUOTE_REMOVED = 0 + ORDER BY RANDOM() + LIMIT 1 + """ + else: + # MariaDB uses RAND() + random_sql = """ + SELECT ID, QUOTE_TEXT + FROM quotes + WHERE QUOTE_REMOVED = 0 + ORDER BY RAND() + LIMIT 1 + """ + + rows = run_db_operation(db_conn, "read", random_sql, log_func=log_func) + if not rows: + return await send_message(ctx, "No quotes are created yet.") + + quote_number, quote_text = rows[0] + await send_message(ctx, f"Quote {quote_number}: {quote_text}") + + +def get_max_quote_id(db_conn, log_func): + """ + Return the highest ID in the quotes table, or 0 if empty. + """ + sql = "SELECT MAX(ID) FROM quotes" + rows = run_db_operation(db_conn, "read", sql, log_func=log_func) + if rows and rows[0] and rows[0][0] is not None: + return rows[0][0] + return 0 + + +def is_sqlite(db_conn): + return 'sqlite3' in str(type(db_conn)).lower() + + +def get_author_name(ctx, is_discord): + """ + Return the name/username of the command author. + For Discord, it's ctx.author.display_name (or ctx.author.name). + For Twitch (twitchio), it's ctx.author.name. + """ + if is_discord: + return str(ctx.author.display_name) + else: + return str(ctx.author.name) + + +def get_channel_name(ctx): + """ + Return the channel name for Twitch. For example, ctx.channel.name in twitchio. + """ + # In twitchio, ctx.channel has .name + return str(ctx.channel.name) + + +async def send_message(ctx, text): + """ + Minimal helper to send a message to either Discord or Twitch. + For discord.py: await ctx.send(text) + For twitchio: await ctx.send(text) + """ + await ctx.send(text) \ No newline at end of file diff --git a/cmd_discord.py b/cmd_discord.py index c8489c2..fff9b82 100644 --- a/cmd_discord.py +++ b/cmd_discord.py @@ -3,7 +3,19 @@ from discord.ext import commands from cmd_common import common_commands as cc from modules.permissions import has_permission -def setup(bot): + +def setup(bot, db_conn=None, log=None): + """ + Attach commands to the Discord bot, store references to db/log. + """ + + # auto-create the quotes table if it doesn't exist + if bot.db_conn and bot.log: + cc.create_quotes_table(bot.db_conn, bot.log) + + # Auto-create the quotes table if desired + if db_conn and log: + cc.create_quotes_table(db_conn, log) @bot.command() async def greet(ctx): result = cc.greet(ctx.author.display_name, "Discord") @@ -41,4 +53,25 @@ def setup(bot): await ctx.send("You don't have permission to use this command.") return - await ctx.send("Hello there!") \ No newline at end of file + await ctx.send("Hello there!") + + @bot.command(name="quote") + async def quote_command(ctx, *args): + """ + !quote + !quote add + !quote remove + !quote + """ + if not bot.db_conn: + return await ctx.send("Database is unavailable, sorry.") + + # Send to our shared logic + await cc.handle_quote_command( + db_conn=bot.db_conn, + log_func=bot.log, + is_discord=True, + ctx=ctx, + args=list(args), + get_twitch_game_for_channel=None # None for Discord + ) \ No newline at end of file diff --git a/cmd_twitch.py b/cmd_twitch.py index aa7c97e..1b884b1 100644 --- a/cmd_twitch.py +++ b/cmd_twitch.py @@ -1,9 +1,19 @@ # cmd_twitch.py + from twitchio.ext import commands from cmd_common import common_commands as cc from modules.permissions import has_permission -def setup(bot): +def setup(bot, db_conn=None, log=None): + """ + This function is called to load/attach commands to the `bot`. + We also attach the db_conn and log so the commands can use them. + """ + + # auto-create the quotes table if it doesn't exist + if bot.db_conn and bot.log: + cc.create_quotes_table(bot.db_conn, bot.log) + @bot.command(name="greet") async def greet(ctx): result = cc.greet(ctx.author.display_name, "Twitch") @@ -18,14 +28,34 @@ def setup(bot): async def howl(ctx): result = cc.howl(ctx.author.display_name) await ctx.send(result) - + @bot.command(name="hi") async def hi_command(ctx): user_id = str(ctx.author.id) # Twitch user ID - user_roles = [role.lower() for role in ctx.author.badges.keys()] # Use Twitch badges as roles + user_roles = [role.lower() for role in ctx.author.badges.keys()] # "roles" from Twitch badges if not has_permission("hi", user_id, user_roles, "twitch"): - await ctx.send("You don't have permission to use this command.") - return + return await ctx.send("You don't have permission to use this command.") - await ctx.send("Hello there!") \ No newline at end of file + await ctx.send("Hello there!") + + @bot.command(name="quote") + async def quote(ctx: commands.Context): + if not bot.db_conn: + return await ctx.send("Database is unavailable, sorry.") + + parts = ctx.message.content.strip().split() + args = parts[1:] if len(parts) > 1 else [] + + def get_twitch_game_for_channel(chan_name): + # Placeholder for your actual logic to fetch the current game + return "SomeGame" + + await cc.handle_quote_command( + db_conn=bot.db_conn, + log_func=bot.log, + is_discord=False, + ctx=ctx, + args=args, + get_twitch_game_for_channel=get_twitch_game_for_channel + ) \ No newline at end of file diff --git a/modules/db.py b/modules/db.py new file mode 100644 index 0000000..0138607 --- /dev/null +++ b/modules/db.py @@ -0,0 +1,214 @@ +# modules/db.py +import os +import re +import time +import sqlite3 + +try: + import mariadb +except ImportError: + mariadb = None # We handle gracefully if 'mariadb' isn't installed. + +def init_db_connection(config, log): + """ + Initializes a database connection based on config.json contents: + - If config says 'use_mariadb', tries connecting to MariaDB. + - If that fails (or not configured), falls back to SQLite. + - Logs FATAL if neither can be established (the bot likely depends on DB). + + :param config: (dict) The loaded config.json data + :param log: (function) Logging function (message, level="INFO") + :return: a connection object (MariaDB or sqlite3.Connection), or None on failure + """ + db_settings = config.get("database", {}) + use_mariadb = db_settings.get("use_mariadb", False) + + if use_mariadb and mariadb is not None: + # Attempt MariaDB + host = db_settings.get("mariadb_host", "localhost") + user = db_settings.get("mariadb_user", "") + password = db_settings.get("mariadb_password", "") + dbname = db_settings.get("mariadb_dbname", "") + port = int(db_settings.get("mariadb_port", 3306)) + + if user and password and dbname: + try: + conn = mariadb.connect( + host=host, + user=user, + password=password, + database=dbname, + port=port + ) + conn.autocommit = False # We'll manage commits manually + log(f"Database connection established using MariaDB (host={host}, db={dbname}).", "INFO") + return conn + except mariadb.Error as e: + log(f"Error connecting to MariaDB: {e}", "WARNING") + else: + log("MariaDB config incomplete. Falling back to SQLite...", "WARNING") + else: + if use_mariadb and mariadb is None: + log("mariadb module not installed but use_mariadb=True. Falling back to SQLite...", "WARNING") + + # Fallback to local SQLite + sqlite_path = db_settings.get("sqlite_path", "local_database.sqlite") + try: + conn = sqlite3.connect(sqlite_path) + log(f"Database connection established using local SQLite: {sqlite_path}", "INFO") + return conn + except sqlite3.Error as e: + log(f"Could not open local SQLite database '{sqlite_path}': {e}", "WARNING") + + # If neither MariaDB nor SQLite connected, that's fatal for the bot + log("No valid database connection could be established! Exiting...", "FATAL") + return None + + +def run_db_operation(conn, operation, query, params=None, log_func=None): + """ + Executes a parameterized query with basic screening for injection attempts: + - 'operation' can be "read", "write", "update", "delete", "lookup", etc. + - 'query' is the SQL statement, with placeholders (? in SQLite or %s in MariaDB both work). + - 'params' is a tuple/list of parameters for the query (preferred for security). + - 'log_func' is the logging function (message, level). + + 1) We do a minimal check for suspicious patterns, e.g. multiple statements or known bad keywords. + 2) We execute the query with parameters, and commit on write/update/delete. + 3) On read/lookup, we fetch and return rows. Otherwise, return rowcount. + + NOTE: + - This is still not a replacement for well-structured queries and security best practices. + - Always use parameterized queries wherever possible to avoid injection. + """ + if conn is None: + if log_func: + log_func("run_db_operation called but no valid DB connection!", "FATAL") + return None + + if params is None: + params = () + + # Basic screening for malicious usage (multiple statements, forced semicolons, suspicious SQL keywords, etc.) + # This is minimal and can be expanded if needed. + lowered = query.strip().lower() + + # Check for multiple statements separated by semicolons (beyond the last one) + if lowered.count(";") > 1: + if log_func: + log_func("Query blocked: multiple SQL statements detected.", "WARNING") + log_func(f"Offending query: {query}", "WARNING") + return None + + # Potentially dangerous SQL keywords + forbidden_keywords = ["drop table", "union select", "exec ", "benchmark(", "sleep("] + for kw in forbidden_keywords: + if kw in lowered: + if log_func: + log_func(f"Query blocked due to forbidden keyword: '{kw}'", "WARNING") + log_func(f"Offending query: {query}", "WARNING") + return None + + cursor = conn.cursor() + try: + cursor.execute(query, params) + + # If it's a write/update/delete, commit the changes + write_ops = ("write", "insert", "update", "delete", "change") + if operation.lower() in write_ops: + conn.commit() + if log_func: + log_func(f"DB operation '{operation}' committed.", "DEBUG") + + # If it's read/lookup, fetch results + read_ops = ("read", "lookup", "select") + if operation.lower() in read_ops: + rows = cursor.fetchall() + return rows + else: + return cursor.rowcount # for insert/update/delete, rowcount can be helpful + except Exception as e: + # Rollback on any error + conn.rollback() + if log_func: + log_func(f"Error during '{operation}' query execution: {e}", "ERROR") + return None + finally: + cursor.close() + +####################### +# Ensure quotes table exists +####################### + +def ensure_quotes_table(db_conn, log_func): + """ + Checks if 'quotes' table exists. If not, attempts to create it. + Raises an Exception or logs errors if creation fails. + """ + + # 1) Determine if DB is sqlite or mariadb for the system table check + is_sqlite = "sqlite3" in str(type(db_conn)).lower() + + # 2) Check existence + if is_sqlite: + # For SQLite: check the sqlite_master table + check_sql = """ + SELECT name + FROM sqlite_master + WHERE type='table' + AND name='quotes' + """ + else: + # For MariaDB/MySQL: check information_schema + check_sql = """ + SELECT table_name + FROM information_schema.tables + WHERE table_name = 'quotes' + AND table_schema = DATABASE() + """ + + from modules.db import run_db_operation + rows = run_db_operation(db_conn, "read", check_sql, log_func=log_func) + if rows and rows[0] and rows[0][0]: + # The table 'quotes' already exists + log_func("Table 'quotes' already exists, skipping creation.", "DEBUG") + return # We can just return + + # 3) Table does NOT exist => create it + log_func("Table 'quotes' does not exist; creating now...", "INFO") + + if is_sqlite: + create_table_sql = """ + CREATE TABLE quotes ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + QUOTE_TEXT TEXT, + QUOTEE TEXT, + QUOTE_CHANNEL TEXT, + QUOTE_DATETIME TEXT, + QUOTE_GAME TEXT, + QUOTE_REMOVED BOOLEAN DEFAULT 0, + QUOTE_REMOVED_BY TEXT + ) + """ + else: + create_table_sql = """ + CREATE TABLE quotes ( + ID INT PRIMARY KEY AUTO_INCREMENT, + QUOTE_TEXT TEXT, + QUOTEE VARCHAR(100), + QUOTE_CHANNEL VARCHAR(100), + QUOTE_DATETIME DATETIME DEFAULT CURRENT_TIMESTAMP, + QUOTE_GAME VARCHAR(200), + QUOTE_REMOVED BOOLEAN DEFAULT FALSE, + QUOTE_REMOVED_BY VARCHAR(100) + ) + """ + + result = run_db_operation(db_conn, "write", create_table_sql, log_func=log_func) + if result is None: + # If run_db_operation returns None on error, handle or raise: + error_msg = "Failed to create 'quotes' table!" + log_func(error_msg, "ERROR") + raise RuntimeError(error_msg) + + log_func("Successfully created table 'quotes'.", "INFO") diff --git a/modules/utility.py b/modules/utility.py index d729da0..7fee8ae 100644 --- a/modules/utility.py +++ b/modules/utility.py @@ -170,3 +170,4 @@ def sanitize_user_input( # 4. Prepare output reason_string = "; ".join(reasons) return (sanitized, sanitization_applied, reason_string, original_string) +