"""Database connection and schema management.""" from __future__ import annotations import hashlib import secrets import sqlite3 import time from pathlib import Path from typing import TYPE_CHECKING, Any from flask import current_app, g if TYPE_CHECKING: from flask import Flask SCHEMA = """ CREATE TABLE IF NOT EXISTS pastes ( id TEXT PRIMARY KEY, content BLOB NOT NULL, mime_type TEXT NOT NULL DEFAULT 'text/plain', owner TEXT, created_at INTEGER NOT NULL, last_accessed INTEGER NOT NULL, burn_after_read INTEGER NOT NULL DEFAULT 0, expires_at INTEGER, password_hash TEXT ); CREATE INDEX IF NOT EXISTS idx_pastes_created_at ON pastes(created_at); CREATE INDEX IF NOT EXISTS idx_pastes_owner ON pastes(owner); CREATE INDEX IF NOT EXISTS idx_pastes_last_accessed ON pastes(last_accessed); -- Content hash tracking for abuse prevention CREATE TABLE IF NOT EXISTS content_hashes ( hash TEXT PRIMARY KEY, first_seen INTEGER NOT NULL, last_seen INTEGER NOT NULL, count INTEGER NOT NULL DEFAULT 1 ); CREATE INDEX IF NOT EXISTS idx_content_hashes_last_seen ON content_hashes(last_seen); -- PKI: Certificate Authority storage CREATE TABLE IF NOT EXISTS certificate_authority ( id TEXT PRIMARY KEY DEFAULT 'default', common_name TEXT NOT NULL, certificate_pem TEXT NOT NULL, private_key_encrypted BLOB NOT NULL, key_salt BLOB NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, key_algorithm TEXT NOT NULL, owner TEXT ); -- PKI: Issued client certificates CREATE TABLE IF NOT EXISTS issued_certificates ( serial TEXT PRIMARY KEY, ca_id TEXT NOT NULL DEFAULT 'default', common_name TEXT NOT NULL, fingerprint_sha1 TEXT NOT NULL UNIQUE, certificate_pem TEXT NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, issued_to TEXT, status TEXT NOT NULL DEFAULT 'valid', revoked_at INTEGER, is_admin INTEGER NOT NULL DEFAULT 0, FOREIGN KEY(ca_id) REFERENCES certificate_authority(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_certs_fingerprint ON issued_certificates(fingerprint_sha1); CREATE INDEX IF NOT EXISTS idx_certs_status ON issued_certificates(status); CREATE INDEX IF NOT EXISTS idx_certs_ca_id ON issued_certificates(ca_id); -- Audit log for security events CREATE TABLE IF NOT EXISTS audit_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp INTEGER NOT NULL, event_type TEXT NOT NULL, client_id TEXT, client_ip TEXT, paste_id TEXT, request_id TEXT, outcome TEXT NOT NULL, details TEXT ); CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp); CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_log(event_type); CREATE INDEX IF NOT EXISTS idx_audit_client_id ON audit_log(client_id); """ # Password hashing constants _HASH_ITERATIONS = 600000 # OWASP 2023 recommendation for PBKDF2-SHA256 _SALT_LENGTH = 32 def hash_password(password: str) -> str | None: """Hash password using PBKDF2-HMAC-SHA256. Returns format: $pbkdf2-sha256$iterations$salt$hash All values are hex-encoded. """ if not password: return None salt = secrets.token_bytes(_SALT_LENGTH) dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, _HASH_ITERATIONS) return f"$pbkdf2-sha256${_HASH_ITERATIONS}${salt.hex()}${dk.hex()}" def verify_password(password: str, password_hash: str) -> bool: """Verify password against stored hash. Uses constant-time comparison to prevent timing attacks. """ if not password or not password_hash: return False try: # Parse hash format: $pbkdf2-sha256$iterations$salt$hash parts = password_hash.split("$") if len(parts) != 5 or parts[1] != "pbkdf2-sha256": return False iterations = int(parts[2]) salt = bytes.fromhex(parts[3]) stored_hash = bytes.fromhex(parts[4]) # Compute hash of provided password dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations) # Constant-time comparison return secrets.compare_digest(dk, stored_hash) except (ValueError, IndexError): return False # Hold reference for in-memory shared cache databases _memory_db_holder = None def _get_connection_string(db_path: str | Path) -> tuple[str, dict[str, Any]]: """Get connection string and kwargs for sqlite3.connect.""" if isinstance(db_path, Path): db_path.parent.mkdir(parents=True, exist_ok=True) return str(db_path), {} if db_path == ":memory:": return "file::memory:?cache=shared", {"uri": True} return db_path, {} def _is_file_database(db_path: str | Path) -> bool: """Check if database path refers to a file (not in-memory).""" if isinstance(db_path, Path): return True if isinstance(db_path, str) and db_path not in (":memory:", ""): return not db_path.startswith("file::memory:") return False def get_db() -> sqlite3.Connection: """Get database connection for current request context.""" if "db" not in g: db_path = current_app.config["DATABASE"] conn_str, kwargs = _get_connection_string(db_path) g.db = sqlite3.connect(conn_str, **kwargs) g.db.row_factory = sqlite3.Row g.db.execute("PRAGMA foreign_keys = ON") if _is_file_database(db_path): # WAL mode set in init_db; these optimize per-connection behavior g.db.execute("PRAGMA busy_timeout = 5000") g.db.execute("PRAGMA synchronous = NORMAL") conn: sqlite3.Connection = g.db return conn def close_db(exception: BaseException | None = None) -> None: """Close database connection at end of request.""" db = g.pop("db", None) if db is not None: db.close() def init_db() -> None: """Initialize database schema.""" global _memory_db_holder db_path = current_app.config["DATABASE"] conn_str, kwargs = _get_connection_string(db_path) # For in-memory databases, keep a connection alive if db_path == ":memory:": _memory_db_holder = sqlite3.connect(conn_str, **kwargs) _memory_db_holder.executescript(SCHEMA) _memory_db_holder.commit() else: db = get_db() # Enable WAL mode for file databases (persists with database) if _is_file_database(db_path): db.execute("PRAGMA journal_mode = WAL") db.executescript(SCHEMA) db.commit() def cleanup_expired_pastes() -> int: """Delete pastes that have expired. Pastes expire based on: - Custom expires_at timestamp if set - Default expiry from last_accessed if expires_at is NULL Returns number of deleted pastes. """ expiry_seconds = current_app.config["PASTE_EXPIRY_SECONDS"] now = int(time.time()) default_cutoff = now - expiry_seconds db = get_db() # Delete pastes with custom expiry that have passed, # OR pastes without custom expiry that exceed default window cursor = db.execute( """DELETE FROM pastes WHERE (expires_at IS NOT NULL AND expires_at < ?) OR (expires_at IS NULL AND last_accessed < ?)""", (now, default_cutoff), ) db.commit() return cursor.rowcount def cleanup_expired_hashes() -> int: """Delete content hashes outside the dedup window. Returns number of deleted hashes. """ window = current_app.config["CONTENT_DEDUP_WINDOW"] cutoff = int(time.time()) - window db = get_db() cursor = db.execute("DELETE FROM content_hashes WHERE last_seen < ?", (cutoff,)) db.commit() return cursor.rowcount def check_content_hash(content_hash: str) -> tuple[bool, int]: """Check if content hash exceeds dedup threshold. Args: content_hash: SHA256 hex digest of content Returns: Tuple of (is_allowed, current_count) is_allowed is False if threshold exceeded within window """ window = current_app.config["CONTENT_DEDUP_WINDOW"] max_count = current_app.config["CONTENT_DEDUP_MAX"] now = int(time.time()) cutoff = now - window db = get_db() # Check existing hash record row = db.execute( "SELECT count, last_seen FROM content_hashes WHERE hash = ?", (content_hash,) ).fetchone() if row is None: # First time seeing this content db.execute( "INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 1)", (content_hash, now, now), ) db.commit() return True, 1 if row["last_seen"] < cutoff: # Outside window, reset counter db.execute( "UPDATE content_hashes SET first_seen = ?, last_seen = ?, count = 1 WHERE hash = ?", (now, now, content_hash), ) db.commit() return True, 1 # Within window, check threshold current_count = row["count"] + 1 if current_count > max_count: # Exceeded threshold, don't increment (prevent counter overflow) return False, row["count"] # Update counter db.execute( "UPDATE content_hashes SET last_seen = ?, count = ? WHERE hash = ?", (now, current_count, content_hash), ) db.commit() return True, current_count def init_app(app: Flask) -> None: """Register database functions with Flask app.""" app.teardown_appcontext(close_db) with app.app_context(): init_db()