diff --git a/Containerfile b/Containerfile index b6ae11e..d50b7b0 100644 --- a/Containerfile +++ b/Containerfile @@ -1,7 +1,27 @@ -# FlaskPaste Container Image +# FlaskPaste Container Image (Multi-Stage Build) # Build: podman build -t flaskpaste . # Run: podman run -d -p 5000:5000 -v flaskpaste-data:/app/data flaskpaste +# Stage 1: Build dependencies +FROM python:3.11-slim AS builder + +WORKDIR /build + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Create virtual environment +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt gunicorn + + +# Stage 2: Runtime image FROM python:3.11-slim LABEL maintainer="FlaskPaste" @@ -10,19 +30,19 @@ LABEL description="Lightweight secure pastebin REST API" # Create non-root user RUN groupadd -r flaskpaste && useradd -r -g flaskpaste flaskpaste +# Copy virtual environment from builder +COPY --from=builder /opt/venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + # Set working directory WORKDIR /app -# Install dependencies first (cache layer) -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt gunicorn - -# Copy application code +# Copy only necessary application files COPY app/ ./app/ COPY wsgi.py . COPY fpaste . -# Create data directory +# Create data directory with correct ownership RUN mkdir -p /app/data && chown -R flaskpaste:flaskpaste /app # Switch to non-root user @@ -32,6 +52,7 @@ USER flaskpaste ENV FLASK_ENV=production ENV FLASKPASTE_DB=/app/data/pastes.db ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 # Expose port EXPOSE 5000 diff --git a/app/api/routes.py b/app/api/routes.py index 24cc73b..6b308cf 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -1,163 +1,256 @@ -"""API route handlers.""" +"""API route handlers using modern Flask patterns.""" + +from __future__ import annotations import hashlib import hmac import json import math -import os import re import secrets import time +from typing import TYPE_CHECKING, Any -from flask import Response, current_app, request +from flask import Response, current_app, g, request +from flask.views import MethodView from app.api import bp from app.config import VERSION -from app.database import check_content_hash, get_db +from app.database import check_content_hash, get_db, hash_password, verify_password -# Valid paste ID pattern (hexadecimal only) +if TYPE_CHECKING: + from sqlite3 import Row + +# Compiled patterns for validation PASTE_ID_PATTERN = re.compile(r"^[a-f0-9]+$") - - -def _url(path: str) -> str: - """Generate URL with configured prefix for reverse proxy deployments.""" - prefix = current_app.config.get("URL_PREFIX", "") - return f"{prefix}{path}" - - -def _base_url() -> str: - """Detect base URL from request headers (for reverse proxy deployments).""" - scheme = ( - request.headers.get("X-Forwarded-Proto") - or request.headers.get("X-Scheme") - or request.scheme - ) - host = ( - request.headers.get("X-Forwarded-Host") - or request.headers.get("Host") - or request.host - ) - prefix = current_app.config.get("URL_PREFIX", "") - return f"{scheme}://{host}{prefix}" - -# Runtime-generated PoW secret (used if not configured) -_pow_secret_cache = None - -# Valid client certificate SHA1 pattern (40 hex chars) CLIENT_ID_PATTERN = re.compile(r"^[a-f0-9]{40}$") +MIME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9!#$&\-^_.+]*/[a-z0-9][a-z0-9!#$&\-^_.+]*$") -# Magic bytes for common binary formats -MAGIC_SIGNATURES = { +# Magic bytes for binary format detection +MAGIC_SIGNATURES: dict[bytes, str] = { b"\x89PNG\r\n\x1a\n": "image/png", b"\xff\xd8\xff": "image/jpeg", b"GIF87a": "image/gif", b"GIF89a": "image/gif", - b"RIFF": "image/webp", # WebP (check for WEBP after RIFF) + b"RIFF": "image/webp", b"PK\x03\x04": "application/zip", b"%PDF": "application/pdf", b"\x1f\x8b": "application/gzip", } +# Generic MIME types to override with detection +GENERIC_MIME_TYPES = frozenset( + { + "application/octet-stream", + "application/x-www-form-urlencoded", + "text/plain", + } +) -def _calculate_entropy(data: bytes) -> float: - """Calculate Shannon entropy in bits per byte. - - Returns value between 0 (uniform) and 8 (perfectly random). - Encrypted/compressed data: ~7.5-8.0 - English text: ~4.0-5.0 - Binary executables: ~5.0-6.5 - """ - if not data: - return 0.0 - - # Count byte frequencies - freq = [0] * 256 - for byte in data: - freq[byte] += 1 - - # Calculate entropy - length = len(data) - entropy = 0.0 - for count in freq: - if count > 0: - p = count / length - entropy -= p * math.log2(p) - - return entropy +# Runtime PoW secret cache +_pow_secret_cache: bytes | None = None -def _get_pow_secret() -> bytes: - """Get or generate the PoW signing secret.""" +# ───────────────────────────────────────────────────────────────────────────── +# Response Helpers +# ───────────────────────────────────────────────────────────────────────────── + + +def json_response(data: dict[str, Any], status: int = 200) -> Response: + """Create JSON response with proper encoding.""" + return Response( + json.dumps(data, ensure_ascii=False), + status=status, + mimetype="application/json", + ) + + +def error_response(message: str, status: int, **extra: Any) -> Response: + """Create standardized error response.""" + data = {"error": message, **extra} + return json_response(data, status) + + +# ───────────────────────────────────────────────────────────────────────────── +# URL Helpers +# ───────────────────────────────────────────────────────────────────────────── + + +def url_prefix() -> str: + """Get configured URL prefix for reverse proxy deployments.""" + return current_app.config.get("URL_PREFIX", "") + + +def prefixed_url(path: str) -> str: + """Generate URL with configured prefix.""" + return f"{url_prefix()}{path}" + + +def base_url() -> str: + """Detect full base URL from request headers.""" + scheme = ( + request.headers.get("X-Forwarded-Proto") + or request.headers.get("X-Scheme") + or request.scheme + ) + host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host") or request.host + return f"{scheme}://{host}{url_prefix()}" + + +# ───────────────────────────────────────────────────────────────────────────── +# Validation Helpers (used within views) +# ───────────────────────────────────────────────────────────────────────────── + + +def validate_paste_id(paste_id: str) -> Response | None: + """Validate paste ID format. Returns error response or None if valid.""" + expected_length = current_app.config["PASTE_ID_LENGTH"] + if len(paste_id) != expected_length or not PASTE_ID_PATTERN.match(paste_id): + return error_response("Invalid paste ID", 400) + return None + + +def fetch_paste(paste_id: str, check_password: bool = True) -> Response | None: + """Fetch paste and store in g.paste. Returns error response or None if OK.""" + db = get_db() + now = int(time.time()) + + # Update access time + db.execute("UPDATE pastes SET last_accessed = ? WHERE id = ?", (now, paste_id)) + + row = db.execute( + """SELECT id, content, mime_type, owner, created_at, + length(content) as size, burn_after_read, expires_at, password_hash + FROM pastes WHERE id = ?""", + (paste_id,), + ).fetchone() + + if row is None: + db.commit() + return error_response("Paste not found", 404) + + # Password verification + if check_password and row["password_hash"]: + provided = request.headers.get("X-Paste-Password", "") + if not provided: + db.commit() + return error_response("Password required", 401, password_protected=True) + if not verify_password(provided, row["password_hash"]): + db.commit() + return error_response("Invalid password", 403) + + g.paste = row + g.db = db + return None + + +def require_auth() -> Response | None: + """Check authentication. Returns error response or None if authenticated.""" + client_id = get_client_id() + if not client_id: + return error_response("Authentication required", 401) + g.client_id = client_id + return None + + +# ───────────────────────────────────────────────────────────────────────────── +# Authentication & Security +# ───────────────────────────────────────────────────────────────────────────── + + +def is_trusted_proxy() -> bool: + """Verify request comes from trusted reverse proxy via shared secret.""" + expected = current_app.config.get("TRUSTED_PROXY_SECRET", "") + if not expected: + return True + provided = request.headers.get("X-Proxy-Secret", "") + return hmac.compare_digest(expected, provided) + + +def get_client_id() -> str | None: + """Extract and validate client certificate fingerprint.""" + if not is_trusted_proxy(): + current_app.logger.warning( + "Auth header ignored: X-Proxy-Secret mismatch from %s", request.remote_addr + ) + return None + + sha1 = request.headers.get("X-SSL-Client-SHA1", "").strip().lower() + if sha1 and CLIENT_ID_PATTERN.match(sha1): + # Check if PKI is enabled and certificate is revoked + if current_app.config.get("PKI_ENABLED"): + from app.pki import is_certificate_valid + + if not is_certificate_valid(sha1): + current_app.logger.warning( + "Auth rejected: certificate revoked or expired: %s", sha1[:12] + "..." + ) + return None + return sha1 + return None + + +# ───────────────────────────────────────────────────────────────────────────── +# Proof-of-Work +# ───────────────────────────────────────────────────────────────────────────── + + +def get_pow_secret() -> bytes: + """Get or generate PoW signing secret.""" global _pow_secret_cache - configured = current_app.config.get("POW_SECRET", "") if configured: return configured.encode() - if _pow_secret_cache is None: _pow_secret_cache = secrets.token_bytes(32) return _pow_secret_cache -def _generate_challenge() -> dict: - """Generate a new PoW challenge.""" +def generate_challenge() -> dict[str, Any]: + """Generate new PoW challenge with signed token.""" difficulty = current_app.config["POW_DIFFICULTY"] ttl = current_app.config["POW_CHALLENGE_TTL"] expires = int(time.time()) + ttl nonce = secrets.token_hex(16) - # Sign the challenge to prevent tampering msg = f"{nonce}:{expires}:{difficulty}".encode() - sig = hmac.new(_get_pow_secret(), msg, hashlib.sha256).hexdigest() + sig = hmac.new(get_pow_secret(), msg, hashlib.sha256).hexdigest() return { "nonce": nonce, "difficulty": difficulty, "expires": expires, - "signature": sig, + "token": f"{nonce}:{expires}:{difficulty}:{sig}", } -def _verify_pow(challenge: str, nonce: str, solution: str) -> tuple[bool, str]: - """Verify a proof-of-work solution. - - Args: - challenge: The challenge nonce from /challenge - nonce: Combined "nonce:expires:difficulty:signature" string - solution: The solution number found by client - - Returns: - Tuple of (valid, error_message) - """ +def verify_pow(token: str, solution: str) -> tuple[bool, str]: + """Verify proof-of-work solution. Returns (valid, error_message).""" difficulty = current_app.config["POW_DIFFICULTY"] - - # PoW disabled if difficulty == 0: return True, "" - # Parse challenge components + # Parse token try: - parts = nonce.split(":") + parts = token.split(":") if len(parts) != 4: return False, "Invalid challenge format" - ch_nonce, ch_expires, ch_difficulty, ch_sig = parts - ch_expires = int(ch_expires) - ch_difficulty = int(ch_difficulty) + nonce, expires_str, diff_str, sig = parts + expires = int(expires_str) + token_diff = int(diff_str) except (ValueError, TypeError): return False, "Invalid challenge format" # Verify signature - msg = f"{ch_nonce}:{ch_expires}:{ch_difficulty}".encode() - expected_sig = hmac.new(_get_pow_secret(), msg, hashlib.sha256).hexdigest() - if not hmac.compare_digest(ch_sig, expected_sig): + msg = f"{nonce}:{expires}:{token_diff}".encode() + expected_sig = hmac.new(get_pow_secret(), msg, hashlib.sha256).hexdigest() + if not hmac.compare_digest(sig, expected_sig): return False, "Invalid challenge signature" - # Check expiry - if int(time.time()) > ch_expires: + # Check expiry and difficulty + if int(time.time()) > expires: return False, "Challenge expired" - - # Verify difficulty matches current config - if ch_difficulty != difficulty: + if token_diff != difficulty: return False, "Difficulty mismatch" # Verify solution @@ -168,18 +261,16 @@ def _verify_pow(challenge: str, nonce: str, solution: str) -> tuple[bool, str]: except (ValueError, TypeError): return False, "Invalid solution" - # Check hash meets difficulty requirement - work = f"{ch_nonce}:{solution}".encode() + # Check hash meets difficulty + work = f"{nonce}:{solution}".encode() hash_bytes = hashlib.sha256(work).digest() - # Count leading zero bits zero_bits = 0 for byte in hash_bytes: if byte == 0: zero_bits += 8 else: - # Count leading zeros in this byte - zero_bits += (8 - byte.bit_length()) + zero_bits += 8 - byte.bit_length() break if zero_bits < difficulty: @@ -188,40 +279,47 @@ def _verify_pow(challenge: str, nonce: str, solution: str) -> tuple[bool, str]: return True, "" -def _is_valid_paste_id(paste_id: str) -> bool: - """Validate paste ID format (hexadecimal, correct length).""" - expected_length = current_app.config["PASTE_ID_LENGTH"] - return ( - len(paste_id) == expected_length - and PASTE_ID_PATTERN.match(paste_id) is not None - ) +# ───────────────────────────────────────────────────────────────────────────── +# Content Processing +# ───────────────────────────────────────────────────────────────────────────── -def _detect_mime_type(content: bytes, content_type: str | None = None) -> str: - """Detect MIME type from content bytes, with magic byte detection taking priority.""" - # Check magic bytes first - most reliable method +def calculate_entropy(data: bytes) -> float: + """Calculate Shannon entropy in bits per byte (0-8 range).""" + if not data: + return 0.0 + + freq = [0] * 256 + for byte in data: + freq[byte] += 1 + + length = len(data) + entropy = 0.0 + for count in freq: + if count > 0: + p = count / length + entropy -= p * math.log2(p) + + return entropy + + +def detect_mime_type(content: bytes, content_type: str | None = None) -> str: + """Detect MIME type using magic bytes, headers, or content analysis.""" + # Magic byte detection (highest priority) for magic, mime in MAGIC_SIGNATURES.items(): if content.startswith(magic): - # Special case for WebP (RIFF....WEBP) - if magic == b"RIFF" and len(content) >= 12: - if content[8:12] != b"WEBP": - continue + # RIFF container: verify WEBP subtype + if magic == b"RIFF" and len(content) >= 12 and content[8:12] != b"WEBP": + continue return mime - # Trust explicit Content-Type if it's specific (not generic defaults) - generic_types = { - "application/octet-stream", - "application/x-www-form-urlencoded", - "text/plain", - } + # Explicit Content-Type (if specific) if content_type: mime = content_type.split(";")[0].strip().lower() - if mime not in generic_types: - # Sanitize: only allow safe characters in MIME type - if re.match(r"^[a-z0-9][a-z0-9!#$&\-^_.+]*\/[a-z0-9][a-z0-9!#$&\-^_.+]*$", mime): - return mime + if mime not in GENERIC_MIME_TYPES and MIME_PATTERN.match(mime): + return mime - # Try to decode as UTF-8 text + # UTF-8 text detection try: content.decode("utf-8") return "text/plain" @@ -229,372 +327,703 @@ def _detect_mime_type(content: bytes, content_type: str | None = None) -> str: return "application/octet-stream" -def _generate_id(content: bytes) -> str: - """Generate a short unique ID from content hash and timestamp.""" +def generate_paste_id(content: bytes) -> str: + """Generate unique paste ID from content hash and timestamp.""" data = content + str(time.time_ns()).encode() length = current_app.config["PASTE_ID_LENGTH"] return hashlib.sha256(data).hexdigest()[:length] -def _json_response(data: dict, status: int = 200) -> Response: - """Create a JSON response with proper encoding and security headers.""" - response = Response( - json.dumps(data, ensure_ascii=False), - status=status, - mimetype="application/json", - ) - return response +# ───────────────────────────────────────────────────────────────────────────── +# Class-Based Views +# ───────────────────────────────────────────────────────────────────────────── -def _is_trusted_proxy() -> bool: - """Verify request comes from a trusted reverse proxy. +class IndexView(MethodView): + """Handle API info and paste creation.""" - If TRUSTED_PROXY_SECRET is configured, the request must include a matching - X-Proxy-Secret header. This provides defense-in-depth against header spoofing - if an attacker bypasses the reverse proxy. - - Returns True if no secret is configured (backwards compatible) or if the - secret matches. - """ - expected_secret = current_app.config.get("TRUSTED_PROXY_SECRET", "") - if not expected_secret: - # No secret configured - trust all requests (backwards compatible) - return True - - # Constant-time comparison to prevent timing attacks - provided_secret = request.headers.get("X-Proxy-Secret", "") - return hmac.compare_digest(expected_secret, provided_secret) - - -def _get_client_id() -> str | None: - """Extract and validate client identity from X-SSL-Client-SHA1 header. - - Returns lowercase SHA1 fingerprint or None if not present/invalid. - - SECURITY: The X-SSL-Client-SHA1 header is only trusted if the request - comes from a trusted proxy (verified via X-Proxy-Secret if configured). - """ - # Verify request comes from trusted proxy before trusting auth headers - if not _is_trusted_proxy(): - current_app.logger.warning( - "Auth header ignored: X-Proxy-Secret mismatch from %s", - request.remote_addr + def get(self) -> Response: + """Return API information and usage examples.""" + prefix = url_prefix() or "/" + return json_response( + { + "name": "FlaskPaste", + "version": VERSION, + "prefix": prefix, + "endpoints": { + f"GET {prefixed_url('/')}": "API information", + f"GET {prefixed_url('/health')}": "Health check", + f"GET {prefixed_url('/client')}": "Download CLI client", + f"GET {prefixed_url('/challenge')}": "Get PoW challenge", + f"POST {prefixed_url('/')}": "Create paste", + f"GET {prefixed_url('/')}": "Retrieve paste metadata", + f"GET {prefixed_url('//raw')}": "Retrieve raw paste content", + f"DELETE {prefixed_url('/')}": "Delete paste", + }, + "usage": { + "raw": f"curl --data-binary @file.txt {base_url()}/", + "pipe": f"cat file.txt | curl --data-binary @- {base_url()}/", + "json": f"curl -H \"Content-Type: application/json\" -d '...' {base_url()}/", + }, + "note": "Use --data-binary (not -d) to preserve newlines", + } ) - return None - client_sha1 = request.headers.get("X-SSL-Client-SHA1", "").strip().lower() - # Validate format: must be 40 hex characters (SHA1) - if client_sha1 and CLIENT_ID_PATTERN.match(client_sha1): - return client_sha1 + def post(self) -> Response: + """Create a new paste.""" + # Parse content + content: bytes | None = None + mime_type: str | None = None + + if request.is_json: + data = request.get_json(silent=True) + if data and isinstance(data.get("content"), str): + content = data["content"].encode("utf-8") + mime_type = "text/plain" + else: + content = request.get_data(as_text=False) + if content: + mime_type = detect_mime_type(content, request.content_type) + + if not content: + return error_response("No content provided", 400) + + owner = get_client_id() + + # Proof-of-work verification + difficulty = current_app.config["POW_DIFFICULTY"] + if difficulty > 0: + token = request.headers.get("X-PoW-Token", "") + solution = request.headers.get("X-PoW-Solution", "") + + if not token or not solution: + return error_response( + "Proof-of-work required", 400, hint="GET /challenge for a new challenge" + ) + + valid, err = verify_pow(token, solution) + if not valid: + current_app.logger.warning( + "PoW verification failed: %s from=%s", err, request.remote_addr + ) + return error_response(f"Proof-of-work failed: {err}", 400) + + # Size limits + content_size = len(content) + max_size = ( + current_app.config["MAX_PASTE_SIZE_AUTH"] + if owner + else current_app.config["MAX_PASTE_SIZE_ANON"] + ) + + if content_size > max_size: + return error_response( + "Paste too large", + 413, + size=content_size, + max_size=max_size, + authenticated=owner is not None, + ) + + # Entropy check + min_entropy = current_app.config.get("MIN_ENTROPY", 0) + min_entropy_size = current_app.config.get("MIN_ENTROPY_SIZE", 256) + if min_entropy > 0 and content_size >= min_entropy_size: + entropy = calculate_entropy(content) + if entropy < min_entropy: + current_app.logger.warning( + "Low entropy rejected: %.2f < %.2f from=%s", + entropy, + min_entropy, + request.remote_addr, + ) + return error_response( + "Content entropy too low", + 400, + entropy=round(entropy, 2), + min_entropy=min_entropy, + hint="Encrypt content before uploading (-e flag in fpaste)", + ) + + # Deduplication check + content_hash = hashlib.sha256(content).hexdigest() + is_allowed, dedup_count = check_content_hash(content_hash) + + if not is_allowed: + window = current_app.config["CONTENT_DEDUP_WINDOW"] + current_app.logger.warning( + "Dedup threshold exceeded: hash=%s count=%d from=%s", + content_hash[:16], + dedup_count, + request.remote_addr, + ) + return error_response( + "Duplicate content rate limit exceeded", + 429, + count=dedup_count, + window_seconds=window, + ) + + # Parse optional headers + burn_header = request.headers.get("X-Burn-After-Read", "").strip().lower() + burn_after_read = burn_header in ("true", "1", "yes") + + expires_at = None + expiry_header = request.headers.get("X-Expiry", "").strip() + if expiry_header: + try: + expiry_seconds = int(expiry_header) + if expiry_seconds > 0: + max_expiry = current_app.config.get("MAX_EXPIRY_SECONDS", 0) + if max_expiry > 0: + expiry_seconds = min(expiry_seconds, max_expiry) + expires_at = int(time.time()) + expiry_seconds + except ValueError: + pass + + password_hash = None + password_header = request.headers.get("X-Paste-Password", "") + if password_header: + if len(password_header) > 1024: + return error_response("Password too long (max 1024 chars)", 400) + password_hash = hash_password(password_header) + + # Insert paste + paste_id = generate_paste_id(content) + now = int(time.time()) + + db = get_db() + db.execute( + """INSERT INTO pastes + (id, content, mime_type, owner, created_at, last_accessed, + burn_after_read, expires_at, password_hash) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + paste_id, + content, + mime_type, + owner, + now, + now, + 1 if burn_after_read else 0, + expires_at, + password_hash, + ), + ) + db.commit() + + # Build response + response_data: dict[str, Any] = { + "id": paste_id, + "url": f"/{paste_id}", + "raw": f"/{paste_id}/raw", + "mime_type": mime_type, + "created_at": now, + } + if owner: + response_data["owner"] = owner + if burn_after_read: + response_data["burn_after_read"] = True + if expires_at: + response_data["expires_at"] = expires_at + if password_hash: + response_data["password_protected"] = True + + return json_response(response_data, 201) + + +class HealthView(MethodView): + """Health check endpoint.""" + + def get(self) -> Response: + """Return health status with database check.""" + try: + db = get_db() + db.execute("SELECT 1") + return json_response({"status": "healthy", "database": "ok"}) + except Exception: + return json_response({"status": "unhealthy", "database": "error"}, 503) + + +class ChallengeView(MethodView): + """Proof-of-work challenge endpoint.""" + + def get(self) -> Response: + """Generate and return PoW challenge.""" + difficulty = current_app.config["POW_DIFFICULTY"] + if difficulty == 0: + return json_response({"enabled": False, "difficulty": 0}) + + ch = generate_challenge() + return json_response( + { + "enabled": True, + "nonce": ch["nonce"], + "difficulty": ch["difficulty"], + "expires": ch["expires"], + "token": ch["token"], + } + ) + + +class ClientView(MethodView): + """CLI client download endpoint.""" + + def get(self) -> Response: + """Serve fpaste CLI with server URL pre-configured.""" + import os + + server_url = base_url() + client_path = os.path.join(current_app.root_path, "..", "fpaste") + + try: + with open(client_path) as f: + content = f.read() + + # Replace default server URL + content = content.replace( + '"server": os.environ.get("FLASKPASTE_SERVER", "http://localhost:5000")', + f'"server": os.environ.get("FLASKPASTE_SERVER", "{server_url}")', + ) + content = content.replace( + "http://localhost:5000)", + f"{server_url})", + ) + + response = Response(content, mimetype="text/x-python") + response.headers["Content-Disposition"] = "attachment; filename=fpaste" + return response + except FileNotFoundError: + return error_response("Client not available", 404) + + +class PasteView(MethodView): + """Paste metadata operations.""" + + def get(self, paste_id: str) -> Response: + """Retrieve paste metadata.""" + # Validate and fetch + if err := validate_paste_id(paste_id): + return err + if err := fetch_paste(paste_id): + return err + + row: Row = g.paste + g.db.commit() + + response_data: dict[str, Any] = { + "id": row["id"], + "mime_type": row["mime_type"], + "size": row["size"], + "created_at": row["created_at"], + "raw": f"/{paste_id}/raw", + } + if row["burn_after_read"]: + response_data["burn_after_read"] = True + if row["expires_at"]: + response_data["expires_at"] = row["expires_at"] + if row["password_hash"]: + response_data["password_protected"] = True + + return json_response(response_data) + + def head(self, paste_id: str) -> Response: + """Return paste metadata headers only.""" + return self.get(paste_id) + + +class PasteRawView(MethodView): + """Raw paste content retrieval.""" + + def get(self, paste_id: str) -> Response: + """Retrieve raw paste content.""" + # Validate and fetch + if err := validate_paste_id(paste_id): + return err + if err := fetch_paste(paste_id): + return err + + row: Row = g.paste + db = g.db + + burn_after_read = row["burn_after_read"] + if burn_after_read: + db.execute("DELETE FROM pastes WHERE id = ?", (paste_id,)) + current_app.logger.info("Burn-after-read paste deleted: %s", paste_id) + + db.commit() + + response = Response(row["content"], mimetype=row["mime_type"]) + if row["mime_type"].startswith(("image/", "text/")): + response.headers["Content-Disposition"] = "inline" + if burn_after_read: + response.headers["X-Burn-After-Read"] = "true" + + return response + + def head(self, paste_id: str) -> Response: + """Return raw paste headers without triggering burn.""" + # Validate and fetch + if err := validate_paste_id(paste_id): + return err + if err := fetch_paste(paste_id): + return err + + row: Row = g.paste + g.db.commit() + + response = Response(mimetype=row["mime_type"]) + response.headers["Content-Length"] = str(row["size"]) + if row["mime_type"].startswith(("image/", "text/")): + response.headers["Content-Disposition"] = "inline" + if row["burn_after_read"]: + response.headers["X-Burn-After-Read"] = "true" + + return response + + +class PasteDeleteView(MethodView): + """Paste deletion with authentication.""" + + def delete(self, paste_id: str) -> Response: + """Delete paste. Requires ownership.""" + # Validate + if err := validate_paste_id(paste_id): + return err + if err := require_auth(): + return err + + db = get_db() + + row = db.execute("SELECT owner FROM pastes WHERE id = ?", (paste_id,)).fetchone() + + if row is None: + return error_response("Paste not found", 404) + + if row["owner"] != g.client_id: + return error_response("Permission denied", 403) + + db.execute("DELETE FROM pastes WHERE id = ?", (paste_id,)) + db.commit() + + return json_response({"message": "Paste deleted"}) + + +# ───────────────────────────────────────────────────────────────────────────── +# PKI Views (Certificate Authority) +# ───────────────────────────────────────────────────────────────────────────── + + +def require_pki_enabled() -> Response | None: + """Check if PKI is enabled. Returns error response or None if enabled.""" + if not current_app.config.get("PKI_ENABLED"): + return error_response("PKI not enabled", 404) return None -@bp.route("/client", methods=["GET"]) -def client(): - """Download the fpaste CLI client with server URL pre-configured.""" - import os +class PKIStatusView(MethodView): + """PKI status endpoint.""" - # Detect scheme (check reverse proxy headers first) - scheme = ( - request.headers.get("X-Forwarded-Proto") - or request.headers.get("X-Scheme") - or request.scheme - ) + def get(self) -> Response: + """Return PKI status and CA info if available.""" + if not current_app.config.get("PKI_ENABLED"): + return json_response({"enabled": False}) - # Detect host (check reverse proxy headers first) - host = ( - request.headers.get("X-Forwarded-Host") - or request.headers.get("Host") - or request.host - ) + from app.pki import get_ca_info - # Build server URL with prefix - prefix = current_app.config.get("URL_PREFIX", "") - server_url = f"{scheme}://{host}{prefix}" + ca_info = get_ca_info() + if ca_info is None: + return json_response( + { + "enabled": True, + "ca_exists": False, + "hint": "POST /pki/ca to generate CA", + } + ) - client_path = os.path.join(current_app.root_path, "..", "fpaste") - try: - with open(client_path, "r") as f: - content = f.read() - - # Replace default server URL - content = content.replace( - '"server": os.environ.get("FLASKPASTE_SERVER", "http://localhost:5000")', - f'"server": os.environ.get("FLASKPASTE_SERVER", "{server_url}")', - ) - content = content.replace( - "http://localhost:5000)", - f"{server_url})", + return json_response( + { + "enabled": True, + "ca_exists": True, + "common_name": ca_info["common_name"], + "fingerprint_sha1": ca_info["fingerprint_sha1"], + "created_at": ca_info["created_at"], + "expires_at": ca_info["expires_at"], + "key_algorithm": ca_info["key_algorithm"], + } ) - response = Response(content, mimetype="text/x-python") - response.headers["Content-Disposition"] = "attachment; filename=fpaste" + +class PKICAGenerateView(MethodView): + """CA generation endpoint (first-run only).""" + + def post(self) -> Response: + """Generate CA certificate. Only works if no CA exists.""" + if err := require_pki_enabled(): + return err + + from app.pki import ( + CAExistsError, + PKIError, + generate_ca, + get_ca_info, + ) + + # Check if CA already exists + if get_ca_info() is not None: + return error_response("CA already exists", 409) + + # Get CA password from config + password = current_app.config.get("PKI_CA_PASSWORD", "") + if not password: + return error_response( + "PKI_CA_PASSWORD not configured", + 500, + hint="Set FLASKPASTE_PKI_CA_PASSWORD environment variable", + ) + + # Parse request for optional common name + common_name = "FlaskPaste CA" + if request.is_json: + data = request.get_json(silent=True) + if data and isinstance(data.get("common_name"), str): + common_name = data["common_name"][:64] + + # Generate CA + try: + days = current_app.config.get("PKI_CA_DAYS", 3650) + owner = get_client_id() + ca_info = generate_ca(common_name, password, days=days, owner=owner) + except CAExistsError: + return error_response("CA already exists", 409) + except PKIError as e: + current_app.logger.error("CA generation failed: %s", e) + return error_response("CA generation failed", 500) + + current_app.logger.info( + "CA generated: cn=%s fingerprint=%s", common_name, ca_info["fingerprint_sha1"][:12] + ) + + return json_response( + { + "message": "CA generated", + "common_name": ca_info["common_name"], + "fingerprint_sha1": ca_info["fingerprint_sha1"], + "created_at": ca_info["created_at"], + "expires_at": ca_info["expires_at"], + "download": prefixed_url("/pki/ca.crt"), + }, + 201, + ) + + +class PKICADownloadView(MethodView): + """CA certificate download endpoint.""" + + def get(self) -> Response: + """Download CA certificate in PEM format.""" + if err := require_pki_enabled(): + return err + + from app.pki import get_ca_info + + ca_info = get_ca_info() + if ca_info is None: + return error_response("CA not initialized", 404) + + response = Response(ca_info["certificate_pem"], mimetype="application/x-pem-file") + response.headers["Content-Disposition"] = ( + f"attachment; filename={ca_info['common_name'].replace(' ', '_')}.crt" + ) return response - except FileNotFoundError: - return _json_response({"error": "Client not available"}, 404) -@bp.route("/health", methods=["GET"]) -def health(): - """Health check endpoint for load balancers and monitoring.""" - try: - db = get_db() - db.execute("SELECT 1") - return _json_response({"status": "healthy", "database": "ok"}) - except Exception: - return _json_response({"status": "unhealthy", "database": "error"}, 503) +class PKIIssueView(MethodView): + """Certificate issuance endpoint (open registration).""" + def post(self) -> Response: + """Issue a new client certificate.""" + if err := require_pki_enabled(): + return err -@bp.route("/challenge", methods=["GET"]) -def challenge(): - """Get a proof-of-work challenge for paste creation.""" - difficulty = current_app.config["POW_DIFFICULTY"] - if difficulty == 0: - return _json_response({"enabled": False, "difficulty": 0}) - - ch = _generate_challenge() - return _json_response({ - "enabled": True, - "nonce": ch["nonce"], - "difficulty": ch["difficulty"], - "expires": ch["expires"], - "token": f"{ch['nonce']}:{ch['expires']}:{ch['difficulty']}:{ch['signature']}", - }) - - -@bp.route("/", methods=["GET", "POST"]) -def index(): - """Handle API info (GET) and paste creation (POST).""" - if request.method == "POST": - return create_paste() - - prefix = current_app.config.get("URL_PREFIX", "") - return _json_response( - { - "name": "FlaskPaste", - "version": VERSION, - "prefix": prefix or "/", - "endpoints": { - f"GET {_url('/')}": "API information", - f"GET {_url('/health')}": "Health check", - f"GET {_url('/client')}": "Download CLI client", - f"GET {_url('/challenge')}": "Get PoW challenge", - f"POST {_url('/')}": "Create paste", - f"GET {_url('/')}": "Retrieve paste metadata", - f"GET {_url('//raw')}": "Retrieve raw paste content", - f"DELETE {_url('/')}": "Delete paste", - }, - "usage": { - "raw": f"curl --data-binary @file.txt {_base_url()}/", - "pipe": f"cat file.txt | curl --data-binary @- {_base_url()}/", - "json": f"curl -H 'Content-Type: application/json' -d '{{\"content\":\"...\"}}' {_base_url()}/", - }, - "note": "Use --data-binary (not -d) to preserve newlines", - } - ) - - -def create_paste(): - """Create a new paste from request body.""" - content: bytes | None = None - mime_type: str | None = None - - if request.is_json: - data = request.get_json(silent=True) - if data and isinstance(data.get("content"), str): - content = data["content"].encode("utf-8") - mime_type = "text/plain" - else: - content = request.get_data(as_text=False) - if content: - mime_type = _detect_mime_type(content, request.content_type) - - if not content: - return _json_response({"error": "No content provided"}, 400) - - owner = _get_client_id() - - # Verify proof-of-work (if enabled) - difficulty = current_app.config["POW_DIFFICULTY"] - if difficulty > 0: - pow_token = request.headers.get("X-PoW-Token", "") - pow_solution = request.headers.get("X-PoW-Solution", "") - - if not pow_token or not pow_solution: - return _json_response({ - "error": "Proof-of-work required", - "hint": "GET /challenge for a new challenge", - }, 400) - - # Extract nonce from token for verification - parts = pow_token.split(":") - pow_nonce = parts[0] if parts else "" - - valid, err = _verify_pow(pow_nonce, pow_token, pow_solution) - if not valid: - current_app.logger.warning( - "PoW verification failed: %s from=%s", - err, request.remote_addr - ) - return _json_response({"error": f"Proof-of-work failed: {err}"}, 400) - - # Enforce size limits based on authentication - content_size = len(content) - if owner: - max_size = current_app.config["MAX_PASTE_SIZE_AUTH"] - else: - max_size = current_app.config["MAX_PASTE_SIZE_ANON"] - - if content_size > max_size: - return _json_response({ - "error": "Paste too large", - "size": content_size, - "max_size": max_size, - "authenticated": owner is not None, - }, 413) - - # Check minimum entropy requirement (encryption enforcement) - min_entropy = current_app.config.get("MIN_ENTROPY", 0) - min_entropy_size = current_app.config.get("MIN_ENTROPY_SIZE", 256) - if min_entropy > 0 and content_size >= min_entropy_size: - entropy = _calculate_entropy(content) - if entropy < min_entropy: - current_app.logger.warning( - "Low entropy rejected: %.2f < %.2f from=%s", - entropy, min_entropy, request.remote_addr - ) - return _json_response({ - "error": "Content entropy too low", - "entropy": round(entropy, 2), - "min_entropy": min_entropy, - "hint": "Encrypt content before uploading (-e flag in fpaste)", - }, 400) - - # Check content deduplication threshold - content_hash = hashlib.sha256(content).hexdigest() - is_allowed, dedup_count = check_content_hash(content_hash) - - if not is_allowed: - window = current_app.config["CONTENT_DEDUP_WINDOW"] - current_app.logger.warning( - "Dedup threshold exceeded: hash=%s count=%d from=%s", - content_hash[:16], dedup_count, request.remote_addr + from app.pki import ( + CANotFoundError, + PKIError, + issue_certificate, ) - return _json_response({ - "error": "Duplicate content rate limit exceeded", - "count": dedup_count, - "window_seconds": window, - }, 429) - paste_id = _generate_id(content) - now = int(time.time()) + # Parse request + common_name = None + if request.is_json: + data = request.get_json(silent=True) + if data and isinstance(data.get("common_name"), str): + common_name = data["common_name"][:64] - db = get_db() - db.execute( - "INSERT INTO pastes (id, content, mime_type, owner, created_at, last_accessed) VALUES (?, ?, ?, ?, ?, ?)", - (paste_id, content, mime_type, owner, now, now), - ) - db.commit() + if not common_name: + return error_response( + "common_name required", 400, hint='POST {"common_name": "your-name"}' + ) - response_data = { - "id": paste_id, - "url": f"/{paste_id}", - "raw": f"/{paste_id}/raw", - "mime_type": mime_type, - "created_at": now, - } - if owner: - response_data["owner"] = owner + # Get CA password from config + password = current_app.config.get("PKI_CA_PASSWORD", "") + if not password: + return error_response("PKI not properly configured", 500) - return _json_response(response_data, 201) + # Issue certificate + try: + days = current_app.config.get("PKI_CERT_DAYS", 365) + issued_to = get_client_id() + cert_info = issue_certificate(common_name, password, days=days, issued_to=issued_to) + except CANotFoundError: + return error_response("CA not initialized", 404) + except PKIError as e: + current_app.logger.error("Certificate issuance failed: %s", e) + return error_response("Certificate issuance failed", 500) + + current_app.logger.info( + "Certificate issued: cn=%s serial=%s fingerprint=%s to=%s", + common_name, + cert_info["serial"][:8], + cert_info["fingerprint_sha1"][:12], + issued_to or "anonymous", + ) + + # Return certificate bundle + return json_response( + { + "message": "Certificate issued", + "serial": cert_info["serial"], + "common_name": cert_info["common_name"], + "fingerprint_sha1": cert_info["fingerprint_sha1"], + "created_at": cert_info["created_at"], + "expires_at": cert_info["expires_at"], + "certificate_pem": cert_info["certificate_pem"], + "private_key_pem": cert_info["private_key_pem"], + }, + 201, + ) -@bp.route("/", methods=["GET", "HEAD"]) -def get_paste(paste_id: str): - """Retrieve paste metadata by ID. HEAD returns headers only.""" - if not _is_valid_paste_id(paste_id): - return _json_response({"error": "Invalid paste ID"}, 400) +class PKICertsView(MethodView): + """Certificate listing endpoint.""" - db = get_db() - now = int(time.time()) + def get(self) -> Response: + """List issued certificates.""" + if err := require_pki_enabled(): + return err - # Update last_accessed and return paste in one transaction - db.execute( - "UPDATE pastes SET last_accessed = ? WHERE id = ?", (now, paste_id) - ) - row = db.execute( - "SELECT id, mime_type, created_at, length(content) as size FROM pastes WHERE id = ?", - (paste_id,) - ).fetchone() - db.commit() + client_id = get_client_id() - if row is None: - return _json_response({"error": "Paste not found"}, 404) + db = get_db() - return _json_response({ - "id": row["id"], - "mime_type": row["mime_type"], - "size": row["size"], - "created_at": row["created_at"], - "raw": f"/{paste_id}/raw", - }) + # Authenticated users see their own certs or certs they issued + # Anonymous users see nothing + if client_id: + rows = db.execute( + """SELECT serial, common_name, fingerprint_sha1, + created_at, expires_at, issued_to, status, revoked_at + FROM issued_certificates + WHERE issued_to = ? OR fingerprint_sha1 = ? + ORDER BY created_at DESC""", + (client_id, client_id), + ).fetchall() + else: + # Anonymous: empty list + rows = [] + + certs = [] + for row in rows: + cert = { + "serial": row["serial"], + "common_name": row["common_name"], + "fingerprint_sha1": row["fingerprint_sha1"], + "created_at": row["created_at"], + "expires_at": row["expires_at"], + "status": row["status"], + } + if row["issued_to"]: + cert["issued_to"] = row["issued_to"] + if row["revoked_at"]: + cert["revoked_at"] = row["revoked_at"] + certs.append(cert) + + return json_response({"certificates": certs, "count": len(certs)}) -@bp.route("//raw", methods=["GET", "HEAD"]) -def get_paste_raw(paste_id: str): - """Retrieve raw paste content with correct MIME type. HEAD returns headers only.""" - if not _is_valid_paste_id(paste_id): - return _json_response({"error": "Invalid paste ID"}, 400) +class PKIRevokeView(MethodView): + """Certificate revocation endpoint.""" - db = get_db() - now = int(time.time()) + def post(self, serial: str) -> Response: + """Revoke a certificate by serial number.""" + if err := require_pki_enabled(): + return err + if err := require_auth(): + return err - # Update last_accessed and return paste in one transaction - db.execute( - "UPDATE pastes SET last_accessed = ? WHERE id = ?", (now, paste_id) - ) - row = db.execute( - "SELECT content, mime_type FROM pastes WHERE id = ?", (paste_id,) - ).fetchone() - db.commit() + from app.pki import CertificateNotFoundError, PKIError, revoke_certificate - if row is None: - return _json_response({"error": "Paste not found"}, 404) + db = get_db() - mime_type = row["mime_type"] + # Check certificate exists and get ownership info + row = db.execute( + "SELECT issued_to, fingerprint_sha1, status FROM issued_certificates WHERE serial = ?", + (serial,), + ).fetchone() - response = Response(row["content"], mimetype=mime_type) - # Display inline for images and text, let browser decide for others - if mime_type.startswith(("image/", "text/")): - response.headers["Content-Disposition"] = "inline" + if row is None: + return error_response("Certificate not found", 404) - return response + if row["status"] == "revoked": + return error_response("Certificate already revoked", 409) + + # Check permission: must be issuer or the certificate itself + client_id = g.client_id + can_revoke = row["issued_to"] == client_id or row["fingerprint_sha1"] == client_id + + if not can_revoke: + return error_response("Permission denied", 403) + + # Revoke + try: + revoke_certificate(serial) + except CertificateNotFoundError: + return error_response("Certificate not found", 404) + except PKIError as e: + current_app.logger.error("Revocation failed: %s", e) + return error_response("Revocation failed", 500) + + current_app.logger.info("Certificate revoked: serial=%s by=%s", serial[:8], client_id[:12]) + + return json_response({"message": "Certificate revoked", "serial": serial}) -@bp.route("/", methods=["DELETE"]) -def delete_paste(paste_id: str): - """Delete a paste by ID. Requires ownership via X-SSL-Client-SHA1 header.""" - if not _is_valid_paste_id(paste_id): - return _json_response({"error": "Invalid paste ID"}, 400) +# ───────────────────────────────────────────────────────────────────────────── +# Route Registration +# ───────────────────────────────────────────────────────────────────────────── - client_id = _get_client_id() - if not client_id: - return _json_response({"error": "Authentication required"}, 401) +# Index and paste creation +bp.add_url_rule("/", view_func=IndexView.as_view("index")) - db = get_db() +# Utility endpoints +bp.add_url_rule("/health", view_func=HealthView.as_view("health")) +bp.add_url_rule("/challenge", view_func=ChallengeView.as_view("challenge")) +bp.add_url_rule("/client", view_func=ClientView.as_view("client")) - # Check paste exists and verify ownership - row = db.execute( - "SELECT owner FROM pastes WHERE id = ?", (paste_id,) - ).fetchone() +# Paste operations +bp.add_url_rule("/", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD"]) +bp.add_url_rule( + "//raw", view_func=PasteRawView.as_view("paste_raw"), methods=["GET", "HEAD"] +) +bp.add_url_rule( + "/", view_func=PasteDeleteView.as_view("paste_delete"), methods=["DELETE"] +) - if row is None: - return _json_response({"error": "Paste not found"}, 404) - - if row["owner"] != client_id: - return _json_response({"error": "Permission denied"}, 403) - - db.execute("DELETE FROM pastes WHERE id = ?", (paste_id,)) - db.commit() - - return _json_response({"message": "Paste deleted"}) +# PKI endpoints +bp.add_url_rule("/pki", view_func=PKIStatusView.as_view("pki_status")) +bp.add_url_rule("/pki/ca", view_func=PKICAGenerateView.as_view("pki_ca_generate")) +bp.add_url_rule("/pki/ca.crt", view_func=PKICADownloadView.as_view("pki_ca_download")) +bp.add_url_rule("/pki/issue", view_func=PKIIssueView.as_view("pki_issue")) +bp.add_url_rule("/pki/certs", view_func=PKICertsView.as_view("pki_certs")) +bp.add_url_rule( + "/pki/revoke/", view_func=PKIRevokeView.as_view("pki_revoke"), methods=["POST"] +) diff --git a/app/config.py b/app/config.py index 5f418bb..4ba4c86 100644 --- a/app/config.py +++ b/app/config.py @@ -4,7 +4,7 @@ import os from pathlib import Path # Application version -VERSION = "1.1.0" +VERSION = "1.2.0" class Config: @@ -21,6 +21,8 @@ class Config: # Paste expiry (default 5 days) PASTE_EXPIRY_SECONDS = int(os.environ.get("FLASKPASTE_EXPIRY", 5 * 24 * 60 * 60)) + # Maximum custom expiry (default 30 days, 0 = use default expiry as max) + MAX_EXPIRY_SECONDS = int(os.environ.get("FLASKPASTE_MAX_EXPIRY", 30 * 24 * 60 * 60)) # Content deduplication / abuse prevention # Throttle repeated submissions of identical content @@ -54,6 +56,16 @@ class Config: # URL prefix for reverse proxy deployments (e.g., "/paste" for mymx.me/paste) URL_PREFIX = os.environ.get("FLASKPASTE_URL_PREFIX", "").rstrip("/") + # PKI Configuration + # Enable PKI endpoints for certificate authority and issuance + PKI_ENABLED = os.environ.get("FLASKPASTE_PKI_ENABLED", "0").lower() in ("1", "true", "yes") + # CA password for signing operations (REQUIRED when PKI is enabled) + PKI_CA_PASSWORD = os.environ.get("FLASKPASTE_PKI_CA_PASSWORD", "") + # Default validity period for issued certificates (days) + PKI_CERT_DAYS = int(os.environ.get("FLASKPASTE_PKI_CERT_DAYS", "365")) + # CA certificate validity period (days) + PKI_CA_DAYS = int(os.environ.get("FLASKPASTE_PKI_CA_DAYS", "3650")) # 10 years + class DevelopmentConfig(Config): """Development configuration.""" @@ -80,6 +92,12 @@ class TestingConfig(Config): # Disable PoW for most tests (easier testing) POW_DIFFICULTY = 0 + # PKI testing configuration + PKI_ENABLED = True + PKI_CA_PASSWORD = "test-ca-password" + PKI_CERT_DAYS = 30 + PKI_CA_DAYS = 365 + config = { "development": DevelopmentConfig, diff --git a/app/database.py b/app/database.py index 45218e0..0a33223 100644 --- a/app/database.py +++ b/app/database.py @@ -1,5 +1,7 @@ """Database connection and schema management.""" +import hashlib +import secrets import sqlite3 import time from pathlib import Path @@ -13,7 +15,10 @@ CREATE TABLE IF NOT EXISTS pastes ( mime_type TEXT NOT NULL DEFAULT 'text/plain', owner TEXT, created_at INTEGER NOT NULL, - last_accessed INTEGER NOT NULL + last_accessed INTEGER NOT NULL, + burn_after_read INTEGER NOT NULL DEFAULT 0, + expires_at INTEGER, + password_hash TEXT ); CREATE INDEX IF NOT EXISTS idx_pastes_created_at ON pastes(created_at); @@ -29,8 +34,86 @@ CREATE TABLE IF NOT EXISTS content_hashes ( ); CREATE INDEX IF NOT EXISTS idx_content_hashes_last_seen ON content_hashes(last_seen); + +-- PKI: Certificate Authority storage +CREATE TABLE IF NOT EXISTS certificate_authority ( + id TEXT PRIMARY KEY DEFAULT 'default', + common_name TEXT NOT NULL, + certificate_pem TEXT NOT NULL, + private_key_encrypted BLOB NOT NULL, + key_salt BLOB NOT NULL, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL, + key_algorithm TEXT NOT NULL, + owner TEXT +); + +-- PKI: Issued client certificates +CREATE TABLE IF NOT EXISTS issued_certificates ( + serial TEXT PRIMARY KEY, + ca_id TEXT NOT NULL DEFAULT 'default', + common_name TEXT NOT NULL, + fingerprint_sha1 TEXT NOT NULL UNIQUE, + certificate_pem TEXT NOT NULL, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL, + issued_to TEXT, + status TEXT NOT NULL DEFAULT 'valid', + revoked_at INTEGER, + FOREIGN KEY(ca_id) REFERENCES certificate_authority(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_certs_fingerprint ON issued_certificates(fingerprint_sha1); +CREATE INDEX IF NOT EXISTS idx_certs_status ON issued_certificates(status); +CREATE INDEX IF NOT EXISTS idx_certs_ca_id ON issued_certificates(ca_id); """ +# Password hashing constants +_HASH_ITERATIONS = 600000 # OWASP 2023 recommendation for PBKDF2-SHA256 +_SALT_LENGTH = 32 + + +def hash_password(password: str) -> str: + """Hash password using PBKDF2-HMAC-SHA256. + + Returns format: $pbkdf2-sha256$iterations$salt$hash + All values are hex-encoded. + """ + if not password: + return None + + salt = secrets.token_bytes(_SALT_LENGTH) + dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, _HASH_ITERATIONS) + return f"$pbkdf2-sha256${_HASH_ITERATIONS}${salt.hex()}${dk.hex()}" + + +def verify_password(password: str, password_hash: str) -> bool: + """Verify password against stored hash. + + Uses constant-time comparison to prevent timing attacks. + """ + if not password or not password_hash: + return False + + try: + # Parse hash format: $pbkdf2-sha256$iterations$salt$hash + parts = password_hash.split("$") + if len(parts) != 5 or parts[1] != "pbkdf2-sha256": + return False + + iterations = int(parts[2]) + salt = bytes.fromhex(parts[3]) + stored_hash = bytes.fromhex(parts[4]) + + # Compute hash of provided password + dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations) + + # Constant-time comparison + return secrets.compare_digest(dk, stored_hash) + except (ValueError, IndexError): + return False + + # Hold reference for in-memory shared cache databases _memory_db_holder = None @@ -98,15 +181,27 @@ def init_db() -> None: def cleanup_expired_pastes() -> int: - """Delete pastes that haven't been accessed within expiry period. + """Delete pastes that have expired. + + Pastes expire based on: + - Custom expires_at timestamp if set + - Default expiry from last_accessed if expires_at is NULL Returns number of deleted pastes. """ expiry_seconds = current_app.config["PASTE_EXPIRY_SECONDS"] - cutoff = int(time.time()) - expiry_seconds + now = int(time.time()) + default_cutoff = now - expiry_seconds db = get_db() - cursor = db.execute("DELETE FROM pastes WHERE last_accessed < ?", (cutoff,)) + # Delete pastes with custom expiry that have passed, + # OR pastes without custom expiry that exceed default window + cursor = db.execute( + """DELETE FROM pastes WHERE + (expires_at IS NOT NULL AND expires_at < ?) + OR (expires_at IS NULL AND last_accessed < ?)""", + (now, default_cutoff), + ) db.commit() return cursor.rowcount @@ -146,15 +241,14 @@ def check_content_hash(content_hash: str) -> tuple[bool, int]: # Check existing hash record row = db.execute( - "SELECT count, last_seen FROM content_hashes WHERE hash = ?", - (content_hash,) + "SELECT count, last_seen FROM content_hashes WHERE hash = ?", (content_hash,) ).fetchone() if row is None: # First time seeing this content db.execute( "INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 1)", - (content_hash, now, now) + (content_hash, now, now), ) db.commit() return True, 1 @@ -163,7 +257,7 @@ def check_content_hash(content_hash: str) -> tuple[bool, int]: # Outside window, reset counter db.execute( "UPDATE content_hashes SET first_seen = ?, last_seen = ?, count = 1 WHERE hash = ?", - (now, now, content_hash) + (now, now, content_hash), ) db.commit() return True, 1 @@ -178,7 +272,7 @@ def check_content_hash(content_hash: str) -> tuple[bool, int]: # Update counter db.execute( "UPDATE content_hashes SET last_seen = ?, count = ? WHERE hash = ?", - (now, current_count, content_hash) + (now, current_count, content_hash), ) db.commit() diff --git a/app/pki.py b/app/pki.py new file mode 100644 index 0000000..01bb982 --- /dev/null +++ b/app/pki.py @@ -0,0 +1,1019 @@ +"""Minimal PKI module for certificate authority and client certificate management. + +This module provides both standalone PKI functionality and Flask-integrated +database-backed functions. Core cryptographic functions have no Flask dependencies. + +Usage (standalone): + from app.pki import PKI + + pki = PKI(password="your-ca-password") + ca_info = pki.generate_ca("My CA") + cert_info = pki.issue_certificate("client-name") + +Usage (Flask routes): + from app.pki import generate_ca, issue_certificate, is_certificate_valid + + ca_info = generate_ca("My CA", password) # Saves to database + if is_certificate_valid(fingerprint): + # Certificate is valid and not revoked +""" + +import hashlib +import secrets +import time +from datetime import UTC, datetime, timedelta +from typing import Any + +# Cryptography imports (required dependency) +try: + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + + HAS_CRYPTO = True +except ImportError: + HAS_CRYPTO = False + + +# Constants +_KDF_ITERATIONS = 600000 # OWASP 2023 recommendation +_SALT_LENGTH = 32 +_KEY_LENGTH = 32 # AES-256 + + +class PKIError(Exception): + """Base exception for PKI operations.""" + + pass + + +class CANotFoundError(PKIError): + """CA does not exist.""" + + pass + + +class CAExistsError(PKIError): + """CA already exists.""" + + pass + + +class CertificateNotFoundError(PKIError): + """Certificate not found.""" + + pass + + +class InvalidPasswordError(PKIError): + """Invalid CA password.""" + + pass + + +def _require_crypto() -> None: + """Raise if cryptography package is not available.""" + if not HAS_CRYPTO: + raise PKIError("PKI requires 'cryptography' package: pip install cryptography") + + +def derive_key(password: str, salt: bytes) -> bytes: + """Derive encryption key from password using PBKDF2. + + Args: + password: Password string + salt: Random salt bytes + + Returns: + 32-byte derived key for AES-256 + """ + _require_crypto() + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=_KEY_LENGTH, + salt=salt, + iterations=_KDF_ITERATIONS, + ) + return kdf.derive(password.encode("utf-8")) + + +def encrypt_private_key(private_key: Any, password: str) -> tuple[bytes, bytes]: + """Encrypt private key with password. + + Args: + private_key: EC or RSA private key object + password: Encryption password + + Returns: + Tuple of (encrypted_key_bytes, salt) + """ + _require_crypto() + + # Serialize private key to PEM (unencrypted) + key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Derive encryption key + salt = secrets.token_bytes(_SALT_LENGTH) + encryption_key = derive_key(password, salt) + + # Encrypt with AES-256-GCM + nonce = secrets.token_bytes(12) + aesgcm = AESGCM(encryption_key) + ciphertext = aesgcm.encrypt(nonce, key_pem, None) + + # Prepend nonce to ciphertext + encrypted = nonce + ciphertext + + return encrypted, salt + + +def decrypt_private_key(encrypted: bytes, salt: bytes, password: str) -> Any: + """Decrypt private key with password. + + Args: + encrypted: Encrypted key bytes (nonce + ciphertext) + salt: Salt used for key derivation + password: Decryption password + + Returns: + Private key object + + Raises: + InvalidPasswordError: If password is incorrect + """ + _require_crypto() + + if len(encrypted) < 12: + raise PKIError("Invalid encrypted key format") + + # Derive encryption key + encryption_key = derive_key(password, salt) + + # Decrypt + nonce = encrypted[:12] + ciphertext = encrypted[12:] + aesgcm = AESGCM(encryption_key) + + try: + key_pem = aesgcm.decrypt(nonce, ciphertext, None) + except Exception: + raise InvalidPasswordError("Invalid CA password") from None + + # Load private key + return serialization.load_pem_private_key(key_pem, password=None) + + +def calculate_fingerprint(certificate: Any) -> str: + """Calculate SHA1 fingerprint of certificate. + + SHA1 is used here for X.509 certificate fingerprints, which is standard + practice and not a security concern (fingerprints are identifiers, not + used for cryptographic security). + + Args: + certificate: X.509 certificate object + + Returns: + Lowercase hex fingerprint (40 characters) + """ + _require_crypto() + cert_der = certificate.public_bytes(serialization.Encoding.DER) + # SHA1 fingerprints are industry standard for X.509, not security-relevant + return hashlib.sha1(cert_der, usedforsecurity=False).hexdigest() + + +class PKI: + """Standalone PKI manager for CA and certificate operations. + + This class provides core PKI functionality without Flask dependencies. + Use get_pki() for Flask-integrated usage. + + Args: + password: CA password for signing operations + ca_days: CA certificate validity in days (default: 3650) + cert_days: Client certificate validity in days (default: 365) + """ + + def __init__( + self, + password: str, + ca_days: int = 3650, + cert_days: int = 365, + ): + _require_crypto() + self.password = password + self.ca_days = ca_days + self.cert_days = cert_days + + # In-memory storage (override for database backing) + self._ca_store: dict | None = None + self._certificates: dict[str, dict] = {} + + def has_ca(self) -> bool: + """Check if CA exists.""" + return self._ca_store is not None + + def generate_ca( + self, + common_name: str, + algorithm: str = "ec", + curve: str = "secp384r1", + ) -> dict: + """Generate a new Certificate Authority. + + Args: + common_name: CA common name (e.g., "FlaskPaste CA") + algorithm: Key algorithm ("ec" only for now) + curve: EC curve name (secp256r1, secp384r1, secp521r1) + + Returns: + Dict with CA info: id, common_name, fingerprint, expires_at, certificate_pem + + Raises: + CAExistsError: If CA already exists + """ + if self.has_ca(): + raise CAExistsError("CA already exists") + + # Generate EC key + curves = { + "secp256r1": ec.SECP256R1(), + "secp384r1": ec.SECP384R1(), + "secp521r1": ec.SECP521R1(), + } + if curve not in curves: + raise PKIError(f"Unsupported curve: {curve}") + + private_key = ec.generate_private_key(curves[curve]) + + # Build CA certificate + now = datetime.now(UTC) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"), + ] + ) + + cert_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=self.ca_days)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=0), + critical=True, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_cert_sign=True, + crl_sign=True, + key_encipherment=False, + content_commitment=False, + data_encipherment=False, + key_agreement=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + ) + + certificate = cert_builder.sign(private_key, hashes.SHA256()) + + # Encrypt private key + encrypted_key, salt = encrypt_private_key(private_key, self.password) + + # Calculate fingerprint + fingerprint = calculate_fingerprint(certificate) + + # Store CA + expires_at = int((now + timedelta(days=self.ca_days)).timestamp()) + cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + self._ca_store = { + "id": "default", + "common_name": common_name, + "certificate_pem": cert_pem, + "private_key_encrypted": encrypted_key, + "key_salt": salt, + "created_at": int(now.timestamp()), + "expires_at": expires_at, + "key_algorithm": f"ec:{curve}", + "fingerprint": fingerprint, + "_private_key": private_key, # Cached for signing + "_certificate": certificate, + } + + return { + "id": "default", + "common_name": common_name, + "fingerprint": fingerprint, + "expires_at": expires_at, + "certificate_pem": cert_pem, + } + + def load_ca(self, cert_pem: str, key_pem: str, key_password: str | None = None) -> dict: + """Load existing CA from PEM files. + + Args: + cert_pem: CA certificate PEM string + key_pem: CA private key PEM string + key_password: Password for encrypted private key (if any) + + Returns: + Dict with CA info + """ + if self.has_ca(): + raise CAExistsError("CA already exists") + + # Load certificate + certificate = x509.load_pem_x509_certificate(cert_pem.encode("utf-8")) + + # Load private key + pwd = key_password.encode("utf-8") if key_password else None + private_key = serialization.load_pem_private_key(key_pem.encode("utf-8"), password=pwd) + + # Extract info + common_name = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value + fingerprint = calculate_fingerprint(certificate) + + # Re-encrypt with our password + encrypted_key, salt = encrypt_private_key(private_key, self.password) + + now = datetime.now(UTC) + expires_at = int(certificate.not_valid_after_utc.timestamp()) + + self._ca_store = { + "id": "default", + "common_name": common_name, + "certificate_pem": cert_pem, + "private_key_encrypted": encrypted_key, + "key_salt": salt, + "created_at": int(now.timestamp()), + "expires_at": expires_at, + "key_algorithm": "imported", + "fingerprint": fingerprint, + "_private_key": private_key, + "_certificate": certificate, + } + + return { + "id": "default", + "common_name": common_name, + "fingerprint": fingerprint, + "expires_at": expires_at, + "certificate_pem": cert_pem, + } + + def get_ca(self) -> dict: + """Get CA information. + + Returns: + Dict with CA info (without private key) + + Raises: + CANotFoundError: If no CA exists + """ + if not self.has_ca(): + raise CANotFoundError("No CA configured") + + return { + "id": self._ca_store["id"], + "common_name": self._ca_store["common_name"], + "fingerprint": self._ca_store["fingerprint"], + "expires_at": self._ca_store["expires_at"], + "certificate_pem": self._ca_store["certificate_pem"], + } + + def get_ca_certificate_pem(self) -> str: + """Get CA certificate in PEM format. + + Returns: + PEM-encoded CA certificate + """ + if not self.has_ca(): + raise CANotFoundError("No CA configured") + return self._ca_store["certificate_pem"] + + def _get_signing_key(self) -> tuple[Any, Any]: + """Get CA private key and certificate for signing. + + Returns: + Tuple of (private_key, certificate) + """ + if not self.has_ca(): + raise CANotFoundError("No CA configured") + + # Use cached key if available + if "_private_key" in self._ca_store: + return self._ca_store["_private_key"], self._ca_store["_certificate"] + + # Decrypt private key + private_key = decrypt_private_key( + self._ca_store["private_key_encrypted"], + self._ca_store["key_salt"], + self.password, + ) + + # Load certificate + certificate = x509.load_pem_x509_certificate( + self._ca_store["certificate_pem"].encode("utf-8") + ) + + # Cache for future use + self._ca_store["_private_key"] = private_key + self._ca_store["_certificate"] = certificate + + return private_key, certificate + + def issue_certificate( + self, + common_name: str, + days: int | None = None, + algorithm: str = "ec", + curve: str = "secp384r1", + ) -> dict: + """Issue a new client certificate. + + Args: + common_name: Client common name + days: Validity period (default: self.cert_days) + algorithm: Key algorithm ("ec") + curve: EC curve name + + Returns: + Dict with certificate info including private key PEM + """ + if days is None: + days = self.cert_days + + ca_key, ca_cert = self._get_signing_key() + + # Generate client key + curves = { + "secp256r1": ec.SECP256R1(), + "secp384r1": ec.SECP384R1(), + "secp521r1": ec.SECP521R1(), + } + if curve not in curves: + raise PKIError(f"Unsupported curve: {curve}") + + client_key = ec.generate_private_key(curves[curve]) + + # Build certificate + now = datetime.now(UTC) + serial = x509.random_serial_number() + + subject = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ] + ) + + cert_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(ca_cert.subject) + .public_key(client_key.public_key()) + .serial_number(serial) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=days)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + content_commitment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), + critical=False, + ) + ) + + certificate = cert_builder.sign(ca_key, hashes.SHA256()) + + # Serialize + cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8") + key_pem = client_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + # Calculate fingerprint + fingerprint = calculate_fingerprint(certificate) + serial_hex = format(serial, "032x") + expires_at = int((now + timedelta(days=days)).timestamp()) + + # Store certificate record + cert_record = { + "serial": serial_hex, + "ca_id": "default", + "common_name": common_name, + "fingerprint_sha1": fingerprint, + "certificate_pem": cert_pem, + "created_at": int(now.timestamp()), + "expires_at": expires_at, + "status": "valid", + "revoked_at": None, + } + self._certificates[fingerprint] = cert_record + + return { + "serial": serial_hex, + "common_name": common_name, + "fingerprint": fingerprint, + "expires_at": expires_at, + "certificate_pem": cert_pem, + "private_key_pem": key_pem, + "ca_certificate_pem": self._ca_store["certificate_pem"], + } + + def revoke_certificate(self, fingerprint: str) -> bool: + """Revoke a certificate by fingerprint. + + Args: + fingerprint: SHA1 fingerprint (40 hex chars) + + Returns: + True if revoked, False if already revoked + """ + fingerprint = fingerprint.lower() + if fingerprint not in self._certificates: + raise CertificateNotFoundError(f"Certificate not found: {fingerprint}") + + cert = self._certificates[fingerprint] + if cert["status"] == "revoked": + return False + + cert["status"] = "revoked" + cert["revoked_at"] = int(time.time()) + return True + + def is_valid(self, fingerprint: str) -> bool: + """Check if fingerprint is valid (exists and not revoked). + + Args: + fingerprint: SHA1 fingerprint + + Returns: + True if valid, False otherwise + """ + fingerprint = fingerprint.lower() + if fingerprint not in self._certificates: + return True # Unknown fingerprints are allowed (external certs) + + cert = self._certificates[fingerprint] + if cert["status"] != "valid": + return False + + # Check expiry + return cert["expires_at"] >= int(time.time()) + + def get_certificate(self, fingerprint: str) -> dict | None: + """Get certificate info by fingerprint. + + Args: + fingerprint: SHA1 fingerprint + + Returns: + Certificate info dict or None if not found + """ + fingerprint = fingerprint.lower() + cert = self._certificates.get(fingerprint) + if cert is None: + return None + + return { + "serial": cert["serial"], + "common_name": cert["common_name"], + "fingerprint": cert["fingerprint_sha1"], + "expires_at": cert["expires_at"], + "status": cert["status"], + "created_at": cert["created_at"], + "revoked_at": cert["revoked_at"], + } + + def list_certificates(self) -> list[dict]: + """List all issued certificates. + + Returns: + List of certificate info dicts + """ + return [ + { + "serial": c["serial"], + "common_name": c["common_name"], + "fingerprint": c["fingerprint_sha1"], + "expires_at": c["expires_at"], + "status": c["status"], + "created_at": c["created_at"], + } + for c in self._certificates.values() + ] + + +def reset_pki() -> None: + """Reset PKI state (for testing). + + This is a no-op since routes use direct database functions. + Kept for backward compatibility with tests. + """ + pass + + +def is_certificate_valid(fingerprint: str) -> bool: + """Check if fingerprint is valid for authentication. + + This is the main integration point for routes.py. + Unknown fingerprints are considered valid (external certs). + Queries database directly to ensure fresh revocation status. + + Args: + fingerprint: SHA1 fingerprint + + Returns: + True if valid or unknown, False if revoked/expired + """ + from flask import current_app + + if not current_app.config.get("PKI_ENABLED"): + return True + + from app.database import get_db + + fingerprint = fingerprint.lower() + db = get_db() + + # Query database directly for fresh revocation status + row = db.execute( + "SELECT status, expires_at FROM issued_certificates WHERE fingerprint_sha1 = ?", + (fingerprint,), + ).fetchone() + + if row is None: + # Unknown fingerprint (external cert or not issued by us) - allow + return True + + # Check status + if row["status"] != "valid": + return False + + # Check expiry + return row["expires_at"] >= int(time.time()) + + +# ───────────────────────────────────────────────────────────────────────────── +# Flask Route Helper Functions +# ───────────────────────────────────────────────────────────────────────────── + + +def get_ca_info() -> dict | None: + """Get CA information for status endpoint. + + Returns: + Dict with CA info or None if no CA exists + """ + from flask import current_app + + from app.database import get_db + + if not current_app.config.get("PKI_ENABLED"): + return None + + db = get_db() + row = db.execute( + """SELECT id, common_name, certificate_pem, created_at, expires_at, key_algorithm + FROM certificate_authority WHERE id = 'default'""" + ).fetchone() + + if row is None: + return None + + # Calculate fingerprint + cert = x509.load_pem_x509_certificate(row["certificate_pem"].encode("utf-8")) + fingerprint = calculate_fingerprint(cert) + + return { + "id": row["id"], + "common_name": row["common_name"], + "certificate_pem": row["certificate_pem"], + "fingerprint_sha1": fingerprint, + "created_at": row["created_at"], + "expires_at": row["expires_at"], + "key_algorithm": row["key_algorithm"], + } + + +def generate_ca( + common_name: str, + password: str, + days: int = 3650, + owner: str | None = None, +) -> dict: + """Generate CA certificate and save to database. + + Args: + common_name: CA common name + password: CA password for encrypting private key + days: Validity period in days + owner: Optional owner fingerprint + + Returns: + Dict with CA info including fingerprint_sha1 + + Raises: + CAExistsError: If CA already exists + """ + _require_crypto() + from app.database import get_db + + db = get_db() + + # Check if CA already exists + existing = db.execute("SELECT id FROM certificate_authority WHERE id = 'default'").fetchone() + if existing: + raise CAExistsError("CA already exists") + + # Generate EC key + curve = ec.SECP384R1() + private_key = ec.generate_private_key(curve) + + # Build CA certificate + now = datetime.now(UTC) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"), + ] + ) + + cert_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=days)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=0), + critical=True, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_cert_sign=True, + crl_sign=True, + key_encipherment=False, + content_commitment=False, + data_encipherment=False, + key_agreement=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + ) + + certificate = cert_builder.sign(private_key, hashes.SHA256()) + + # Encrypt private key + encrypted_key, salt = encrypt_private_key(private_key, password) + + # Calculate fingerprint + fingerprint = calculate_fingerprint(certificate) + + # Serialize + cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8") + created_at = int(now.timestamp()) + expires_at = int((now + timedelta(days=days)).timestamp()) + + # Save to database + db.execute( + """INSERT INTO certificate_authority + (id, common_name, certificate_pem, private_key_encrypted, + key_salt, created_at, expires_at, key_algorithm, owner) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "default", + common_name, + cert_pem, + encrypted_key, + salt, + created_at, + expires_at, + "ec:secp384r1", + owner, + ), + ) + db.commit() + + # Reset cached PKI instance + reset_pki() + + return { + "id": "default", + "common_name": common_name, + "fingerprint_sha1": fingerprint, + "certificate_pem": cert_pem, + "created_at": created_at, + "expires_at": expires_at, + } + + +def issue_certificate( + common_name: str, + password: str, + days: int = 365, + issued_to: str | None = None, +) -> dict: + """Issue a client certificate signed by the CA. + + Args: + common_name: Client common name + password: CA password for signing + days: Validity period in days + issued_to: Optional fingerprint of issuing user + + Returns: + Dict with certificate and private key PEM + + Raises: + CANotFoundError: If no CA exists + """ + _require_crypto() + from app.database import get_db + + db = get_db() + + # Load CA + ca_row = db.execute( + """SELECT certificate_pem, private_key_encrypted, key_salt + FROM certificate_authority WHERE id = 'default'""" + ).fetchone() + + if ca_row is None: + raise CANotFoundError("No CA configured") + + # Decrypt CA private key + ca_key = decrypt_private_key( + ca_row["private_key_encrypted"], + ca_row["key_salt"], + password, + ) + ca_cert = x509.load_pem_x509_certificate(ca_row["certificate_pem"].encode("utf-8")) + + # Generate client key + curve = ec.SECP384R1() + client_key = ec.generate_private_key(curve) + + # Build certificate + now = datetime.now(UTC) + serial = x509.random_serial_number() + + subject = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ] + ) + + cert_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(ca_cert.subject) + .public_key(client_key.public_key()) + .serial_number(serial) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=days)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + content_commitment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), + critical=False, + ) + ) + + certificate = cert_builder.sign(ca_key, hashes.SHA256()) + + # Serialize + cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8") + key_pem = client_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + # Calculate fingerprint + fingerprint = calculate_fingerprint(certificate) + serial_hex = format(serial, "032x") + created_at = int(now.timestamp()) + expires_at = int((now + timedelta(days=days)).timestamp()) + + # Save to database + db.execute( + """INSERT INTO issued_certificates + (serial, ca_id, common_name, fingerprint_sha1, certificate_pem, + created_at, expires_at, issued_to, status, revoked_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + serial_hex, + "default", + common_name, + fingerprint, + cert_pem, + created_at, + expires_at, + issued_to, + "valid", + None, + ), + ) + db.commit() + + return { + "serial": serial_hex, + "common_name": common_name, + "fingerprint_sha1": fingerprint, + "certificate_pem": cert_pem, + "private_key_pem": key_pem, + "created_at": created_at, + "expires_at": expires_at, + } + + +def revoke_certificate(serial: str) -> bool: + """Revoke a certificate by serial number. + + Args: + serial: Certificate serial number (hex) + + Returns: + True if revoked + + Raises: + CertificateNotFoundError: If certificate not found + """ + from app.database import get_db + + db = get_db() + + # Check exists + row = db.execute( + "SELECT status FROM issued_certificates WHERE serial = ?", (serial,) + ).fetchone() + + if row is None: + raise CertificateNotFoundError(f"Certificate not found: {serial}") + + if row["status"] == "revoked": + return False # Already revoked + + # Revoke + db.execute( + "UPDATE issued_certificates SET status = 'revoked', revoked_at = ? WHERE serial = ?", + (int(time.time()), serial), + ) + db.commit() + + return True diff --git a/documentation/api.md b/documentation/api.md index 9b813f0..cdd5039 100644 --- a/documentation/api.md +++ b/documentation/api.md @@ -103,7 +103,7 @@ Host: localhost:5000 ```json { "name": "FlaskPaste", - "version": "1.1.0", + "version": "1.2.0", "endpoints": { "GET /": "API information", "GET /health": "Health check", @@ -159,6 +159,45 @@ X-PoW-Solution: 12345678 Hello, World! ``` +**Request (Burn-After-Read):** + +Create a paste that deletes itself after first retrieval: + +```http +POST / HTTP/1.1 +Host: localhost:5000 +Content-Type: text/plain +X-Burn-After-Read: true + +Secret message +``` + +**Request (Custom Expiry):** + +Create a paste with custom expiry time (in seconds): + +```http +POST / HTTP/1.1 +Host: localhost:5000 +Content-Type: text/plain +X-Expiry: 3600 + +Expires in 1 hour +``` + +**Request (Password Protected):** + +Create a paste that requires a password to access: + +```http +POST / HTTP/1.1 +Host: localhost:5000 +Content-Type: text/plain +X-Paste-Password: secretpassword + +Password protected content +``` + **Response (201 Created):** ```json { @@ -167,7 +206,10 @@ Hello, World! "raw": "/abc12345/raw", "mime_type": "text/plain", "created_at": 1700000000, - "owner": "a1b2c3..." // Only present if authenticated + "owner": "a1b2c3...", // Only present if authenticated + "burn_after_read": true, // Only present if enabled + "expires_at": 1700003600, // Only present if custom expiry set + "password_protected": true // Only present if password set } ``` @@ -175,6 +217,7 @@ Hello, World! | Code | Description | |------|-------------| | 400 | No content provided | +| 400 | Password too long (max 1024 chars) | | 400 | Proof-of-work required (when PoW enabled) | | 400 | Proof-of-work failed (invalid/expired challenge) | | 413 | Paste too large | @@ -198,6 +241,13 @@ GET /abc12345 HTTP/1.1 Host: localhost:5000 ``` +**Request (Password Protected):** +```http +GET /abc12345 HTTP/1.1 +Host: localhost:5000 +X-Paste-Password: secretpassword +``` + **Response (200 OK):** ```json { @@ -205,7 +255,8 @@ Host: localhost:5000 "mime_type": "text/plain", "size": 1234, "created_at": 1700000000, - "raw": "/abc12345/raw" + "raw": "/abc12345/raw", + "password_protected": true // Only present if protected } ``` @@ -213,6 +264,8 @@ Host: localhost:5000 | Code | Description | |------|-------------| | 400 | Invalid paste ID format | +| 401 | Password required | +| 403 | Invalid password | | 404 | Paste not found | --- @@ -229,6 +282,13 @@ GET /abc12345/raw HTTP/1.1 Host: localhost:5000 ``` +**Request (Password Protected):** +```http +GET /abc12345/raw HTTP/1.1 +Host: localhost:5000 +X-Paste-Password: secretpassword +``` + **Response (200 OK):** ```http HTTP/1.1 200 OK @@ -245,6 +305,8 @@ Content-Disposition: inline | Code | Description | |------|-------------| | 400 | Invalid paste ID format | +| 401 | Password required | +| 403 | Invalid password | | 404 | Paste not found | --- @@ -306,6 +368,135 @@ Pastes expire based on last access time (default: 5 days). - Cleanup runs automatically (hourly, throttled) - Configurable via `FLASKPASTE_EXPIRY` environment variable +**Custom Expiry:** + +Pastes can have custom expiry times using the `X-Expiry` header: + +```bash +# Paste expires in 1 hour +curl -H "X-Expiry: 3600" --data-binary @file.txt http://host/ +``` + +**Configuration:** +```bash +export FLASKPASTE_EXPIRY=432000 # Default expiry: 5 days +export FLASKPASTE_MAX_EXPIRY=2592000 # Max custom expiry: 30 days +``` + +**Notes:** +- Custom expiry is capped at `FLASKPASTE_MAX_EXPIRY` +- Invalid or negative values are ignored (uses default) +- Response includes `expires_at` timestamp when custom expiry is set + +--- + +## Burn-After-Read + +Single-access pastes that delete themselves after first retrieval. + +**How it works:** +- Set `X-Burn-After-Read: true` header on creation +- First `GET /{id}/raw` returns content and deletes paste +- Subsequent requests return 404 +- Metadata `GET /{id}` does not trigger burn +- `HEAD` requests do not trigger burn + +**Usage:** +```bash +# Create burn-after-read paste +curl -H "X-Burn-After-Read: true" --data-binary @secret.txt http://host/ + +# Response indicates burn is enabled +{ + "id": "abc12345", + "burn_after_read": true, + ... +} +``` + +**Notes:** +- Response includes `X-Burn-After-Read: true` header when content is retrieved +- Can be combined with custom expiry (paste expires OR burns, whichever first) +- Accepted values: `true`, `1`, `yes` (case-insensitive) + +--- + +## Password Protection + +Pastes can be protected with a password using PBKDF2-HMAC-SHA256 hashing. + +**Creating a protected paste:** +```http +POST / HTTP/1.1 +Host: localhost:5000 +Content-Type: text/plain +X-Paste-Password: mysecretpassword + +Protected content here +``` + +**Response (201 Created):** +```json +{ + "id": "abc12345", + "url": "/abc12345", + "raw": "/abc12345/raw", + "mime_type": "text/plain", + "created_at": 1700000000, + "password_protected": true +} +``` + +**Accessing protected paste:** +```http +GET /abc12345 HTTP/1.1 +Host: localhost:5000 +X-Paste-Password: mysecretpassword +``` + +**Errors:** +| Code | Description | +|------|-------------| +| 400 | Password too long (max 1024 chars) | +| 401 | Password required | +| 403 | Invalid password | + +**Response (401 Unauthorized):** +```json +{ + "error": "Password required", + "password_protected": true +} +``` + +**Response (403 Forbidden):** +```json +{ + "error": "Invalid password" +} +``` + +**Security Implementation:** +- PBKDF2-HMAC-SHA256 with 600,000 iterations (OWASP 2023) +- 32-byte random salt per password +- Constant-time comparison prevents timing attacks +- Passwords never logged or stored in plaintext +- Maximum 1024 characters to prevent DoS via expensive hashing + +**CLI Usage:** +```bash +# Create protected paste +./fpaste create -p "mypassword" secret.txt + +# Retrieve protected paste +./fpaste get -p "mypassword" abc12345 +``` + +**Notes:** +- Password protection can be combined with burn-after-read and custom expiry +- Unicode passwords are supported +- Special characters are allowed + --- ## Abuse Prevention @@ -513,3 +704,235 @@ http-request set-header X-Proxy-Secret your-secret-value ``` If the secret doesn't match, authentication headers (`X-SSL-Client-SHA1`) are ignored and the request is treated as anonymous. + +--- + +## PKI (Certificate Authority) + +FlaskPaste includes an optional minimal PKI for issuing client certificates. + +### Configuration + +Enable PKI via environment variables: + +```bash +export FLASKPASTE_PKI_ENABLED=1 # Enable PKI endpoints +export FLASKPASTE_PKI_CA_PASSWORD="secret" # Required: CA password +export FLASKPASTE_PKI_CERT_DAYS=365 # Client certificate validity (days) +export FLASKPASTE_PKI_CA_DAYS=3650 # CA certificate validity (days) +``` + +### Certificate Revocation + +When PKI is enabled, certificates issued by the CA are tracked. Revoked certificates are rejected during authentication (treated as anonymous). + +--- + +### GET /pki + +Get PKI status and CA information. + +**Request:** +```http +GET /pki HTTP/1.1 +Host: localhost:5000 +``` + +**Response (PKI disabled):** +```json +{ + "enabled": false +} +``` + +**Response (PKI enabled, no CA):** +```json +{ + "enabled": true, + "ca_exists": false, + "hint": "POST /pki/ca to generate CA" +} +``` + +**Response (PKI enabled with CA):** +```json +{ + "enabled": true, + "ca_exists": true, + "common_name": "FlaskPaste CA", + "fingerprint_sha1": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", + "created_at": 1700000000, + "expires_at": 2015000000, + "key_algorithm": "ec:secp384r1" +} +``` + +--- + +### POST /pki/ca + +Generate a new Certificate Authority. Only works once (first-run bootstrap). + +**Request:** +```http +POST /pki/ca HTTP/1.1 +Host: localhost:5000 +Content-Type: application/json + +{"common_name": "My CA"} +``` + +**Response (201 Created):** +```json +{ + "message": "CA generated", + "common_name": "My CA", + "fingerprint_sha1": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", + "created_at": 1700000000, + "expires_at": 2015000000, + "download": "/pki/ca.crt" +} +``` + +**Errors:** +| Code | Description | +|------|-------------| +| 404 | PKI not enabled | +| 409 | CA already exists | +| 500 | PKI_CA_PASSWORD not configured | + +--- + +### GET /pki/ca.crt + +Download CA certificate in PEM format (for trust store). + +**Request:** +```http +GET /pki/ca.crt HTTP/1.1 +Host: localhost:5000 +``` + +**Response (200 OK):** +``` +-----BEGIN CERTIFICATE----- +MIICxDCCAaygAwIBAgIUY... +-----END CERTIFICATE----- +``` + +Content-Type: `application/x-pem-file` + +**Errors:** +| Code | Description | +|------|-------------| +| 404 | PKI not enabled or CA not initialized | + +--- + +### POST /pki/issue + +Issue a new client certificate (open registration). + +**Request:** +```http +POST /pki/issue HTTP/1.1 +Host: localhost:5000 +Content-Type: application/json + +{"common_name": "alice"} +``` + +**Response (201 Created):** +```json +{ + "message": "Certificate issued", + "serial": "00000000000000000000000000000001", + "common_name": "alice", + "fingerprint_sha1": "b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3", + "created_at": 1700000000, + "expires_at": 1731536000, + "certificate_pem": "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----\n", + "private_key_pem": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n" +} +``` + +**Errors:** +| Code | Description | +|------|-------------| +| 400 | common_name required | +| 404 | PKI not enabled or CA not initialized | +| 500 | Certificate issuance failed | + +**Security Notes:** +- Private key is generated server-side and returned in response +- Store the private key securely; it is not recoverable +- The certificate can be used with nginx, HAProxy, or curl for mTLS + +--- + +### GET /pki/certs + +List certificates (own certificates only). + +**Request (authenticated):** +```http +GET /pki/certs HTTP/1.1 +Host: localhost:5000 +X-SSL-Client-SHA1: a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2 +``` + +**Response (200 OK):** +```json +{ + "certificates": [ + { + "serial": "00000000000000000000000000000001", + "common_name": "alice", + "fingerprint_sha1": "b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3", + "created_at": 1700000000, + "expires_at": 1731536000, + "status": "valid", + "issued_to": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + } + ], + "count": 1 +} +``` + +**Notes:** +- Anonymous users receive an empty list +- Users see certificates they issued or certificates that match their fingerprint + +--- + +### POST /pki/revoke/{serial} + +Revoke a certificate by serial number. Requires authentication and ownership. + +**Request:** +```http +POST /pki/revoke/00000000000000000000000000000001 HTTP/1.1 +Host: localhost:5000 +X-SSL-Client-SHA1: a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2 +``` + +**Response (200 OK):** +```json +{ + "message": "Certificate revoked", + "serial": "00000000000000000000000000000001" +} +``` + +**Errors:** +| Code | Description | +|------|-------------| +| 401 | Authentication required | +| 403 | Permission denied (not owner) | +| 404 | Certificate not found or PKI not enabled | +| 409 | Certificate already revoked | + +**Authorization:** +- Must be authenticated +- Can revoke certificates you issued (issued_to matches your fingerprint) +- Can revoke your own certificate (fingerprint matches) diff --git a/fpaste b/fpaste index c7c0ac4..22d9e65 100755 --- a/fpaste +++ b/fpaste @@ -6,14 +6,20 @@ import base64 import hashlib import json import os +import ssl import sys import urllib.error import urllib.request +from datetime import UTC, datetime, timedelta from pathlib import Path -# Optional encryption support +# Optional cryptography support (for encryption and cert generation) try: + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import ec, rsa from cryptography.hazmat.primitives.ciphers.aead import AESGCM + from cryptography.x509.oid import NameOID HAS_CRYPTO = True except ImportError: @@ -25,6 +31,9 @@ def get_config(): config = { "server": os.environ.get("FLASKPASTE_SERVER", "http://localhost:5000"), "cert_sha1": os.environ.get("FLASKPASTE_CERT_SHA1", ""), + "client_cert": os.environ.get("FLASKPASTE_CLIENT_CERT", ""), + "client_key": os.environ.get("FLASKPASTE_CLIENT_KEY", ""), + "ca_cert": os.environ.get("FLASKPASTE_CA_CERT", ""), } # Try config file @@ -40,17 +49,51 @@ def get_config(): config["server"] = value elif key == "cert_sha1": config["cert_sha1"] = value + elif key == "client_cert": + config["client_cert"] = value + elif key == "client_key": + config["client_key"] = value + elif key == "ca_cert": + config["ca_cert"] = value return config -def request(url, method="GET", data=None, headers=None): +def create_ssl_context(config): + """Create SSL context for mTLS if certificates are configured.""" + client_cert = config.get("client_cert", "") + client_key = config.get("client_key", "") + ca_cert = config.get("ca_cert", "") + + if not client_cert: + return None + + ctx = ssl.create_default_context() + + # Load CA certificate if specified + if ca_cert: + ctx.load_verify_locations(ca_cert) + + # Load client certificate and key + try: + ctx.load_cert_chain(certfile=client_cert, keyfile=client_key or None) + except ssl.SSLError as e: + die(f"failed to load client certificate: {e}") + except FileNotFoundError as e: + die(f"certificate file not found: {e}") + + return ctx + + +def request(url, method="GET", data=None, headers=None, ssl_context=None): """Make HTTP request and return response.""" headers = headers or {} - req = urllib.request.Request(url, data=data, headers=headers, method=method) + # User-configured server URL, audit is expected + req = urllib.request.Request(url, data=data, headers=headers, method=method) # noqa: S310 try: - with urllib.request.urlopen(req, timeout=30) as resp: + # User-configured server URL, audit is expected + with urllib.request.urlopen(req, timeout=30, context=ssl_context) as resp: # noqa: S310 return resp.status, resp.read(), dict(resp.headers) except urllib.error.HTTPError as e: return e.code, e.read(), dict(e.headers) @@ -120,11 +163,11 @@ def solve_pow(nonce, difficulty): # Count leading zero bits zero_bits = 0 - for byte in hash_bytes[:target_bytes + 1]: + for byte in hash_bytes[: target_bytes + 1]: if byte == 0: zero_bits += 8 else: - zero_bits += (8 - byte.bit_length()) + zero_bits += 8 - byte.bit_length() break if zero_bits >= difficulty: @@ -141,7 +184,7 @@ def solve_pow(nonce, difficulty): def get_challenge(config): """Fetch PoW challenge from server.""" url = config["server"].rstrip("/") + "/challenge" - status, body, _ = request(url) + status, body, _ = request(url, ssl_context=config.get("ssl_context")) if status != 200: return None @@ -186,6 +229,18 @@ def cmd_create(args, config): if config["cert_sha1"]: headers["X-SSL-Client-SHA1"] = config["cert_sha1"] + # Add burn-after-read header + if args.burn: + headers["X-Burn-After-Read"] = "true" + + # Add custom expiry header + if args.expiry: + headers["X-Expiry"] = str(args.expiry) + + # Add password header + if args.password: + headers["X-Paste-Password"] = args.password + # Get and solve PoW challenge if required challenge = get_challenge(config) if challenge: @@ -198,7 +253,9 @@ def cmd_create(args, config): headers["X-PoW-Solution"] = str(solution) url = config["server"].rstrip("/") + "/" - status, body, _ = request(url, method="POST", data=content, headers=headers) + status, body, _ = request( + url, method="POST", data=content, headers=headers, ssl_context=config.get("ssl_context") + ) if status == 201: data = json.loads(body) @@ -236,9 +293,14 @@ def cmd_get(args, config): paste_id = url_input.split("/")[-1] # Handle full URLs base = config["server"].rstrip("/") + # Build headers for password-protected pastes + headers = {} + if args.password: + headers["X-Paste-Password"] = args.password + if args.meta: url = f"{base}/{paste_id}" - status, body, _ = request(url) + status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) if status == 200: data = json.loads(body) print(f"id: {data['id']}") @@ -246,12 +308,19 @@ def cmd_get(args, config): print(f"size: {data['size']}") print(f"created_at: {data['created_at']}") if encryption_key: - print(f"encrypted: yes (key in URL)") + print("encrypted: yes (key in URL)") + if data.get("password_protected"): + print("protected: yes (password required)") + elif status == 401: + die("password required (-p)") + elif status == 403: + die("invalid password") else: die(f"not found: {paste_id}") else: url = f"{base}/{paste_id}/raw" - status, body, headers = request(url) + ssl_ctx = config.get("ssl_context") + status, body, _ = request(url, headers=headers, ssl_context=ssl_ctx) if status == 200: # Decrypt if encryption key was provided if encryption_key: @@ -266,6 +335,10 @@ def cmd_get(args, config): # Add newline if content doesn't end with one and stdout is tty if sys.stdout.isatty() and body and not body.endswith(b"\n"): sys.stdout.buffer.write(b"\n") + elif status == 401: + die("password required (-p)") + elif status == 403: + die("invalid password") else: die(f"not found: {paste_id}") @@ -280,7 +353,9 @@ def cmd_delete(args, config): url = f"{base}/{paste_id}" headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} - status, body, _ = request(url, method="DELETE", headers=headers) + status, _, _ = request( + url, method="DELETE", headers=headers, ssl_context=config.get("ssl_context") + ) if status == 200: print(f"deleted: {paste_id}") @@ -297,7 +372,7 @@ def cmd_delete(args, config): def cmd_info(args, config): """Show server info.""" url = config["server"].rstrip("/") + "/" - status, body, _ = request(url) + status, body, _ = request(url, ssl_context=config.get("ssl_context")) if status == 200: data = json.loads(body) @@ -308,21 +383,337 @@ def cmd_info(args, config): die("failed to connect to server") +def cmd_pki_status(args, config): + """Show PKI status and CA information.""" + url = config["server"].rstrip("/") + "/pki" + status, body, _ = request(url, ssl_context=config.get("ssl_context")) + + if status == 404: + die("PKI not enabled on this server") + elif status != 200: + die(f"failed to get PKI status ({status})") + + data = json.loads(body) + + print(f"pki enabled: {data.get('enabled', False)}") + print(f"ca exists: {data.get('ca_exists', False)}") + + if data.get("ca_exists"): + print(f"common name: {data.get('common_name', 'unknown')}") + print(f"fingerprint: {data.get('fingerprint_sha1', 'unknown')}") + if data.get("created_at"): + print(f"created: {data.get('created_at')}") + if data.get("expires_at"): + print(f"expires: {data.get('expires_at')}") + print(f"download: {config['server'].rstrip('/')}{data.get('download', '/pki/ca.crt')}") + elif data.get("hint"): + print(f"hint: {data.get('hint')}") + + +def cmd_pki_issue(args, config): + """Request a new client certificate from the server CA.""" + url = config["server"].rstrip("/") + "/pki/issue" + + headers = {"Content-Type": "application/json"} + if config["cert_sha1"]: + headers["X-SSL-Client-SHA1"] = config["cert_sha1"] + + payload = {"common_name": args.name} + data = json.dumps(payload).encode() + + status, body, _ = request( + url, method="POST", data=data, headers=headers, ssl_context=config.get("ssl_context") + ) + + if status == 404: + # Could be PKI disabled or no CA + try: + err = json.loads(body).get("error", "PKI not available") + except (json.JSONDecodeError, UnicodeDecodeError): + err = "PKI not available" + die(err) + elif status == 400: + try: + err = json.loads(body).get("error", "bad request") + except (json.JSONDecodeError, UnicodeDecodeError): + err = "bad request" + die(err) + elif status != 201: + die(f"certificate issuance failed ({status})") + + result = json.loads(body) + + # Determine output directory + out_dir = Path(args.output) if args.output else Path.home() / ".config" / "fpaste" + out_dir.mkdir(parents=True, exist_ok=True) + + # File paths + key_file = out_dir / "client.key" + cert_file = out_dir / "client.crt" + + # Check for existing files + if not args.force: + if key_file.exists(): + die(f"key file exists: {key_file} (use --force)") + if cert_file.exists(): + die(f"cert file exists: {cert_file} (use --force)") + + # Write files + key_file.write_text(result["private_key_pem"]) + key_file.chmod(0o600) + cert_file.write_text(result["certificate_pem"]) + + fingerprint = result.get("fingerprint_sha1", "unknown") + + print(f"key: {key_file}", file=sys.stderr) + print(f"certificate: {cert_file}", file=sys.stderr) + print(f"fingerprint: {fingerprint}", file=sys.stderr) + print(f"serial: {result.get('serial', 'unknown')}", file=sys.stderr) + print(f"common name: {result.get('common_name', args.name)}", file=sys.stderr) + + # Update config file if requested + if args.configure: + config_file = Path.home() / ".config" / "fpaste" / "config" + config_file.parent.mkdir(parents=True, exist_ok=True) + + # Read existing config + existing = {} + if config_file.exists(): + for line in config_file.read_text().splitlines(): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + existing[k.strip().lower()] = v.strip() + + # Update values + existing["client_cert"] = str(cert_file) + existing["client_key"] = str(key_file) + existing["cert_sha1"] = fingerprint + + # Write config + lines = [f"{k} = {v}" for k, v in sorted(existing.items())] + config_file.write_text("\n".join(lines) + "\n") + print(f"config: {config_file} (updated)", file=sys.stderr) + + # Output fingerprint to stdout for easy capture + print(fingerprint) + + +def cmd_pki_download(args, config): + """Download the CA certificate from the server.""" + url = config["server"].rstrip("/") + "/pki/ca.crt" + status, body, _ = request(url, ssl_context=config.get("ssl_context")) + + if status == 404: + die("CA certificate not available (PKI disabled or CA not generated)") + elif status != 200: + die(f"failed to download CA certificate ({status})") + + # Determine output + if args.output: + out_path = Path(args.output) + out_path.write_bytes(body) + print(f"saved: {out_path}", file=sys.stderr) + + # Calculate and show fingerprint if cryptography available + if HAS_CRYPTO: + cert = x509.load_pem_x509_certificate(body) + # SHA1 is standard for X.509 fingerprints + fp = hashlib.sha1(cert.public_bytes(serialization.Encoding.DER)).hexdigest() # noqa: S324 + print(f"fingerprint: {fp}", file=sys.stderr) + + # Update config if requested + if args.configure: + config_file = Path.home() / ".config" / "fpaste" / "config" + config_file.parent.mkdir(parents=True, exist_ok=True) + + existing = {} + if config_file.exists(): + for line in config_file.read_text().splitlines(): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + existing[k.strip().lower()] = v.strip() + + existing["ca_cert"] = str(out_path) + + lines = [f"{k} = {v}" for k, v in sorted(existing.items())] + config_file.write_text("\n".join(lines) + "\n") + print(f"config: {config_file} (updated)", file=sys.stderr) + else: + # Output to stdout + sys.stdout.buffer.write(body) + + +def cmd_cert(args, config): + """Generate a self-signed client certificate for mTLS authentication.""" + if not HAS_CRYPTO: + die("certificate generation requires 'cryptography' package: pip install cryptography") + + # Determine output directory + out_dir = Path(args.output) if args.output else Path.home() / ".config" / "fpaste" + out_dir.mkdir(parents=True, exist_ok=True) + + # File paths + key_file = out_dir / "client.key" + cert_file = out_dir / "client.crt" + + # Check for existing files + if not args.force: + if key_file.exists(): + die(f"key file exists: {key_file} (use --force)") + if cert_file.exists(): + die(f"cert file exists: {cert_file} (use --force)") + + # Generate private key + if args.algorithm == "rsa": + key_size = args.bits or 4096 + print(f"generating {key_size}-bit RSA key...", file=sys.stderr) + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=key_size, + ) + elif args.algorithm == "ec": + curve_name = args.curve or "secp384r1" + curves = { + "secp256r1": ec.SECP256R1(), + "secp384r1": ec.SECP384R1(), + "secp521r1": ec.SECP521R1(), + } + if curve_name not in curves: + die(f"unsupported curve: {curve_name} (use: secp256r1, secp384r1, secp521r1)") + print(f"generating EC key ({curve_name})...", file=sys.stderr) + private_key = ec.generate_private_key(curves[curve_name]) + else: + die(f"unsupported algorithm: {args.algorithm}") + + # Certificate subject + cn = args.name or os.environ.get("USER", "fpaste-client") + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, cn), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste Client"), + ] + ) + + # Validity period + days = args.days or 365 + now = datetime.now(UTC) + + # Build certificate + cert_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=days)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + content_commitment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH]), + critical=False, + ) + ) + + # Sign certificate + print("signing certificate...", file=sys.stderr) + certificate = cert_builder.sign(private_key, hashes.SHA256()) + + # Calculate SHA1 fingerprint (standard for X.509) + cert_der = certificate.public_bytes(serialization.Encoding.DER) + fingerprint = hashlib.sha1(cert_der).hexdigest() # noqa: S324 + + # Serialize private key + if args.password_key: + key_encryption = serialization.BestAvailableEncryption(args.password_key.encode("utf-8")) + else: + key_encryption = serialization.NoEncryption() + + key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=key_encryption, + ) + + # Serialize certificate + cert_pem = certificate.public_bytes(serialization.Encoding.PEM) + + # Write files + key_file.write_bytes(key_pem) + key_file.chmod(0o600) # Restrict permissions + cert_file.write_bytes(cert_pem) + + print(f"key: {key_file}", file=sys.stderr) + print(f"certificate: {cert_file}", file=sys.stderr) + print(f"fingerprint: {fingerprint}", file=sys.stderr) + print(f"valid for: {days} days", file=sys.stderr) + print(f"common name: {cn}", file=sys.stderr) + + # Update config file if requested + if args.configure: + config_file = Path.home() / ".config" / "fpaste" / "config" + config_file.parent.mkdir(parents=True, exist_ok=True) + + # Read existing config + existing = {} + if config_file.exists(): + for line in config_file.read_text().splitlines(): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + existing[k.strip().lower()] = v.strip() + + # Update values + existing["client_cert"] = str(cert_file) + existing["client_key"] = str(key_file) + existing["cert_sha1"] = fingerprint + + # Write config + lines = [f"{k} = {v}" for k, v in sorted(existing.items())] + config_file.write_text("\n".join(lines) + "\n") + print(f"config: {config_file} (updated)", file=sys.stderr) + + # Output fingerprint to stdout for easy capture + print(fingerprint) + + def main(): parser = argparse.ArgumentParser( prog="fpaste", description="FlaskPaste command-line client", ) parser.add_argument( - "-s", "--server", - help="server URL (default: $FLASKPASTE_SERVER or http://localhost:5000)", + "-s", + "--server", + help="server URL (env: FLASKPASTE_SERVER)", ) subparsers = parser.add_subparsers(dest="command", metavar="command") # create p_create = subparsers.add_parser("create", aliases=["c", "new"], help="create paste") p_create.add_argument("file", nargs="?", help="file to upload (- for stdin)") - p_create.add_argument("-e", "--encrypt", action="store_true", help="encrypt content (E2E)") + p_create.add_argument("-e", "--encrypt", action="store_true", help="encrypt content") + p_create.add_argument("-b", "--burn", action="store_true", help="burn after read") + p_create.add_argument("-x", "--expiry", type=int, metavar="SEC", help="expiry in seconds") + p_create.add_argument("-p", "--password", metavar="PASS", help="password protect") p_create.add_argument("-r", "--raw", action="store_true", help="output raw URL") p_create.add_argument("-q", "--quiet", action="store_true", help="output ID only") @@ -330,6 +721,7 @@ def main(): p_get = subparsers.add_parser("get", aliases=["g"], help="retrieve paste") p_get.add_argument("id", help="paste ID or URL") p_get.add_argument("-o", "--output", help="save to file") + p_get.add_argument("-p", "--password", metavar="PASS", help="password for protected paste") p_get.add_argument("-m", "--meta", action="store_true", help="show metadata only") # delete @@ -339,18 +731,73 @@ def main(): # info subparsers.add_parser("info", aliases=["i"], help="show server info") + # cert + p_cert = subparsers.add_parser("cert", help="generate client certificate") + p_cert.add_argument("-o", "--output", metavar="DIR", help="output directory") + p_cert.add_argument( + "-a", "--algorithm", choices=["rsa", "ec"], default="ec", help="key algorithm (default: ec)" + ) + p_cert.add_argument("-b", "--bits", type=int, metavar="N", help="RSA key size (default: 4096)") + p_cert.add_argument( + "-c", "--curve", metavar="CURVE", help="EC curve: secp256r1, secp384r1, secp521r1" + ) + p_cert.add_argument("-d", "--days", type=int, metavar="N", help="validity period in days") + p_cert.add_argument("-n", "--name", metavar="CN", help="common name (default: $USER)") + p_cert.add_argument("--password-key", metavar="PASS", help="encrypt private key with password") + p_cert.add_argument( + "--configure", action="store_true", help="update config file with generated cert paths" + ) + p_cert.add_argument("-f", "--force", action="store_true", help="overwrite existing files") + + # pki (with subcommands) + p_pki = subparsers.add_parser("pki", help="PKI operations (server-issued certificates)") + pki_sub = p_pki.add_subparsers(dest="pki_command", metavar="subcommand") + + # pki status + pki_sub.add_parser("status", help="show PKI status and CA info") + + # pki issue + p_pki_issue = pki_sub.add_parser("issue", help="request certificate from server CA") + p_pki_issue.add_argument( + "-n", "--name", required=True, metavar="CN", help="common name for certificate (required)" + ) + p_pki_issue.add_argument( + "-o", "--output", metavar="DIR", help="output directory (default: ~/.config/fpaste)" + ) + p_pki_issue.add_argument( + "--configure", action="store_true", help="update config file with issued cert paths" + ) + p_pki_issue.add_argument("-f", "--force", action="store_true", help="overwrite existing files") + + # pki download + p_pki_download = pki_sub.add_parser("download", aliases=["dl"], help="download CA certificate") + p_pki_download.add_argument( + "-o", "--output", metavar="FILE", help="save to file (default: stdout)" + ) + p_pki_download.add_argument( + "--configure", + action="store_true", + help="update config file with CA cert path (requires -o)", + ) + args = parser.parse_args() config = get_config() if args.server: config["server"] = args.server + # Create SSL context for mTLS if configured + config["ssl_context"] = create_ssl_context(config) + if not args.command: # Default: create from stdin if data is piped if not sys.stdin.isatty(): args.command = "create" args.file = None args.encrypt = False + args.burn = False + args.expiry = None + args.password = None args.raw = False args.quiet = False else: @@ -365,6 +812,18 @@ def main(): cmd_delete(args, config) elif args.command in ("info", "i"): cmd_info(args, config) + elif args.command == "cert": + cmd_cert(args, config) + elif args.command == "pki": + if args.pki_command == "status": + cmd_pki_status(args, config) + elif args.pki_command == "issue": + cmd_pki_issue(args, config) + elif args.pki_command in ("download", "dl"): + cmd_pki_download(args, config) + else: + # Show pki help if no subcommand + parser.parse_args(["pki", "--help"]) if __name__ == "__main__": diff --git a/tests/test_paste_options.py b/tests/test_paste_options.py new file mode 100644 index 0000000..975150f --- /dev/null +++ b/tests/test_paste_options.py @@ -0,0 +1,500 @@ +"""Tests for burn-after-read and custom expiry features.""" + +import time + +import pytest + +from app import create_app +from app.database import cleanup_expired_pastes + + +class TestBurnAfterRead: + """Test burn-after-read paste functionality.""" + + @pytest.fixture + def app(self): + """Create app for burn-after-read tests.""" + return create_app("testing") + + @pytest.fixture + def client(self, app): + """Create test client.""" + return app.test_client() + + def test_create_burn_paste(self, client): + """Creating a burn-after-read paste should succeed.""" + response = client.post( + "/", + data=b"secret message", + headers={"X-Burn-After-Read": "true"}, + ) + assert response.status_code == 201 + data = response.get_json() + assert data["burn_after_read"] is True + + def test_burn_paste_deleted_after_raw_get(self, client): + """Burn paste should be deleted after first GET /raw.""" + # Create burn paste + response = client.post( + "/", + data=b"one-time secret", + headers={"X-Burn-After-Read": "true"}, + ) + paste_id = response.get_json()["id"] + + # First GET should succeed + response = client.get(f"/{paste_id}/raw") + assert response.status_code == 200 + assert response.data == b"one-time secret" + assert response.headers.get("X-Burn-After-Read") == "true" + + # Second GET should fail (paste deleted) + response = client.get(f"/{paste_id}/raw") + assert response.status_code == 404 + + def test_burn_paste_metadata_does_not_trigger_burn(self, client): + """GET metadata should not delete burn paste.""" + # Create burn paste + response = client.post( + "/", + data=b"secret", + headers={"X-Burn-After-Read": "true"}, + ) + paste_id = response.get_json()["id"] + + # Metadata GET should succeed and show burn flag + response = client.get(f"/{paste_id}") + assert response.status_code == 200 + data = response.get_json() + assert data["burn_after_read"] is True + + # Paste should still exist + response = client.get(f"/{paste_id}") + assert response.status_code == 200 + + # Raw GET should delete it + response = client.get(f"/{paste_id}/raw") + assert response.status_code == 200 + + # Now it's gone + response = client.get(f"/{paste_id}") + assert response.status_code == 404 + + def test_head_does_not_trigger_burn(self, client): + """HEAD request should not delete burn paste.""" + # Create burn paste + response = client.post( + "/", + data=b"secret", + headers={"X-Burn-After-Read": "true"}, + ) + paste_id = response.get_json()["id"] + + # HEAD should succeed + response = client.head(f"/{paste_id}/raw") + assert response.status_code == 200 + + # Paste should still exist + response = client.get(f"/{paste_id}/raw") + assert response.status_code == 200 + + def test_burn_header_variations(self, client): + """Different true values for X-Burn-After-Read should work.""" + for value in ["true", "TRUE", "1", "yes", "YES"]: + response = client.post( + "/", + data=b"content", + headers={"X-Burn-After-Read": value}, + ) + data = response.get_json() + assert data.get("burn_after_read") is True, f"Failed for value: {value}" + + def test_burn_header_false_values(self, client): + """False values should not enable burn-after-read.""" + for value in ["false", "0", "no", ""]: + response = client.post( + "/", + data=b"content", + headers={"X-Burn-After-Read": value}, + ) + data = response.get_json() + assert "burn_after_read" not in data, f"Should not be burn for: {value}" + + +class TestCustomExpiry: + """Test custom expiry functionality.""" + + @pytest.fixture + def app(self): + """Create app with short max expiry for testing.""" + app = create_app("testing") + app.config["MAX_EXPIRY_SECONDS"] = 3600 # 1 hour max + app.config["PASTE_EXPIRY_SECONDS"] = 60 # 1 minute default + return app + + @pytest.fixture + def client(self, app): + """Create test client.""" + return app.test_client() + + def test_create_paste_with_custom_expiry(self, client): + """Creating a paste with X-Expiry should set expires_at.""" + response = client.post( + "/", + data=b"temporary content", + headers={"X-Expiry": "300"}, # 5 minutes + ) + assert response.status_code == 201 + data = response.get_json() + assert "expires_at" in data + # Should be approximately now + 300 + now = int(time.time()) + assert abs(data["expires_at"] - (now + 300)) < 5 + + def test_custom_expiry_capped_at_max(self, client): + """Custom expiry should be capped at MAX_EXPIRY_SECONDS.""" + response = client.post( + "/", + data=b"content", + headers={"X-Expiry": "999999"}, # Way more than max + ) + assert response.status_code == 201 + data = response.get_json() + assert "expires_at" in data + # Should be capped at 3600 seconds from now + now = int(time.time()) + assert abs(data["expires_at"] - (now + 3600)) < 5 + + def test_expiry_shown_in_metadata(self, client): + """Custom expiry should appear in paste metadata.""" + response = client.post( + "/", + data=b"content", + headers={"X-Expiry": "600"}, + ) + paste_id = response.get_json()["id"] + + response = client.get(f"/{paste_id}") + data = response.get_json() + assert "expires_at" in data + + def test_invalid_expiry_ignored(self, client): + """Invalid X-Expiry values should be ignored.""" + for value in ["invalid", "-100", "0", ""]: + response = client.post( + "/", + data=b"content", + headers={"X-Expiry": value}, + ) + assert response.status_code == 201 + data = response.get_json() + assert "expires_at" not in data, f"Should not have expiry for: {value}" + + def test_paste_without_custom_expiry(self, client): + """Paste without X-Expiry should not have expires_at.""" + response = client.post("/", data=b"content") + assert response.status_code == 201 + data = response.get_json() + assert "expires_at" not in data + + +class TestExpiryCleanup: + """Test cleanup of expired pastes.""" + + @pytest.fixture + def app(self): + """Create app with very short expiry for testing.""" + app = create_app("testing") + app.config["PASTE_EXPIRY_SECONDS"] = 1 # 1 second default + app.config["MAX_EXPIRY_SECONDS"] = 10 + return app + + @pytest.fixture + def client(self, app): + """Create test client.""" + return app.test_client() + + def test_cleanup_custom_expired_paste(self, app, client): + """Paste with expired custom expiry should be cleaned up.""" + # Create paste with 1 second expiry + response = client.post( + "/", + data=b"expiring soon", + headers={"X-Expiry": "1"}, + ) + paste_id = response.get_json()["id"] + + # Should exist immediately + response = client.get(f"/{paste_id}") + assert response.status_code == 200 + + # Wait for expiry + time.sleep(2) + + # Run cleanup + with app.app_context(): + deleted = cleanup_expired_pastes() + assert deleted >= 1 + + # Should be gone + response = client.get(f"/{paste_id}") + assert response.status_code == 404 + + def test_cleanup_respects_default_expiry(self, app, client): + """Paste without custom expiry should use default expiry.""" + # Create paste without custom expiry + response = client.post("/", data=b"default expiry") + paste_id = response.get_json()["id"] + + # Wait for default expiry (1 second in test config) + time.sleep(2) + + # Run cleanup + with app.app_context(): + deleted = cleanup_expired_pastes() + assert deleted >= 1 + + # Should be gone + response = client.get(f"/{paste_id}") + assert response.status_code == 404 + + def test_cleanup_keeps_unexpired_paste(self, app, client): + """Paste with future custom expiry should not be cleaned up.""" + # Create paste with long expiry + response = client.post( + "/", + data=b"not expiring soon", + headers={"X-Expiry": "10"}, # 10 seconds + ) + paste_id = response.get_json()["id"] + + # Run cleanup immediately + with app.app_context(): + cleanup_expired_pastes() + + # Should still exist + response = client.get(f"/{paste_id}") + assert response.status_code == 200 + + +class TestCombinedOptions: + """Test combinations of burn-after-read and custom expiry.""" + + @pytest.fixture + def app(self): + """Create app for combined tests.""" + return create_app("testing") + + @pytest.fixture + def client(self, app): + """Create test client.""" + return app.test_client() + + def test_burn_and_expiry_together(self, client): + """Paste can have both burn-after-read and custom expiry.""" + response = client.post( + "/", + data=b"secret with expiry", + headers={ + "X-Burn-After-Read": "true", + "X-Expiry": "3600", + }, + ) + assert response.status_code == 201 + data = response.get_json() + assert data["burn_after_read"] is True + assert "expires_at" in data + + +class TestPasswordProtection: + """Test password-protected paste functionality.""" + + @pytest.fixture + def app(self): + """Create app for password tests.""" + return create_app("testing") + + @pytest.fixture + def client(self, app): + """Create test client.""" + return app.test_client() + + def test_create_password_protected_paste(self, client): + """Creating a password-protected paste should succeed.""" + response = client.post( + "/", + data=b"secret content", + headers={"X-Paste-Password": "mypassword123"}, + ) + assert response.status_code == 201 + data = response.get_json() + assert data["password_protected"] is True + + def test_get_protected_paste_without_password(self, client): + """Accessing protected paste without password should return 401.""" + # Create protected paste + response = client.post( + "/", + data=b"protected content", + headers={"X-Paste-Password": "secret"}, + ) + paste_id = response.get_json()["id"] + + # Try to access without password + response = client.get(f"/{paste_id}") + assert response.status_code == 401 + data = response.get_json() + assert data["password_protected"] is True + assert "Password required" in data["error"] + + def test_get_protected_paste_with_wrong_password(self, client): + """Accessing protected paste with wrong password should return 403.""" + # Create protected paste + response = client.post( + "/", + data=b"protected content", + headers={"X-Paste-Password": "correctpassword"}, + ) + paste_id = response.get_json()["id"] + + # Try with wrong password + response = client.get( + f"/{paste_id}", + headers={"X-Paste-Password": "wrongpassword"}, + ) + assert response.status_code == 403 + data = response.get_json() + assert "Invalid password" in data["error"] + + def test_get_protected_paste_with_correct_password(self, client): + """Accessing protected paste with correct password should succeed.""" + password = "supersecret123" + # Create protected paste + response = client.post( + "/", + data=b"protected content", + headers={"X-Paste-Password": password}, + ) + paste_id = response.get_json()["id"] + + # Access with correct password + response = client.get( + f"/{paste_id}", + headers={"X-Paste-Password": password}, + ) + assert response.status_code == 200 + data = response.get_json() + assert data["password_protected"] is True + + def test_get_raw_protected_paste_without_password(self, client): + """Getting raw content without password should return 401.""" + response = client.post( + "/", + data=b"secret raw content", + headers={"X-Paste-Password": "secret"}, + ) + paste_id = response.get_json()["id"] + + response = client.get(f"/{paste_id}/raw") + assert response.status_code == 401 + + def test_get_raw_protected_paste_with_correct_password(self, client): + """Getting raw content with correct password should succeed.""" + password = "mypassword" + response = client.post( + "/", + data=b"secret raw content", + headers={"X-Paste-Password": password}, + ) + paste_id = response.get_json()["id"] + + response = client.get( + f"/{paste_id}/raw", + headers={"X-Paste-Password": password}, + ) + assert response.status_code == 200 + assert response.data == b"secret raw content" + + def test_password_too_long_rejected(self, client): + """Password longer than 1024 chars should be rejected.""" + long_password = "x" * 1025 + response = client.post( + "/", + data=b"content", + headers={"X-Paste-Password": long_password}, + ) + assert response.status_code == 400 + data = response.get_json() + assert "too long" in data["error"] + + def test_unprotected_paste_accessible(self, client): + """Unprotected paste should be accessible without password.""" + response = client.post("/", data=b"public content") + paste_id = response.get_json()["id"] + + response = client.get(f"/{paste_id}") + assert response.status_code == 200 + assert "password_protected" not in response.get_json() + + def test_password_with_special_chars(self, client): + """Password with special characters should work.""" + password = "p@ssw0rd!#$%^&*()_+-=[]{}|;':\",./<>?" + response = client.post( + "/", + data=b"special content", + headers={"X-Paste-Password": password}, + ) + paste_id = response.get_json()["id"] + + response = client.get( + f"/{paste_id}", + headers={"X-Paste-Password": password}, + ) + assert response.status_code == 200 + + def test_password_with_unicode(self, client): + """Password with unicode characters should work.""" + password = "пароль密码🔐" + response = client.post( + "/", + data=b"unicode content", + headers={"X-Paste-Password": password}, + ) + paste_id = response.get_json()["id"] + + response = client.get( + f"/{paste_id}", + headers={"X-Paste-Password": password}, + ) + assert response.status_code == 200 + + def test_password_combined_with_burn(self, client): + """Password protection can be combined with burn-after-read.""" + password = "secret" + response = client.post( + "/", + data=b"protected burn content", + headers={ + "X-Paste-Password": password, + "X-Burn-After-Read": "true", + }, + ) + assert response.status_code == 201 + data = response.get_json() + assert data["password_protected"] is True + assert data["burn_after_read"] is True + paste_id = data["id"] + + # First access with password should succeed + response = client.get( + f"/{paste_id}/raw", + headers={"X-Paste-Password": password}, + ) + assert response.status_code == 200 + + # Second access should fail (burned) + response = client.get( + f"/{paste_id}/raw", + headers={"X-Paste-Password": password}, + ) + assert response.status_code == 404 diff --git a/tests/test_pki.py b/tests/test_pki.py new file mode 100644 index 0000000..bf896bc --- /dev/null +++ b/tests/test_pki.py @@ -0,0 +1,371 @@ +"""Tests for PKI (Certificate Authority) functionality.""" + +from datetime import UTC + +import pytest + +from app.pki import reset_pki + + +@pytest.fixture(autouse=True) +def reset_pki_state(app): + """Reset PKI state and clear PKI database tables before each test.""" + reset_pki() + + # Clear PKI tables in database + with app.app_context(): + from app.database import get_db + + db = get_db() + db.execute("DELETE FROM issued_certificates") + db.execute("DELETE FROM certificate_authority") + db.commit() + + yield + reset_pki() + + +class TestPKIStatus: + """Test GET /pki endpoint.""" + + def test_pki_status_when_enabled(self, client): + """PKI status shows enabled with no CA initially.""" + response = client.get("/pki") + assert response.status_code == 200 + data = response.get_json() + assert data["enabled"] is True + assert data["ca_exists"] is False + assert "hint" in data + + def test_pki_status_after_ca_generation(self, client): + """PKI status shows CA info after generation.""" + # Generate CA first + client.post("/pki/ca", json={"common_name": "Test CA"}) + + response = client.get("/pki") + assert response.status_code == 200 + data = response.get_json() + assert data["enabled"] is True + assert data["ca_exists"] is True + assert data["common_name"] == "Test CA" + assert "fingerprint_sha1" in data + assert len(data["fingerprint_sha1"]) == 40 + + +class TestCAGeneration: + """Test POST /pki/ca endpoint.""" + + def test_generate_ca_success(self, client): + """CA can be generated with default name.""" + response = client.post("/pki/ca") + assert response.status_code == 201 + data = response.get_json() + assert data["message"] == "CA generated" + assert data["common_name"] == "FlaskPaste CA" + assert "fingerprint_sha1" in data + assert "created_at" in data + assert "expires_at" in data + assert data["download"] == "/pki/ca.crt" + + def test_generate_ca_custom_name(self, client): + """CA can be generated with custom name.""" + response = client.post("/pki/ca", json={"common_name": "My Custom CA"}) + assert response.status_code == 201 + data = response.get_json() + assert data["common_name"] == "My Custom CA" + + def test_generate_ca_twice_fails(self, client): + """CA cannot be generated twice.""" + # First generation succeeds + response = client.post("/pki/ca") + assert response.status_code == 201 + + # Second generation fails + response = client.post("/pki/ca") + assert response.status_code == 409 + data = response.get_json() + assert "already exists" in data["error"] + + +class TestCADownload: + """Test GET /pki/ca.crt endpoint.""" + + def test_download_ca_not_initialized(self, client): + """Download fails when no CA exists.""" + response = client.get("/pki/ca.crt") + assert response.status_code == 404 + + def test_download_ca_success(self, client): + """CA certificate can be downloaded.""" + # Generate CA first + client.post("/pki/ca", json={"common_name": "Test CA"}) + + response = client.get("/pki/ca.crt") + assert response.status_code == 200 + assert response.content_type == "application/x-pem-file" + assert b"-----BEGIN CERTIFICATE-----" in response.data + assert b"-----END CERTIFICATE-----" in response.data + + +class TestCertificateIssuance: + """Test POST /pki/issue endpoint.""" + + def test_issue_without_ca_fails(self, client): + """Issuance fails when no CA exists.""" + response = client.post("/pki/issue", json={"common_name": "alice"}) + assert response.status_code == 404 + + def test_issue_without_name_fails(self, client): + """Issuance fails without common_name.""" + client.post("/pki/ca") + + response = client.post("/pki/issue", json={}) + assert response.status_code == 400 + assert "common_name required" in response.get_json()["error"] + + def test_issue_certificate_success(self, client): + """Certificate issuance succeeds.""" + client.post("/pki/ca") + + response = client.post("/pki/issue", json={"common_name": "alice"}) + assert response.status_code == 201 + data = response.get_json() + assert data["message"] == "Certificate issued" + assert data["common_name"] == "alice" + assert "serial" in data + assert "fingerprint_sha1" in data + assert len(data["fingerprint_sha1"]) == 40 + assert "certificate_pem" in data + assert "private_key_pem" in data + assert "-----BEGIN CERTIFICATE-----" in data["certificate_pem"] + assert "-----BEGIN PRIVATE KEY-----" in data["private_key_pem"] + + def test_issue_multiple_certificates(self, client): + """Multiple certificates can be issued.""" + client.post("/pki/ca") + + response1 = client.post("/pki/issue", json={"common_name": "alice"}) + response2 = client.post("/pki/issue", json={"common_name": "bob"}) + + assert response1.status_code == 201 + assert response2.status_code == 201 + + data1 = response1.get_json() + data2 = response2.get_json() + + # Different serials and fingerprints + assert data1["serial"] != data2["serial"] + assert data1["fingerprint_sha1"] != data2["fingerprint_sha1"] + + +class TestCertificateListing: + """Test GET /pki/certs endpoint.""" + + def test_list_anonymous_empty(self, client): + """Anonymous users see empty list.""" + client.post("/pki/ca") + + response = client.get("/pki/certs") + assert response.status_code == 200 + data = response.get_json() + assert data["certificates"] == [] + assert data["count"] == 0 + + def test_list_authenticated_sees_own(self, client): + """Authenticated users see certificates they issued.""" + client.post("/pki/ca") + + # Issue certificate as authenticated user + issuer_fingerprint = "a" * 40 + client.post( + "/pki/issue", + json={"common_name": "alice"}, + headers={"X-SSL-Client-SHA1": issuer_fingerprint}, + ) + + # List as same user + response = client.get("/pki/certs", headers={"X-SSL-Client-SHA1": issuer_fingerprint}) + assert response.status_code == 200 + data = response.get_json() + assert data["count"] == 1 + assert data["certificates"][0]["common_name"] == "alice" + + +class TestCertificateRevocation: + """Test POST /pki/revoke/ endpoint.""" + + def test_revoke_unauthenticated_fails(self, client): + """Revocation requires authentication.""" + client.post("/pki/ca") + issue_resp = client.post("/pki/issue", json={"common_name": "alice"}) + serial = issue_resp.get_json()["serial"] + + response = client.post(f"/pki/revoke/{serial}") + assert response.status_code == 401 + + def test_revoke_unauthorized_fails(self, client): + """Revocation requires ownership.""" + client.post("/pki/ca") + + # Issue as one user + issue_resp = client.post( + "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": "a" * 40} + ) + serial = issue_resp.get_json()["serial"] + + # Try to revoke as different user + response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": "b" * 40}) + assert response.status_code == 403 + + def test_revoke_as_issuer_succeeds(self, client): + """Issuer can revoke certificate.""" + client.post("/pki/ca") + + issuer = "a" * 40 + issue_resp = client.post( + "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} + ) + serial = issue_resp.get_json()["serial"] + + response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) + assert response.status_code == 200 + assert response.get_json()["message"] == "Certificate revoked" + + def test_revoke_nonexistent_fails(self, client): + """Revoking nonexistent certificate fails.""" + client.post("/pki/ca") + + response = client.post("/pki/revoke/0" * 32, headers={"X-SSL-Client-SHA1": "a" * 40}) + assert response.status_code == 404 + + def test_revoke_twice_fails(self, client): + """Certificate cannot be revoked twice.""" + client.post("/pki/ca") + + issuer = "a" * 40 + issue_resp = client.post( + "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} + ) + serial = issue_resp.get_json()["serial"] + + # First revocation succeeds + response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) + assert response.status_code == 200 + + # Second revocation fails + response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) + assert response.status_code == 409 + + +class TestRevocationIntegration: + """Test revocation affects authentication.""" + + def test_revoked_cert_treated_as_anonymous(self, client): + """Revoked certificate is treated as anonymous.""" + client.post("/pki/ca") + + # Issue certificate + issuer = "a" * 40 + issue_resp = client.post( + "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} + ) + cert_fingerprint = issue_resp.get_json()["fingerprint_sha1"] + serial = issue_resp.get_json()["serial"] + + # Create paste as authenticated user + create_resp = client.post( + "/", data=b"test content", headers={"X-SSL-Client-SHA1": cert_fingerprint} + ) + assert create_resp.status_code == 201 + paste_id = create_resp.get_json()["id"] + assert "owner" in create_resp.get_json() + + # Revoke the certificate + client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) + + # Try to delete paste with revoked cert - should fail + delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": cert_fingerprint}) + assert delete_resp.status_code == 401 + + +class TestPKICryptoFunctions: + """Test standalone PKI cryptographic functions.""" + + def test_derive_key_consistency(self): + """Key derivation produces consistent results.""" + from app.pki import derive_key + + password = "test-password" + salt = b"x" * 32 + + key1 = derive_key(password, salt) + key2 = derive_key(password, salt) + + assert key1 == key2 + assert len(key1) == 32 + + def test_encrypt_decrypt_roundtrip(self): + """Private key encryption/decryption roundtrip.""" + from cryptography.hazmat.primitives.asymmetric import ec + + from app.pki import decrypt_private_key, encrypt_private_key + + # Generate a test key + private_key = ec.generate_private_key(ec.SECP384R1()) + password = "test-password" + + # Encrypt + encrypted, salt = encrypt_private_key(private_key, password) + + # Decrypt + decrypted = decrypt_private_key(encrypted, salt, password) + + # Verify same key + assert private_key.private_numbers() == decrypted.private_numbers() + + def test_wrong_password_fails(self): + """Decryption with wrong password fails.""" + from cryptography.hazmat.primitives.asymmetric import ec + + from app.pki import ( + InvalidPasswordError, + decrypt_private_key, + encrypt_private_key, + ) + + private_key = ec.generate_private_key(ec.SECP384R1()) + encrypted, salt = encrypt_private_key(private_key, "correct") + + with pytest.raises(InvalidPasswordError): + decrypt_private_key(encrypted, salt, "wrong") + + def test_fingerprint_calculation(self): + """Certificate fingerprint is calculated correctly.""" + from datetime import datetime, timedelta + + from cryptography import x509 + + # Minimal self-signed cert for testing + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.x509.oid import NameOID + + from app.pki import calculate_fingerprint + + key = ec.generate_private_key(ec.SECP256R1()) + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(key.public_key()) + .serial_number(1) + .not_valid_before(datetime.now(UTC)) + .not_valid_after(datetime.now(UTC) + timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + + fingerprint = calculate_fingerprint(cert) + + assert len(fingerprint) == 40 + assert all(c in "0123456789abcdef" for c in fingerprint)