forked from username/flaskpaste
Authenticated users can tag pastes with a human-readable label via X-Display-Name header. Supports create, update, remove, and listing. Max 128 chars, control characters rejected.
362 lines
11 KiB
Python
362 lines
11 KiB
Python
"""Database connection and schema management."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import secrets
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from flask import current_app, g
|
|
|
|
# HASH-001: Lock for content hash deduplication to prevent race conditions
|
|
_content_hash_lock = threading.Lock()
|
|
|
|
if TYPE_CHECKING:
|
|
from flask import Flask
|
|
|
|
SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS pastes (
|
|
id TEXT PRIMARY KEY,
|
|
content BLOB NOT NULL,
|
|
mime_type TEXT NOT NULL DEFAULT 'text/plain',
|
|
owner TEXT,
|
|
created_at INTEGER NOT NULL,
|
|
last_accessed INTEGER NOT NULL,
|
|
burn_after_read INTEGER NOT NULL DEFAULT 0,
|
|
expires_at INTEGER,
|
|
password_hash TEXT,
|
|
display_name TEXT
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_pastes_created_at ON pastes(created_at);
|
|
CREATE INDEX IF NOT EXISTS idx_pastes_owner ON pastes(owner);
|
|
CREATE INDEX IF NOT EXISTS idx_pastes_last_accessed ON pastes(last_accessed);
|
|
|
|
-- Content hash tracking for abuse prevention
|
|
CREATE TABLE IF NOT EXISTS content_hashes (
|
|
hash TEXT PRIMARY KEY,
|
|
first_seen INTEGER NOT NULL,
|
|
last_seen INTEGER NOT NULL,
|
|
count INTEGER NOT NULL DEFAULT 1
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_content_hashes_last_seen ON content_hashes(last_seen);
|
|
|
|
-- PKI: Certificate Authority storage
|
|
CREATE TABLE IF NOT EXISTS certificate_authority (
|
|
id TEXT PRIMARY KEY DEFAULT 'default',
|
|
common_name TEXT NOT NULL,
|
|
certificate_pem TEXT NOT NULL,
|
|
private_key_encrypted BLOB NOT NULL,
|
|
key_salt BLOB NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
key_algorithm TEXT NOT NULL,
|
|
owner TEXT
|
|
);
|
|
|
|
-- PKI: Issued client certificates
|
|
CREATE TABLE IF NOT EXISTS issued_certificates (
|
|
serial TEXT PRIMARY KEY,
|
|
ca_id TEXT NOT NULL DEFAULT 'default',
|
|
common_name TEXT NOT NULL,
|
|
fingerprint_sha1 TEXT NOT NULL UNIQUE,
|
|
certificate_pem TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
issued_to TEXT,
|
|
status TEXT NOT NULL DEFAULT 'valid',
|
|
revoked_at INTEGER,
|
|
is_admin INTEGER NOT NULL DEFAULT 0,
|
|
FOREIGN KEY(ca_id) REFERENCES certificate_authority(id) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_certs_fingerprint ON issued_certificates(fingerprint_sha1);
|
|
CREATE INDEX IF NOT EXISTS idx_certs_status ON issued_certificates(status);
|
|
CREATE INDEX IF NOT EXISTS idx_certs_ca_id ON issued_certificates(ca_id);
|
|
|
|
-- Audit log for security events
|
|
CREATE TABLE IF NOT EXISTS audit_log (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
timestamp INTEGER NOT NULL,
|
|
event_type TEXT NOT NULL,
|
|
client_id TEXT,
|
|
client_ip TEXT,
|
|
paste_id TEXT,
|
|
request_id TEXT,
|
|
outcome TEXT NOT NULL,
|
|
details TEXT
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp);
|
|
CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_log(event_type);
|
|
CREATE INDEX IF NOT EXISTS idx_audit_client_id ON audit_log(client_id);
|
|
|
|
-- URL shortener
|
|
CREATE TABLE IF NOT EXISTS short_urls (
|
|
id TEXT PRIMARY KEY,
|
|
target_url TEXT NOT NULL,
|
|
url_hash TEXT NOT NULL,
|
|
owner TEXT,
|
|
created_at INTEGER NOT NULL,
|
|
last_accessed INTEGER NOT NULL,
|
|
access_count INTEGER NOT NULL DEFAULT 0,
|
|
expires_at INTEGER
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_short_urls_owner ON short_urls(owner);
|
|
CREATE INDEX IF NOT EXISTS idx_short_urls_created_at ON short_urls(created_at);
|
|
CREATE INDEX IF NOT EXISTS idx_short_urls_url_hash ON short_urls(url_hash);
|
|
"""
|
|
|
|
# Password hashing constants
|
|
_HASH_ITERATIONS = 600000 # OWASP 2023 recommendation for PBKDF2-SHA256
|
|
_SALT_LENGTH = 32
|
|
|
|
|
|
def hash_password(password: str) -> str | None:
|
|
"""Hash password using PBKDF2-HMAC-SHA256.
|
|
|
|
Returns format: $pbkdf2-sha256$iterations$salt$hash
|
|
All values are hex-encoded.
|
|
"""
|
|
if not password:
|
|
return None
|
|
|
|
salt = secrets.token_bytes(_SALT_LENGTH)
|
|
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, _HASH_ITERATIONS)
|
|
return f"$pbkdf2-sha256${_HASH_ITERATIONS}${salt.hex()}${dk.hex()}"
|
|
|
|
|
|
def verify_password(password: str, password_hash: str) -> bool:
|
|
"""Verify password against stored hash.
|
|
|
|
Uses constant-time comparison to prevent timing attacks.
|
|
"""
|
|
if not password or not password_hash:
|
|
return False
|
|
|
|
try:
|
|
# Parse hash format: $pbkdf2-sha256$iterations$salt$hash
|
|
parts = password_hash.split("$")
|
|
if len(parts) != 5 or parts[1] != "pbkdf2-sha256":
|
|
return False
|
|
|
|
iterations = int(parts[2])
|
|
salt = bytes.fromhex(parts[3])
|
|
stored_hash = bytes.fromhex(parts[4])
|
|
|
|
# Compute hash of provided password
|
|
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations)
|
|
|
|
# Constant-time comparison
|
|
return secrets.compare_digest(dk, stored_hash)
|
|
except (ValueError, IndexError):
|
|
return False
|
|
|
|
|
|
# Hold reference for in-memory shared cache databases
|
|
_memory_db_holder = None
|
|
|
|
|
|
def _get_connection_string(db_path: str | Path) -> tuple[str, dict[str, Any]]:
|
|
"""Get connection string and kwargs for sqlite3.connect."""
|
|
if isinstance(db_path, Path):
|
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
return str(db_path), {}
|
|
if db_path == ":memory:":
|
|
return "file::memory:?cache=shared", {"uri": True}
|
|
return db_path, {}
|
|
|
|
|
|
def _is_file_database(db_path: str | Path) -> bool:
|
|
"""Check if database path refers to a file (not in-memory)."""
|
|
if isinstance(db_path, Path):
|
|
return True
|
|
if isinstance(db_path, str) and db_path not in (":memory:", ""):
|
|
return not db_path.startswith("file::memory:")
|
|
return False
|
|
|
|
|
|
def get_db() -> sqlite3.Connection:
|
|
"""Get database connection for current request context."""
|
|
if "db" not in g:
|
|
db_path = current_app.config["DATABASE"]
|
|
conn_str, kwargs = _get_connection_string(db_path)
|
|
g.db = sqlite3.connect(conn_str, **kwargs)
|
|
g.db.row_factory = sqlite3.Row
|
|
g.db.execute("PRAGMA foreign_keys = ON")
|
|
if _is_file_database(db_path):
|
|
# WAL mode set in init_db; these optimize per-connection behavior
|
|
g.db.execute("PRAGMA busy_timeout = 5000")
|
|
g.db.execute("PRAGMA synchronous = NORMAL")
|
|
conn: sqlite3.Connection = g.db
|
|
return conn
|
|
|
|
|
|
def close_db(exception: BaseException | None = None) -> None:
|
|
"""Close database connection at end of request."""
|
|
db = g.pop("db", None)
|
|
if db is not None:
|
|
db.close()
|
|
|
|
|
|
def init_db() -> None:
|
|
"""Initialize database schema."""
|
|
global _memory_db_holder
|
|
|
|
db_path = current_app.config["DATABASE"]
|
|
conn_str, kwargs = _get_connection_string(db_path)
|
|
|
|
# For in-memory databases, keep a connection alive
|
|
if db_path == ":memory:":
|
|
_memory_db_holder = sqlite3.connect(conn_str, **kwargs)
|
|
_memory_db_holder.executescript(SCHEMA)
|
|
_memory_db_holder.commit()
|
|
else:
|
|
db = get_db()
|
|
# Enable WAL mode for file databases (persists with database)
|
|
if _is_file_database(db_path):
|
|
db.execute("PRAGMA journal_mode = WAL")
|
|
db.executescript(SCHEMA)
|
|
db.commit()
|
|
|
|
|
|
def cleanup_expired_pastes() -> int:
|
|
"""Delete pastes that have expired.
|
|
|
|
Pastes expire based on:
|
|
- Custom expires_at timestamp if set
|
|
- Default expiry from last_accessed if expires_at is NULL
|
|
|
|
Returns number of deleted pastes.
|
|
"""
|
|
expiry_seconds = current_app.config["PASTE_EXPIRY_SECONDS"]
|
|
now = int(time.time())
|
|
default_cutoff = now - expiry_seconds
|
|
|
|
db = get_db()
|
|
# Delete pastes with custom expiry that have passed,
|
|
# OR pastes without custom expiry that exceed default window
|
|
cursor = db.execute(
|
|
"""DELETE FROM pastes WHERE
|
|
(expires_at IS NOT NULL AND expires_at < ?)
|
|
OR (expires_at IS NULL AND last_accessed < ?)""",
|
|
(now, default_cutoff),
|
|
)
|
|
db.commit()
|
|
|
|
return cursor.rowcount
|
|
|
|
|
|
def cleanup_expired_hashes() -> int:
|
|
"""Delete content hashes outside the dedup window.
|
|
|
|
Returns number of deleted hashes.
|
|
"""
|
|
window = current_app.config["CONTENT_DEDUP_WINDOW"]
|
|
cutoff = int(time.time()) - window
|
|
|
|
db = get_db()
|
|
cursor = db.execute("DELETE FROM content_hashes WHERE last_seen < ?", (cutoff,))
|
|
db.commit()
|
|
|
|
return cursor.rowcount
|
|
|
|
|
|
def cleanup_expired_short_urls() -> int:
|
|
"""Delete short URLs that have expired.
|
|
|
|
Short URLs expire based on:
|
|
- Custom expires_at timestamp if set
|
|
- Default expiry from last_accessed if expires_at is NULL
|
|
|
|
Returns number of deleted short URLs.
|
|
"""
|
|
expiry_seconds = current_app.config["PASTE_EXPIRY_SECONDS"]
|
|
now = int(time.time())
|
|
default_cutoff = now - expiry_seconds
|
|
|
|
db = get_db()
|
|
cursor = db.execute(
|
|
"""DELETE FROM short_urls WHERE
|
|
(expires_at IS NOT NULL AND expires_at < ?)
|
|
OR (expires_at IS NULL AND last_accessed < ?)""",
|
|
(now, default_cutoff),
|
|
)
|
|
db.commit()
|
|
|
|
return cursor.rowcount
|
|
|
|
|
|
def check_content_hash(content_hash: str) -> tuple[bool, int]:
|
|
"""Check if content hash exceeds dedup threshold.
|
|
|
|
Args:
|
|
content_hash: SHA256 hex digest of content
|
|
|
|
Returns:
|
|
Tuple of (is_allowed, current_count)
|
|
is_allowed is False if threshold exceeded within window
|
|
"""
|
|
window = current_app.config["CONTENT_DEDUP_WINDOW"]
|
|
max_count = current_app.config["CONTENT_DEDUP_MAX"]
|
|
now = int(time.time())
|
|
cutoff = now - window
|
|
|
|
db = get_db()
|
|
|
|
# HASH-001: Lock to prevent race condition between SELECT and UPDATE
|
|
with _content_hash_lock:
|
|
# Check existing hash record
|
|
row = db.execute(
|
|
"SELECT count, last_seen FROM content_hashes WHERE hash = ?", (content_hash,)
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
# First time seeing this content
|
|
db.execute(
|
|
"INSERT INTO content_hashes (hash, first_seen, last_seen, count) "
|
|
"VALUES (?, ?, ?, 1)",
|
|
(content_hash, now, now),
|
|
)
|
|
db.commit()
|
|
return True, 1
|
|
|
|
if row["last_seen"] < cutoff:
|
|
# Outside window, reset counter
|
|
db.execute(
|
|
"UPDATE content_hashes SET first_seen = ?, last_seen = ?, count = 1 WHERE hash = ?",
|
|
(now, now, content_hash),
|
|
)
|
|
db.commit()
|
|
return True, 1
|
|
|
|
# Within window, check threshold
|
|
current_count = row["count"] + 1
|
|
|
|
if current_count > max_count:
|
|
# Exceeded threshold, don't increment (prevent counter overflow)
|
|
return False, row["count"]
|
|
|
|
# Update counter
|
|
db.execute(
|
|
"UPDATE content_hashes SET last_seen = ?, count = ? WHERE hash = ?",
|
|
(now, current_count, content_hash),
|
|
)
|
|
db.commit()
|
|
|
|
return True, current_count
|
|
|
|
|
|
def init_app(app: Flask) -> None:
|
|
"""Register database functions with Flask app."""
|
|
app.teardown_appcontext(close_db)
|
|
|
|
with app.app_context():
|
|
init_db()
|