OokamiPupV2/modules/db.py

376 lines
13 KiB
Python

# modules/db.py
import os
import re
import time
import sqlite3
try:
import mariadb
except ImportError:
mariadb = None # We handle gracefully if 'mariadb' isn't installed.
def init_db_connection(config, log):
"""
Initializes a database connection based on config.json contents:
- If config says 'use_mariadb', tries connecting to MariaDB.
- If that fails (or not configured), falls back to SQLite.
- Logs FATAL if neither can be established (the bot likely depends on DB).
:param config: (dict) The loaded config.json data
:param log: (function) Logging function (message, level="INFO")
:return: a connection object (MariaDB or sqlite3.Connection), or None on failure
"""
db_settings = config.get("database", {})
use_mariadb = db_settings.get("use_mariadb", False)
if use_mariadb and mariadb is not None:
# Attempt MariaDB
host = db_settings.get("mariadb_host", "localhost")
user = db_settings.get("mariadb_user", "")
password = db_settings.get("mariadb_password", "")
dbname = db_settings.get("mariadb_dbname", "")
port = int(db_settings.get("mariadb_port", 3306))
if user and password and dbname:
try:
conn = mariadb.connect(
host=host,
user=user,
password=password,
database=dbname,
port=port
)
conn.autocommit = False # We'll manage commits manually
log(f"Database connection established using MariaDB (host={host}, db={dbname}).")
return conn
except mariadb.Error as e:
log(f"Error connecting to MariaDB: {e}", "WARNING")
else:
log("MariaDB config incomplete. Falling back to SQLite...", "WARNING")
else:
if use_mariadb and mariadb is None:
log("mariadb module not installed but use_mariadb=True. Falling back to SQLite...", "WARNING")
# Fallback to local SQLite
sqlite_path = db_settings.get("sqlite_path", "local_database.sqlite")
try:
conn = sqlite3.connect(sqlite_path)
log(f"Database connection established using local SQLite: {sqlite_path}")
return conn
except sqlite3.Error as e:
log(f"Could not open local SQLite database '{sqlite_path}': {e}", "WARNING")
# If neither MariaDB nor SQLite connected, that's fatal for the bot
log("No valid database connection could be established! Exiting...", "FATAL")
return None
def run_db_operation(conn, operation, query, params=None, log_func=None):
"""
Executes a parameterized query with basic screening for injection attempts:
- 'operation' can be "read", "write", "update", "delete", "lookup", etc.
- 'query' is the SQL statement, with placeholders (? in SQLite or %s in MariaDB both work).
- 'params' is a tuple/list of parameters for the query (preferred for security).
- 'log_func' is the logging function (message, level).
1) We do a minimal check for suspicious patterns, e.g. multiple statements or known bad keywords.
2) We execute the query with parameters, and commit on write/update/delete.
3) On read/lookup, we fetch and return rows. Otherwise, return rowcount.
NOTE:
- This is still not a replacement for well-structured queries and security best practices.
- Always use parameterized queries wherever possible to avoid injection.
"""
if conn is None:
if log_func:
log_func("run_db_operation called but no valid DB connection!", "FATAL")
return None
if params is None:
params = ()
# Basic screening for malicious usage (multiple statements, forced semicolons, suspicious SQL keywords, etc.)
# This is minimal and can be expanded if needed.
lowered = query.strip().lower()
# Check for multiple statements separated by semicolons (beyond the last one)
if lowered.count(";") > 1:
if log_func:
log_func("Query blocked: multiple SQL statements detected.", "WARNING")
log_func(f"Offending query: {query}", "WARNING")
return None
# Potentially dangerous SQL keywords
forbidden_keywords = ["drop table", "union select", "exec ", "benchmark(", "sleep("]
for kw in forbidden_keywords:
if kw in lowered:
if log_func:
log_func(f"Query blocked due to forbidden keyword: '{kw}'", "WARNING")
log_func(f"Offending query: {query}", "WARNING")
return None
cursor = conn.cursor()
try:
cursor.execute(query, params)
# If it's a write/update/delete, commit the changes
write_ops = ("write", "insert", "update", "delete", "change")
if operation.lower() in write_ops:
conn.commit()
if log_func:
log_func(f"DB operation '{operation}' committed.", "DEBUG")
# If it's read/lookup, fetch results
read_ops = ("read", "lookup", "select")
if operation.lower() in read_ops:
rows = cursor.fetchall()
return rows
else:
return cursor.rowcount # for insert/update/delete, rowcount can be helpful
except Exception as e:
# Rollback on any error
conn.rollback()
if log_func:
log_func(f"Error during '{operation}' query execution: {e}", "ERROR")
return None
finally:
cursor.close()
#######################
# Ensure quotes table exists
#######################
def ensure_quotes_table(db_conn, log_func):
"""
Checks if 'quotes' table exists. If not, attempts to create it.
Raises an Exception or logs errors if creation fails.
"""
# 1) Determine if DB is sqlite or mariadb for the system table check
is_sqlite = "sqlite3" in str(type(db_conn)).lower()
# 2) Check existence
if is_sqlite:
# For SQLite: check the sqlite_master table
check_sql = """
SELECT name
FROM sqlite_master
WHERE type='table'
AND name='quotes'
"""
else:
# For MariaDB/MySQL: check information_schema
check_sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_name = 'quotes'
AND table_schema = DATABASE()
"""
from modules.db import run_db_operation
rows = run_db_operation(db_conn, "read", check_sql, log_func=log_func)
if rows and rows[0] and rows[0][0]:
# The table 'quotes' already exists
log_func("Table 'quotes' already exists, skipping creation.", "DEBUG")
return # We can just return
# 3) Table does NOT exist => create it
log_func("Table 'quotes' does not exist; creating now...")
if is_sqlite:
create_table_sql = """
CREATE TABLE quotes (
ID INTEGER PRIMARY KEY AUTOINCREMENT,
QUOTE_TEXT TEXT,
QUOTEE TEXT,
QUOTE_CHANNEL TEXT,
QUOTE_DATETIME TEXT,
QUOTE_GAME TEXT,
QUOTE_REMOVED BOOLEAN DEFAULT 0,
QUOTE_REMOVED_BY TEXT
)
"""
else:
create_table_sql = """
CREATE TABLE quotes (
ID INT PRIMARY KEY AUTO_INCREMENT,
QUOTE_TEXT TEXT,
QUOTEE VARCHAR(100),
QUOTE_CHANNEL VARCHAR(100),
QUOTE_DATETIME DATETIME DEFAULT CURRENT_TIMESTAMP,
QUOTE_GAME VARCHAR(200),
QUOTE_REMOVED BOOLEAN DEFAULT FALSE,
QUOTE_REMOVED_BY VARCHAR(100)
)
"""
result = run_db_operation(db_conn, "write", create_table_sql, log_func=log_func)
if result is None:
# If run_db_operation returns None on error, handle or raise:
error_msg = "Failed to create 'quotes' table!"
log_func(error_msg, "ERROR")
raise RuntimeError(error_msg)
log_func("Successfully created table 'quotes'.")
#######################
# Ensure 'users' table
#######################
def ensure_users_table(db_conn, log_func):
"""
Checks if 'users' table exists. If not, creates it.
The 'users' table tracks user linkage across platforms:
- UUID: (PK) The universal ID for the user
- discord_user_id, discord_username, discord_user_display_name
- twitch_user_id, twitch_username, twitch_user_display_name
- datetime_linked (DATE/TIME of row creation)
- user_is_banned (BOOLEAN)
- user_is_bot (BOOLEAN)
This helps unify data for a single 'person' across Discord & Twitch.
"""
is_sqlite = "sqlite3" in str(type(db_conn)).lower()
# 1) Check existence
if is_sqlite:
check_sql = """
SELECT name
FROM sqlite_master
WHERE type='table'
AND name='users'
"""
else:
check_sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_name = 'users'
AND table_schema = DATABASE()
"""
rows = run_db_operation(db_conn, "read", check_sql, log_func=log_func)
if rows and rows[0] and rows[0][0]:
log_func("Table 'users' already exists, skipping creation.", "DEBUG")
return
# 2) Table does NOT exist => create it
log_func("Table 'users' does not exist; creating now...")
if is_sqlite:
create_table_sql = """
CREATE TABLE users (
UUID TEXT PRIMARY KEY,
discord_user_id TEXT,
discord_username TEXT,
discord_user_display_name TEXT,
twitch_user_id TEXT,
twitch_username TEXT,
twitch_user_display_name TEXT,
datetime_linked TEXT DEFAULT CURRENT_TIMESTAMP,
user_is_banned BOOLEAN DEFAULT 0,
user_is_bot BOOLEAN DEFAULT 0
)
"""
else:
create_table_sql = """
CREATE TABLE users (
UUID VARCHAR(36) PRIMARY KEY,
discord_user_id VARCHAR(100),
discord_username VARCHAR(100),
discord_user_display_name VARCHAR(100),
twitch_user_id VARCHAR(100),
twitch_username VARCHAR(100),
twitch_user_display_name VARCHAR(100),
datetime_linked DATETIME DEFAULT CURRENT_TIMESTAMP,
user_is_banned BOOLEAN DEFAULT FALSE,
user_is_bot BOOLEAN DEFAULT FALSE
)
"""
result = run_db_operation(db_conn, "write", create_table_sql, log_func=log_func)
if result is None:
error_msg = "Failed to create 'users' table!"
log_func(error_msg, "ERROR")
raise RuntimeError(error_msg)
log_func("Successfully created table 'users'.")
########################
# Lookup user function
########################
def lookup_user(db_conn, log_func, identifier, identifier_type="discord_user_id"):
"""
Looks up a user in the 'users' table based on the given identifier_type:
- "uuid"
- "discord_user_id"
- "discord_username"
- "twitch_user_id"
- "twitch_username"
You can add more if needed.
Returns a dictionary with all columns:
{
"UUID": str,
"discord_user_id": str or None,
"discord_username": str or None,
"discord_user_display_name": str or None,
"twitch_user_id": str or None,
"twitch_username": str or None,
"twitch_user_display_name": str or None,
"datetime_linked": str (or datetime in MariaDB),
"user_is_banned": bool or int,
"user_is_bot": bool or int
}
If not found, returns None.
"""
valid_cols = ["uuid", "discord_user_id", "discord_username",
"twitch_user_id", "twitch_username"]
if identifier_type.lower() not in valid_cols:
if log_func:
log_func(f"lookup_user error: invalid identifier_type={identifier_type}", "WARNING")
return None
# Build the query
query = f"""
SELECT
UUID,
discord_user_id,
discord_username,
discord_user_display_name,
twitch_user_id,
twitch_username,
twitch_user_display_name,
datetime_linked,
user_is_banned,
user_is_bot
FROM users
WHERE {identifier_type} = ?
LIMIT 1
"""
rows = run_db_operation(db_conn, "read", query, params=(identifier,), log_func=log_func)
if not rows:
return None
# We have at least one row
row = rows[0] # single row
# Build a dictionary
user_data = {
"UUID": row[0],
"discord_user_id": row[1],
"discord_username": row[2],
"discord_user_display_name": row[3],
"twitch_user_id": row[4],
"twitch_username": row[5],
"twitch_user_display_name": row[6],
"datetime_linked": row[7],
"user_is_banned": row[8],
"user_is_bot": row[9],
}
return user_data