Files
flaskpaste/app/database.py
Username 9c5b1d9804
Some checks failed
CI / test (push) Has been cancelled
enable sqlite wal mode for file databases
2025-12-20 03:44:38 +01:00

194 lines
5.6 KiB
Python

"""Database connection and schema management."""
import sqlite3
import time
from pathlib import Path
from flask import current_app, g
SCHEMA = """
CREATE TABLE IF NOT EXISTS pastes (
id TEXT PRIMARY KEY,
content BLOB NOT NULL,
mime_type TEXT NOT NULL DEFAULT 'text/plain',
owner TEXT,
created_at INTEGER NOT NULL,
last_accessed INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_pastes_created_at ON pastes(created_at);
CREATE INDEX IF NOT EXISTS idx_pastes_owner ON pastes(owner);
CREATE INDEX IF NOT EXISTS idx_pastes_last_accessed ON pastes(last_accessed);
-- Content hash tracking for abuse prevention
CREATE TABLE IF NOT EXISTS content_hashes (
hash TEXT PRIMARY KEY,
first_seen INTEGER NOT NULL,
last_seen INTEGER NOT NULL,
count INTEGER NOT NULL DEFAULT 1
);
CREATE INDEX IF NOT EXISTS idx_content_hashes_last_seen ON content_hashes(last_seen);
"""
# Hold reference for in-memory shared cache databases
_memory_db_holder = None
def _get_connection_string(db_path) -> tuple[str, dict]:
"""Get connection string and kwargs for sqlite3.connect."""
if isinstance(db_path, Path):
db_path.parent.mkdir(parents=True, exist_ok=True)
return str(db_path), {}
if db_path == ":memory:":
return "file::memory:?cache=shared", {"uri": True}
return db_path, {}
def _is_file_database(db_path) -> bool:
"""Check if database path refers to a file (not in-memory)."""
if isinstance(db_path, Path):
return True
if isinstance(db_path, str) and db_path not in (":memory:", ""):
return not db_path.startswith("file::memory:")
return False
def get_db() -> sqlite3.Connection:
"""Get database connection for current request context."""
if "db" not in g:
db_path = current_app.config["DATABASE"]
conn_str, kwargs = _get_connection_string(db_path)
g.db = sqlite3.connect(conn_str, **kwargs)
g.db.row_factory = sqlite3.Row
g.db.execute("PRAGMA foreign_keys = ON")
if _is_file_database(db_path):
# WAL mode set in init_db; these optimize per-connection behavior
g.db.execute("PRAGMA busy_timeout = 5000")
g.db.execute("PRAGMA synchronous = NORMAL")
return g.db
def close_db(exception=None) -> None:
"""Close database connection at end of request."""
db = g.pop("db", None)
if db is not None:
db.close()
def init_db() -> None:
"""Initialize database schema."""
global _memory_db_holder
db_path = current_app.config["DATABASE"]
conn_str, kwargs = _get_connection_string(db_path)
# For in-memory databases, keep a connection alive
if db_path == ":memory:":
_memory_db_holder = sqlite3.connect(conn_str, **kwargs)
_memory_db_holder.executescript(SCHEMA)
_memory_db_holder.commit()
else:
db = get_db()
# Enable WAL mode for file databases (persists with database)
if _is_file_database(db_path):
db.execute("PRAGMA journal_mode = WAL")
db.executescript(SCHEMA)
db.commit()
def cleanup_expired_pastes() -> int:
"""Delete pastes that haven't been accessed within expiry period.
Returns number of deleted pastes.
"""
expiry_seconds = current_app.config["PASTE_EXPIRY_SECONDS"]
cutoff = int(time.time()) - expiry_seconds
db = get_db()
cursor = db.execute("DELETE FROM pastes WHERE last_accessed < ?", (cutoff,))
db.commit()
return cursor.rowcount
def cleanup_expired_hashes() -> int:
"""Delete content hashes outside the dedup window.
Returns number of deleted hashes.
"""
window = current_app.config["CONTENT_DEDUP_WINDOW"]
cutoff = int(time.time()) - window
db = get_db()
cursor = db.execute("DELETE FROM content_hashes WHERE last_seen < ?", (cutoff,))
db.commit()
return cursor.rowcount
def check_content_hash(content_hash: str) -> tuple[bool, int]:
"""Check if content hash exceeds dedup threshold.
Args:
content_hash: SHA256 hex digest of content
Returns:
Tuple of (is_allowed, current_count)
is_allowed is False if threshold exceeded within window
"""
window = current_app.config["CONTENT_DEDUP_WINDOW"]
max_count = current_app.config["CONTENT_DEDUP_MAX"]
now = int(time.time())
cutoff = now - window
db = get_db()
# Check existing hash record
row = db.execute(
"SELECT count, last_seen FROM content_hashes WHERE hash = ?",
(content_hash,)
).fetchone()
if row is None:
# First time seeing this content
db.execute(
"INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 1)",
(content_hash, now, now)
)
db.commit()
return True, 1
if row["last_seen"] < cutoff:
# Outside window, reset counter
db.execute(
"UPDATE content_hashes SET first_seen = ?, last_seen = ?, count = 1 WHERE hash = ?",
(now, now, content_hash)
)
db.commit()
return True, 1
# Within window, check threshold
current_count = row["count"] + 1
if current_count > max_count:
# Exceeded threshold, don't increment (prevent counter overflow)
return False, row["count"]
# Update counter
db.execute(
"UPDATE content_hashes SET last_seen = ?, count = ? WHERE hash = ?",
(now, current_count, content_hash)
)
db.commit()
return True, current_count
def init_app(app) -> None:
"""Register database functions with Flask app."""
app.teardown_appcontext(close_db)
with app.app_context():
init_db()