"""Database connection and schema management.""" import sqlite3 import time from pathlib import Path from flask import current_app, g SCHEMA = """ CREATE TABLE IF NOT EXISTS pastes ( id TEXT PRIMARY KEY, content BLOB NOT NULL, mime_type TEXT NOT NULL DEFAULT 'text/plain', owner TEXT, created_at INTEGER NOT NULL, last_accessed INTEGER NOT NULL ); CREATE INDEX IF NOT EXISTS idx_pastes_created_at ON pastes(created_at); CREATE INDEX IF NOT EXISTS idx_pastes_owner ON pastes(owner); CREATE INDEX IF NOT EXISTS idx_pastes_last_accessed ON pastes(last_accessed); -- Content hash tracking for abuse prevention CREATE TABLE IF NOT EXISTS content_hashes ( hash TEXT PRIMARY KEY, first_seen INTEGER NOT NULL, last_seen INTEGER NOT NULL, count INTEGER NOT NULL DEFAULT 1 ); CREATE INDEX IF NOT EXISTS idx_content_hashes_last_seen ON content_hashes(last_seen); """ # Hold reference for in-memory shared cache databases _memory_db_holder = None def _get_connection_string(db_path) -> tuple[str, dict]: """Get connection string and kwargs for sqlite3.connect.""" if isinstance(db_path, Path): db_path.parent.mkdir(parents=True, exist_ok=True) return str(db_path), {} if db_path == ":memory:": return "file::memory:?cache=shared", {"uri": True} return db_path, {} def _is_file_database(db_path) -> bool: """Check if database path refers to a file (not in-memory).""" if isinstance(db_path, Path): return True if isinstance(db_path, str) and db_path not in (":memory:", ""): return not db_path.startswith("file::memory:") return False def get_db() -> sqlite3.Connection: """Get database connection for current request context.""" if "db" not in g: db_path = current_app.config["DATABASE"] conn_str, kwargs = _get_connection_string(db_path) g.db = sqlite3.connect(conn_str, **kwargs) g.db.row_factory = sqlite3.Row g.db.execute("PRAGMA foreign_keys = ON") if _is_file_database(db_path): # WAL mode set in init_db; these optimize per-connection behavior g.db.execute("PRAGMA busy_timeout = 5000") g.db.execute("PRAGMA synchronous = NORMAL") return g.db def close_db(exception=None) -> None: """Close database connection at end of request.""" db = g.pop("db", None) if db is not None: db.close() def init_db() -> None: """Initialize database schema.""" global _memory_db_holder db_path = current_app.config["DATABASE"] conn_str, kwargs = _get_connection_string(db_path) # For in-memory databases, keep a connection alive if db_path == ":memory:": _memory_db_holder = sqlite3.connect(conn_str, **kwargs) _memory_db_holder.executescript(SCHEMA) _memory_db_holder.commit() else: db = get_db() # Enable WAL mode for file databases (persists with database) if _is_file_database(db_path): db.execute("PRAGMA journal_mode = WAL") db.executescript(SCHEMA) db.commit() def cleanup_expired_pastes() -> int: """Delete pastes that haven't been accessed within expiry period. Returns number of deleted pastes. """ expiry_seconds = current_app.config["PASTE_EXPIRY_SECONDS"] cutoff = int(time.time()) - expiry_seconds db = get_db() cursor = db.execute("DELETE FROM pastes WHERE last_accessed < ?", (cutoff,)) db.commit() return cursor.rowcount def cleanup_expired_hashes() -> int: """Delete content hashes outside the dedup window. Returns number of deleted hashes. """ window = current_app.config["CONTENT_DEDUP_WINDOW"] cutoff = int(time.time()) - window db = get_db() cursor = db.execute("DELETE FROM content_hashes WHERE last_seen < ?", (cutoff,)) db.commit() return cursor.rowcount def check_content_hash(content_hash: str) -> tuple[bool, int]: """Check if content hash exceeds dedup threshold. Args: content_hash: SHA256 hex digest of content Returns: Tuple of (is_allowed, current_count) is_allowed is False if threshold exceeded within window """ window = current_app.config["CONTENT_DEDUP_WINDOW"] max_count = current_app.config["CONTENT_DEDUP_MAX"] now = int(time.time()) cutoff = now - window db = get_db() # Check existing hash record row = db.execute( "SELECT count, last_seen FROM content_hashes WHERE hash = ?", (content_hash,) ).fetchone() if row is None: # First time seeing this content db.execute( "INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 1)", (content_hash, now, now) ) db.commit() return True, 1 if row["last_seen"] < cutoff: # Outside window, reset counter db.execute( "UPDATE content_hashes SET first_seen = ?, last_seen = ?, count = 1 WHERE hash = ?", (now, now, content_hash) ) db.commit() return True, 1 # Within window, check threshold current_count = row["count"] + 1 if current_count > max_count: # Exceeded threshold, don't increment (prevent counter overflow) return False, row["count"] # Update counter db.execute( "UPDATE content_hashes SET last_seen = ?, count = ? WHERE hash = ?", (now, current_count, content_hash) ) db.commit() return True, current_count def init_app(app) -> None: """Register database functions with Flask app.""" app.teardown_appcontext(close_db) with app.app_context(): init_db()