# modules/db.py import os import re import time import sqlite3 try: import mariadb except ImportError: mariadb = None # We handle gracefully if 'mariadb' isn't installed. def init_db_connection(config, log): """ Initializes a database connection based on config.json contents: - If config says 'use_mariadb', tries connecting to MariaDB. - If that fails (or not configured), falls back to SQLite. - Logs FATAL if neither can be established (the bot likely depends on DB). :param config: (dict) The loaded config.json data :param log: (function) Logging function (message, level="INFO") :return: a connection object (MariaDB or sqlite3.Connection), or None on failure """ db_settings = config.get("database", {}) use_mariadb = db_settings.get("use_mariadb", False) if use_mariadb and mariadb is not None: # Attempt MariaDB host = db_settings.get("mariadb_host", "localhost") user = db_settings.get("mariadb_user", "") password = db_settings.get("mariadb_password", "") dbname = db_settings.get("mariadb_dbname", "") port = int(db_settings.get("mariadb_port", 3306)) if user and password and dbname: try: conn = mariadb.connect( host=host, user=user, password=password, database=dbname, port=port ) conn.autocommit = False # We'll manage commits manually log(f"Database connection established using MariaDB (host={host}, db={dbname}).") return conn except mariadb.Error as e: log(f"Error connecting to MariaDB: {e}", "WARNING") else: log("MariaDB config incomplete. Falling back to SQLite...", "WARNING") else: if use_mariadb and mariadb is None: log("mariadb module not installed but use_mariadb=True. Falling back to SQLite...", "WARNING") # Fallback to local SQLite sqlite_path = db_settings.get("sqlite_path", "local_database.sqlite") try: conn = sqlite3.connect(sqlite_path) log(f"Database connection established using local SQLite: {sqlite_path}") return conn except sqlite3.Error as e: log(f"Could not open local SQLite database '{sqlite_path}': {e}", "WARNING") # If neither MariaDB nor SQLite connected, that's fatal for the bot log("No valid database connection could be established! Exiting...", "FATAL") return None def run_db_operation(conn, operation, query, params=None, log_func=None): """ Executes a parameterized query with basic screening for injection attempts: - 'operation' can be "read", "write", "update", "delete", "lookup", etc. - 'query' is the SQL statement, with placeholders (? in SQLite or %s in MariaDB both work). - 'params' is a tuple/list of parameters for the query (preferred for security). - 'log_func' is the logging function (message, level). 1) We do a minimal check for suspicious patterns, e.g. multiple statements or known bad keywords. 2) We execute the query with parameters, and commit on write/update/delete. 3) On read/lookup, we fetch and return rows. Otherwise, return rowcount. NOTE: - This is still not a replacement for well-structured queries and security best practices. - Always use parameterized queries wherever possible to avoid injection. """ if conn is None: if log_func: log_func("run_db_operation called but no valid DB connection!", "FATAL") return None if params is None: params = () # Basic screening for malicious usage (multiple statements, forced semicolons, suspicious SQL keywords, etc.) # This is minimal and can be expanded if needed. lowered = query.strip().lower() # Check for multiple statements separated by semicolons (beyond the last one) if lowered.count(";") > 1: if log_func: log_func("Query blocked: multiple SQL statements detected.", "WARNING") log_func(f"Offending query: {query}", "WARNING") return None # Potentially dangerous SQL keywords forbidden_keywords = ["drop table", "union select", "exec ", "benchmark(", "sleep("] for kw in forbidden_keywords: if kw in lowered: if log_func: log_func(f"Query blocked due to forbidden keyword: '{kw}'", "WARNING") log_func(f"Offending query: {query}", "WARNING") return None cursor = conn.cursor() try: cursor.execute(query, params) # If it's a write/update/delete, commit the changes write_ops = ("write", "insert", "update", "delete", "change") if operation.lower() in write_ops: conn.commit() if log_func: log_func(f"DB operation '{operation}' committed.", "DEBUG") # If it's read/lookup, fetch results read_ops = ("read", "lookup", "select") if operation.lower() in read_ops: rows = cursor.fetchall() return rows else: return cursor.rowcount # for insert/update/delete, rowcount can be helpful except Exception as e: # Rollback on any error conn.rollback() if log_func: log_func(f"Error during '{operation}' query execution: {e}", "ERROR") return None finally: cursor.close() ####################### # Ensure quotes table exists ####################### def ensure_quotes_table(db_conn, log_func): """ Checks if 'quotes' table exists. If not, attempts to create it. Raises an Exception or logs errors if creation fails. """ # 1) Determine if DB is sqlite or mariadb for the system table check is_sqlite = "sqlite3" in str(type(db_conn)).lower() # 2) Check existence if is_sqlite: # For SQLite: check the sqlite_master table check_sql = """ SELECT name FROM sqlite_master WHERE type='table' AND name='quotes' """ else: # For MariaDB/MySQL: check information_schema check_sql = """ SELECT table_name FROM information_schema.tables WHERE table_name = 'quotes' AND table_schema = DATABASE() """ from modules.db import run_db_operation rows = run_db_operation(db_conn, "read", check_sql, log_func=log_func) if rows and rows[0] and rows[0][0]: # The table 'quotes' already exists log_func("Table 'quotes' already exists, skipping creation.", "DEBUG") return # We can just return # 3) Table does NOT exist => create it log_func("Table 'quotes' does not exist; creating now...") if is_sqlite: create_table_sql = """ CREATE TABLE quotes ( ID INTEGER PRIMARY KEY AUTOINCREMENT, QUOTE_TEXT TEXT, QUOTEE TEXT, QUOTE_CHANNEL TEXT, QUOTE_DATETIME TEXT, QUOTE_GAME TEXT, QUOTE_REMOVED BOOLEAN DEFAULT 0, QUOTE_REMOVED_BY TEXT ) """ else: create_table_sql = """ CREATE TABLE quotes ( ID INT PRIMARY KEY AUTO_INCREMENT, QUOTE_TEXT TEXT, QUOTEE VARCHAR(100), QUOTE_CHANNEL VARCHAR(100), QUOTE_DATETIME DATETIME DEFAULT CURRENT_TIMESTAMP, QUOTE_GAME VARCHAR(200), QUOTE_REMOVED BOOLEAN DEFAULT FALSE, QUOTE_REMOVED_BY VARCHAR(100) ) """ result = run_db_operation(db_conn, "write", create_table_sql, log_func=log_func) if result is None: # If run_db_operation returns None on error, handle or raise: error_msg = "Failed to create 'quotes' table!" log_func(error_msg, "ERROR") raise RuntimeError(error_msg) log_func("Successfully created table 'quotes'.") ####################### # Ensure 'users' table ####################### def ensure_users_table(db_conn, log_func): """ Checks if 'users' table exists. If not, creates it. The 'users' table tracks user linkage across platforms: - UUID: (PK) The universal ID for the user - discord_user_id, discord_username, discord_user_display_name - twitch_user_id, twitch_username, twitch_user_display_name - datetime_linked (DATE/TIME of row creation) - user_is_banned (BOOLEAN) - user_is_bot (BOOLEAN) This helps unify data for a single 'person' across Discord & Twitch. """ is_sqlite = "sqlite3" in str(type(db_conn)).lower() # 1) Check existence if is_sqlite: check_sql = """ SELECT name FROM sqlite_master WHERE type='table' AND name='users' """ else: check_sql = """ SELECT table_name FROM information_schema.tables WHERE table_name = 'users' AND table_schema = DATABASE() """ rows = run_db_operation(db_conn, "read", check_sql, log_func=log_func) if rows and rows[0] and rows[0][0]: log_func("Table 'users' already exists, skipping creation.", "DEBUG") return # 2) Table does NOT exist => create it log_func("Table 'users' does not exist; creating now...") if is_sqlite: create_table_sql = """ CREATE TABLE users ( UUID TEXT PRIMARY KEY, discord_user_id TEXT, discord_username TEXT, discord_user_display_name TEXT, twitch_user_id TEXT, twitch_username TEXT, twitch_user_display_name TEXT, datetime_linked TEXT DEFAULT CURRENT_TIMESTAMP, user_is_banned BOOLEAN DEFAULT 0, user_is_bot BOOLEAN DEFAULT 0 ) """ else: create_table_sql = """ CREATE TABLE users ( UUID VARCHAR(36) PRIMARY KEY, discord_user_id VARCHAR(100), discord_username VARCHAR(100), discord_user_display_name VARCHAR(100), twitch_user_id VARCHAR(100), twitch_username VARCHAR(100), twitch_user_display_name VARCHAR(100), datetime_linked DATETIME DEFAULT CURRENT_TIMESTAMP, user_is_banned BOOLEAN DEFAULT FALSE, user_is_bot BOOLEAN DEFAULT FALSE ) """ result = run_db_operation(db_conn, "write", create_table_sql, log_func=log_func) if result is None: error_msg = "Failed to create 'users' table!" log_func(error_msg, "ERROR") raise RuntimeError(error_msg) log_func("Successfully created table 'users'.") ######################## # Lookup user function ######################## def lookup_user(db_conn, log_func, identifier, identifier_type="discord_user_id"): """ Looks up a user in the 'users' table based on the given identifier_type: - "uuid" - "discord_user_id" - "discord_username" - "twitch_user_id" - "twitch_username" You can add more if needed. Returns a dictionary with all columns: { "UUID": str, "discord_user_id": str or None, "discord_username": str or None, "discord_user_display_name": str or None, "twitch_user_id": str or None, "twitch_username": str or None, "twitch_user_display_name": str or None, "datetime_linked": str (or datetime in MariaDB), "user_is_banned": bool or int, "user_is_bot": bool or int } If not found, returns None. """ valid_cols = ["uuid", "discord_user_id", "discord_username", "twitch_user_id", "twitch_username"] if identifier_type.lower() not in valid_cols: if log_func: log_func(f"lookup_user error: invalid identifier_type={identifier_type}", "WARNING") return None # Build the query query = f""" SELECT UUID, discord_user_id, discord_username, discord_user_display_name, twitch_user_id, twitch_username, twitch_user_display_name, datetime_linked, user_is_banned, user_is_bot FROM users WHERE {identifier_type} = ? LIMIT 1 """ rows = run_db_operation(db_conn, "read", query, params=(identifier,), log_func=log_func) if not rows: return None # We have at least one row row = rows[0] # single row # Build a dictionary user_data = { "UUID": row[0], "discord_user_id": row[1], "discord_username": row[2], "discord_user_display_name": row[3], "twitch_user_id": row[4], "twitch_username": row[5], "twitch_user_display_name": row[6], "datetime_linked": row[7], "user_is_banned": row[8], "user_is_bot": row[9], } return user_data