From bfc238b5cfe87edbd86c4e476380c4cd6655f51e Mon Sep 17 00:00:00 2001 From: Username Date: Sat, 20 Dec 2025 20:13:00 +0100 Subject: [PATCH] add CLI enhancements and scheduled cleanup CLI commands: - list: show user's pastes with pagination - search: filter by type (glob), after/before timestamps - update: modify content, password, or extend expiry - export: save pastes to directory with optional decryption API changes: - PUT /: update paste content and metadata - GET /pastes: add type, after, before query params Scheduled tasks: - Thread-safe cleanup with per-task intervals - Activate cleanup_expired_hashes (15min) - Activate cleanup_rate_limits (5min) Tests: 205 passing --- app/api/__init__.py | 61 +++- app/api/routes.py | 344 ++++++++++++++++++++++- app/config.py | 17 ++ fpaste | 481 +++++++++++++++++++++++++++++++- tests/conftest.py | 27 +- tests/test_paste_listing.py | 335 ++++++++++++++++++++++ tests/test_paste_update.py | 242 ++++++++++++++++ tests/test_rate_limiting.py | 168 +++++++++++ tests/test_scheduled_cleanup.py | 169 +++++++++++ 9 files changed, 1826 insertions(+), 18 deletions(-) create mode 100644 tests/test_paste_listing.py create mode 100644 tests/test_paste_update.py create mode 100644 tests/test_rate_limiting.py create mode 100644 tests/test_scheduled_cleanup.py diff --git a/app/api/__init__.py b/app/api/__init__.py index 8038871..5f036f1 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -1,32 +1,65 @@ """API blueprint registration.""" +import threading import time from flask import Blueprint, current_app bp = Blueprint("api", __name__) -# Throttle cleanup to run at most once per hour -_last_cleanup = 0 -_CLEANUP_INTERVAL = 3600 # 1 hour +# Thread-safe cleanup scheduling +_cleanup_lock = threading.Lock() +_cleanup_times = { + "pastes": 0, + "hashes": 0, + "rate_limits": 0, +} +_CLEANUP_INTERVALS = { + "pastes": 3600, # 1 hour + "hashes": 900, # 15 minutes + "rate_limits": 300, # 5 minutes +} + + +def reset_cleanup_times() -> None: + """Reset cleanup timestamps. For testing only.""" + with _cleanup_lock: + for key in _cleanup_times: + _cleanup_times[key] = 0 @bp.before_request -def cleanup_expired(): - """Periodically clean up expired pastes.""" - global _last_cleanup - +def run_scheduled_cleanup(): + """Periodically run cleanup tasks on schedule.""" now = time.time() - if now - _last_cleanup < _CLEANUP_INTERVAL: - return - _last_cleanup = now + with _cleanup_lock: + # Cleanup expired pastes + if now - _cleanup_times["pastes"] >= _CLEANUP_INTERVALS["pastes"]: + _cleanup_times["pastes"] = now + from app.database import cleanup_expired_pastes - from app.database import cleanup_expired_pastes + count = cleanup_expired_pastes() + if count > 0: + current_app.logger.info(f"Cleaned up {count} expired paste(s)") - count = cleanup_expired_pastes() - if count > 0: - current_app.logger.info(f"Cleaned up {count} expired paste(s)") + # Cleanup expired content hashes + if now - _cleanup_times["hashes"] >= _CLEANUP_INTERVALS["hashes"]: + _cleanup_times["hashes"] = now + from app.database import cleanup_expired_hashes + + count = cleanup_expired_hashes() + if count > 0: + current_app.logger.info(f"Cleaned up {count} expired hash(es)") + + # Cleanup rate limit entries + if now - _cleanup_times["rate_limits"] >= _CLEANUP_INTERVALS["rate_limits"]: + _cleanup_times["rate_limits"] = now + from app.api.routes import cleanup_rate_limits + + count = cleanup_rate_limits() + if count > 0: + current_app.logger.info(f"Cleaned up {count} rate limit entr(ies)") from app.api import routes # noqa: E402, F401 diff --git a/app/api/routes.py b/app/api/routes.py index e425cde..807e0cf 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -8,7 +8,9 @@ import json import math import re import secrets +import threading import time +from collections import defaultdict from typing import TYPE_CHECKING, Any from flask import Response, current_app, g, request @@ -50,6 +52,96 @@ GENERIC_MIME_TYPES = frozenset( # Runtime PoW secret cache _pow_secret_cache: bytes | None = None +# ───────────────────────────────────────────────────────────────────────────── +# Rate Limiting (in-memory sliding window) +# ───────────────────────────────────────────────────────────────────────────── + +_rate_limit_lock = threading.Lock() +_rate_limit_requests: dict[str, list[float]] = defaultdict(list) + + +def get_client_ip() -> str: + """Get client IP address, respecting X-Forwarded-For from trusted proxy.""" + if is_trusted_proxy(): + forwarded = request.headers.get("X-Forwarded-For", "") + if forwarded: + # Take the first (client) IP from the chain + return forwarded.split(",")[0].strip() + return request.remote_addr or "unknown" + + +def check_rate_limit(client_ip: str, authenticated: bool = False) -> tuple[bool, int, int]: + """Check if request is within rate limit. + + Args: + client_ip: Client IP address + authenticated: Whether client is authenticated (higher limits) + + Returns: + Tuple of (allowed, remaining, reset_seconds) + """ + if not current_app.config.get("RATE_LIMIT_ENABLED", True): + return True, -1, 0 + + window = current_app.config["RATE_LIMIT_WINDOW"] + max_requests = current_app.config["RATE_LIMIT_MAX"] + + if authenticated: + max_requests *= current_app.config.get("RATE_LIMIT_AUTH_MULTIPLIER", 5) + + now = time.time() + cutoff = now - window + + with _rate_limit_lock: + # Clean old requests and get current list + requests = _rate_limit_requests[client_ip] + requests[:] = [t for t in requests if t > cutoff] + + current_count = len(requests) + + if current_count >= max_requests: + # Calculate reset time (when oldest request expires) + reset_at = int(requests[0] + window - now) + 1 if requests else window + return False, 0, reset_at + + # Record this request + requests.append(now) + remaining = max_requests - len(requests) + + return True, remaining, window + + +def cleanup_rate_limits(window: int | None = None) -> int: + """Remove expired rate limit entries. Returns count of cleaned entries. + + Args: + window: Rate limit window in seconds. If None, uses app config. + """ + # This should be called periodically (e.g., via cleanup task) + if window is None: + window = current_app.config.get("RATE_LIMIT_WINDOW", 60) + cutoff = time.time() - window + + cleaned = 0 + with _rate_limit_lock: + to_remove = [] + for ip, requests in _rate_limit_requests.items(): + requests[:] = [t for t in requests if t > cutoff] + if not requests: + to_remove.append(ip) + + for ip in to_remove: + del _rate_limit_requests[ip] + cleaned += 1 + + return cleaned + + +def reset_rate_limits() -> None: + """Clear all rate limit state. For testing only.""" + with _rate_limit_lock: + _rate_limit_requests.clear() + # ───────────────────────────────────────────────────────────────────────────── # Response Helpers @@ -379,6 +471,7 @@ class IndexView(MethodView): f"GET {prefixed_url('/client')}": "Download CLI client", f"GET {prefixed_url('/challenge')}": "Get PoW challenge", f"POST {prefixed_url('/')}": "Create paste", + f"GET {prefixed_url('/pastes')}": "List your pastes (auth required)", f"GET {prefixed_url('/')}": "Retrieve paste metadata", f"GET {prefixed_url('//raw')}": "Retrieve raw paste content", f"DELETE {prefixed_url('/')}": "Delete paste", @@ -413,6 +506,22 @@ class IndexView(MethodView): owner = get_client_id() + # Rate limiting (check before expensive operations) + client_ip = get_client_ip() + allowed, _remaining, reset_seconds = check_rate_limit(client_ip, authenticated=bool(owner)) + + if not allowed: + current_app.logger.warning("Rate limit exceeded: ip=%s auth=%s", client_ip, bool(owner)) + response = error_response( + "Rate limit exceeded", + 429, + retry_after=reset_seconds, + ) + response.headers["Retry-After"] = str(reset_seconds) + response.headers["X-RateLimit-Remaining"] = "0" + response.headers["X-RateLimit-Reset"] = str(reset_seconds) + return response + # Proof-of-work verification difficulty = current_app.config["POW_DIFFICULTY"] if difficulty > 0: @@ -681,6 +790,125 @@ class PasteView(MethodView): """Return paste metadata headers only.""" return self.get(paste_id) + def put(self, paste_id: str) -> Response: + """Update paste content and/or metadata. + + Requires authentication and ownership. + + Content update: Send raw body with Content-Type header + Metadata update: Use headers with empty body + + Headers: + - X-Paste-Password: Set/change password + - X-Remove-Password: true to remove password + - X-Extend-Expiry: Seconds to add to current expiry + """ + # Validate paste ID format + if err := validate_paste_id(paste_id): + return err + if err := require_auth(): + return err + + db = get_db() + + # Fetch current paste + row = db.execute( + """SELECT id, owner, content, mime_type, expires_at, password_hash + FROM pastes WHERE id = ?""", + (paste_id,), + ).fetchone() + + if row is None: + return error_response("Paste not found", 404) + + if row["owner"] != g.client_id: + return error_response("Permission denied", 403) + + # Check for burn-after-read (cannot update) + burn_check = db.execute( + "SELECT burn_after_read FROM pastes WHERE id = ?", (paste_id,) + ).fetchone() + if burn_check and burn_check["burn_after_read"]: + return error_response("Cannot update burn-after-read paste", 400) + + # Parse update parameters + new_password = request.headers.get("X-Paste-Password", "").strip() or None + remove_password = request.headers.get("X-Remove-Password", "").lower() in ( + "true", + "1", + "yes", + ) + extend_expiry_str = request.headers.get("X-Extend-Expiry", "").strip() + + # Prepare update fields + update_fields = [] + update_params: list[Any] = [] + + # Content update (if body provided) + content = request.get_data() + if content: + mime_type = request.content_type or "application/octet-stream" + # Sanitize MIME type + if not MIME_PATTERN.match(mime_type.split(";")[0].strip()): + mime_type = "application/octet-stream" + + update_fields.append("content = ?") + update_params.append(content) + update_fields.append("mime_type = ?") + update_params.append(mime_type.split(";")[0].strip()) + + # Password update + if remove_password: + update_fields.append("password_hash = NULL") + elif new_password: + update_fields.append("password_hash = ?") + update_params.append(hash_password(new_password)) + + # Expiry extension + if extend_expiry_str: + try: + extend_seconds = int(extend_expiry_str) + if extend_seconds > 0: + current_expiry = row["expires_at"] + if current_expiry: + new_expiry = current_expiry + extend_seconds + else: + # If no expiry set, create one from now + new_expiry = int(time.time()) + extend_seconds + update_fields.append("expires_at = ?") + update_params.append(new_expiry) + except ValueError: + return error_response("Invalid X-Extend-Expiry value", 400) + + if not update_fields: + return error_response("No updates provided", 400) + + # Execute update (fields are hardcoded strings, safe from injection) + update_sql = f"UPDATE pastes SET {', '.join(update_fields)} WHERE id = ?" # noqa: S608 + update_params.append(paste_id) + db.execute(update_sql, update_params) + db.commit() + + # Fetch updated paste for response + updated = db.execute( + """SELECT id, mime_type, length(content) as size, expires_at, + CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected + FROM pastes WHERE id = ?""", + (paste_id,), + ).fetchone() + + response_data: dict[str, Any] = { + "id": updated["id"], + "size": updated["size"], + "mime_type": updated["mime_type"], + } + if updated["expires_at"]: + response_data["expires_at"] = updated["expires_at"] + if updated["password_protected"]: + response_data["password_protected"] = True + + return json_response(response_data) + class PasteRawView(MethodView): """Raw paste content retrieval.""" @@ -759,6 +987,119 @@ class PasteDeleteView(MethodView): return json_response({"message": "Paste deleted"}) +class PastesListView(MethodView): + """List authenticated user's pastes (privacy-focused).""" + + def get(self) -> Response: + """List pastes owned by authenticated user. + + Privacy guarantees: + - Requires authentication (mTLS client certificate) + - Users can ONLY see their own pastes + - No admin bypass or cross-user visibility + - Content is never returned, only metadata + + Query parameters: + - limit: max results (default 50, max 200) + - offset: pagination offset (default 0) + - type: filter by MIME type (glob pattern, e.g., "image/*") + - after: filter by created_at >= timestamp + - before: filter by created_at <= timestamp + """ + import fnmatch + + # Strict authentication requirement + if err := require_auth(): + return err + + client_id = g.client_id + + # Parse pagination parameters + try: + limit = min(int(request.args.get("limit", 50)), 200) + offset = max(int(request.args.get("offset", 0)), 0) + except (ValueError, TypeError): + limit, offset = 50, 0 + + # Parse filter parameters + type_filter = request.args.get("type", "").strip() + try: + after_ts = int(request.args.get("after", 0)) + except (ValueError, TypeError): + after_ts = 0 + try: + before_ts = int(request.args.get("before", 0)) + except (ValueError, TypeError): + before_ts = 0 + + db = get_db() + + # Build query with filters + where_clauses = ["owner = ?"] + params: list[Any] = [client_id] + + if after_ts > 0: + where_clauses.append("created_at >= ?") + params.append(after_ts) + if before_ts > 0: + where_clauses.append("created_at <= ?") + params.append(before_ts) + + where_sql = " AND ".join(where_clauses) + + # Count total pastes matching filters (where_sql is safe, built from constants) + count_row = db.execute( + f"SELECT COUNT(*) as total FROM pastes WHERE {where_sql}", # noqa: S608 + params, + ).fetchone() + total = count_row["total"] if count_row else 0 + + # Fetch pastes with metadata only (where_sql is safe, built from constants) + rows = db.execute( + f"""SELECT id, mime_type, length(content) as size, created_at, + last_accessed, burn_after_read, expires_at, + CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected + FROM pastes + WHERE {where_sql} + ORDER BY created_at DESC + LIMIT ? OFFSET ?""", # noqa: S608 + [*params, limit, offset], + ).fetchall() + + # Apply MIME type filter (glob pattern matching done in Python for flexibility) + if type_filter: + rows = [r for r in rows if fnmatch.fnmatch(r["mime_type"], type_filter)] + + pastes = [] + for row in rows: + paste: dict[str, Any] = { + "id": row["id"], + "mime_type": row["mime_type"], + "size": row["size"], + "created_at": row["created_at"], + "last_accessed": row["last_accessed"], + "url": f"/{row['id']}", + "raw": f"/{row['id']}/raw", + } + if row["burn_after_read"]: + paste["burn_after_read"] = True + if row["expires_at"]: + paste["expires_at"] = row["expires_at"] + if row["password_protected"]: + paste["password_protected"] = True + pastes.append(paste) + + return json_response( + { + "pastes": pastes, + "count": len(pastes), + "total": total, + "limit": limit, + "offset": offset, + } + ) + + # ───────────────────────────────────────────────────────────────────────────── # PKI Views (Certificate Authority) # ───────────────────────────────────────────────────────────────────────────── @@ -1060,7 +1401,8 @@ bp.add_url_rule("/challenge", view_func=ChallengeView.as_view("challenge")) bp.add_url_rule("/client", view_func=ClientView.as_view("client")) # Paste operations -bp.add_url_rule("/", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD"]) +bp.add_url_rule("/pastes", view_func=PastesListView.as_view("pastes_list")) +bp.add_url_rule("/", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD", "PUT"]) bp.add_url_rule( "//raw", view_func=PasteRawView.as_view("paste_raw"), methods=["GET", "HEAD"] ) diff --git a/app/config.py b/app/config.py index b3afad9..1e26170 100644 --- a/app/config.py +++ b/app/config.py @@ -67,6 +67,18 @@ class Config: # URL prefix for reverse proxy deployments (e.g., "/paste" for mymx.me/paste) URL_PREFIX = os.environ.get("FLASKPASTE_URL_PREFIX", "").rstrip("/") + # IP-based rate limiting + # Limits paste creation per IP address using sliding window + RATE_LIMIT_ENABLED = os.environ.get("FLASKPASTE_RATE_LIMIT", "1").lower() in ( + "1", + "true", + "yes", + ) + RATE_LIMIT_WINDOW = int(os.environ.get("FLASKPASTE_RATE_WINDOW", "60")) # seconds + RATE_LIMIT_MAX = int(os.environ.get("FLASKPASTE_RATE_MAX", "10")) # requests per window + # Authenticated users get higher limits (multiplier) + RATE_LIMIT_AUTH_MULTIPLIER = int(os.environ.get("FLASKPASTE_RATE_AUTH_MULT", "5")) + # PKI Configuration # Enable PKI endpoints for certificate authority and issuance PKI_ENABLED = os.environ.get("FLASKPASTE_PKI_ENABLED", "0").lower() in ("1", "true", "yes") @@ -103,6 +115,11 @@ class TestingConfig(Config): # Disable PoW for most tests (easier testing) POW_DIFFICULTY = 0 + # Relaxed rate limiting for tests + RATE_LIMIT_ENABLED = True + RATE_LIMIT_WINDOW = 1 + RATE_LIMIT_MAX = 100 + # PKI testing configuration PKI_ENABLED = True PKI_CA_PASSWORD = "test-ca-password" diff --git a/fpaste b/fpaste index fccf739..73a32ce 100755 --- a/fpaste +++ b/fpaste @@ -385,6 +385,425 @@ def cmd_info(args, config): die("failed to connect to server") +def format_size(size): + """Format byte size as human-readable string.""" + if size < 1024: + return f"{size}B" + elif size < 1024 * 1024: + return f"{size / 1024:.1f}K" + else: + return f"{size / (1024 * 1024):.1f}M" + + +def format_timestamp(ts): + """Format Unix timestamp as human-readable date.""" + from datetime import datetime + + dt = datetime.fromtimestamp(ts, tz=UTC) + return dt.strftime("%Y-%m-%d %H:%M") + + +def cmd_list(args, config): + """List user's pastes.""" + if not config["cert_sha1"]: + die("authentication required (set FLASKPASTE_CERT_SHA1)") + + base = config["server"].rstrip("/") + params = [] + if args.limit: + params.append(f"limit={args.limit}") + if args.offset: + params.append(f"offset={args.offset}") + + url = f"{base}/pastes" + if params: + url += "?" + "&".join(params) + + headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} + status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + + if status == 401: + die("authentication failed") + elif status != 200: + die(f"failed to list pastes ({status})") + + data = json.loads(body) + pastes = data.get("pastes", []) + + if args.json: + print(json.dumps(data, indent=2)) + return + + if not pastes: + print("no pastes found") + return + + # Print header + print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS") + + for p in pastes: + paste_id = p["id"] + mime_type = p.get("mime_type", "unknown")[:16] + size = format_size(p.get("size", 0)) + created = format_timestamp(p.get("created_at", 0)) + + flags = [] + if p.get("burn_after_read"): + flags.append("burn") + if p.get("password_protected"): + flags.append("pass") + if p.get("expires_at"): + flags.append("exp") + + flag_str = " ".join(flags) + print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}") + + # Print summary + print(f"\n{data.get('count', 0)} of {data.get('total', 0)} pastes shown") + + +def parse_date(date_str): + """Parse date string to Unix timestamp.""" + from datetime import datetime + + if not date_str: + return 0 + + # Try various formats + formats = [ + "%Y-%m-%d", + "%Y-%m-%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M:%SZ", + ] + for fmt in formats: + try: + dt = datetime.strptime(date_str, fmt) + dt = dt.replace(tzinfo=UTC) + return int(dt.timestamp()) + except ValueError: + continue + + # Try as Unix timestamp + try: + return int(date_str) + except ValueError: + pass + + die(f"invalid date format: {date_str}") + + +def cmd_search(args, config): + """Search user's pastes.""" + if not config["cert_sha1"]: + die("authentication required (set FLASKPASTE_CERT_SHA1)") + + base = config["server"].rstrip("/") + params = [] + + if args.type: + params.append(f"type={args.type}") + if args.after: + ts = parse_date(args.after) + params.append(f"after={ts}") + if args.before: + ts = parse_date(args.before) + params.append(f"before={ts}") + if args.limit: + params.append(f"limit={args.limit}") + + url = f"{base}/pastes" + if params: + url += "?" + "&".join(params) + + headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} + status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + + if status == 401: + die("authentication failed") + elif status != 200: + die(f"failed to search pastes ({status})") + + data = json.loads(body) + pastes = data.get("pastes", []) + + if args.json: + print(json.dumps(data, indent=2)) + return + + if not pastes: + print("no matching pastes found") + return + + # Print header + print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS") + + for p in pastes: + paste_id = p["id"] + mime_type = p.get("mime_type", "unknown")[:16] + size = format_size(p.get("size", 0)) + created = format_timestamp(p.get("created_at", 0)) + + flags = [] + if p.get("burn_after_read"): + flags.append("burn") + if p.get("password_protected"): + flags.append("pass") + if p.get("expires_at"): + flags.append("exp") + + flag_str = " ".join(flags) + print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}") + + # Print summary + print(f"\n{data.get('count', 0)} matching pastes found") + + +def cmd_update(args, config): + """Update an existing paste.""" + if not config["cert_sha1"]: + die("authentication required (set FLASKPASTE_CERT_SHA1)") + + paste_id = args.id.split("/")[-1] # Handle full URLs + if "#" in paste_id: + paste_id = paste_id.split("#")[0] # Remove key fragment + + base = config["server"].rstrip("/") + url = f"{base}/{paste_id}" + + headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} + content = None + + # Read content from file if provided + if args.file: + if args.file == "-": + content = sys.stdin.buffer.read() + else: + path = Path(args.file) + if not path.exists(): + die(f"file not found: {args.file}") + content = path.read_bytes() + + if not content: + die("empty content") + + # Encrypt if requested (default is to encrypt) + if not getattr(args, "no_encrypt", False): + if not HAS_CRYPTO: + die("encryption requires 'cryptography' package (use -E to disable)") + if not args.quiet: + print("encrypting...", end="", file=sys.stderr) + content, encryption_key = encrypt_content(content) + if not args.quiet: + print(" done", file=sys.stderr) + else: + encryption_key = None + + # Set metadata update headers + if args.password: + headers["X-Paste-Password"] = args.password + if args.remove_password: + headers["X-Remove-Password"] = "true" + if args.expiry: + headers["X-Extend-Expiry"] = str(args.expiry) + + # Make request + status, body, _ = request( + url, method="PUT", data=content, headers=headers, ssl_context=config.get("ssl_context") + ) + + if status == 200: + data = json.loads(body) + if args.quiet: + print(paste_id) + else: + print(f"updated: {paste_id}") + print(f" size: {data.get('size', 'unknown')}") + print(f" type: {data.get('mime_type', 'unknown')}") + if data.get("expires_at"): + print(f" expires: {data.get('expires_at')}") + if data.get("password_protected"): + print(" password: protected") + + # Show new encryption key if content was updated and encrypted + if content and "encryption_key" in dir() and encryption_key: + key_fragment = "#" + encode_key(encryption_key) + print(f" key: {base}/{paste_id}{key_fragment}") + elif status == 400: + try: + err = json.loads(body).get("error", "bad request") + except (json.JSONDecodeError, UnicodeDecodeError): + err = "bad request" + die(err) + elif status == 401: + die("authentication failed") + elif status == 403: + die("permission denied (not owner)") + elif status == 404: + die(f"not found: {paste_id}") + else: + die(f"update failed ({status})") + + +def cmd_export(args, config): + """Export user's pastes to a directory.""" + if not config["cert_sha1"]: + die("authentication required (set FLASKPASTE_CERT_SHA1)") + + base = config["server"].rstrip("/") + out_dir = Path(args.output) if args.output else Path("fpaste-export") + + # Create output directory + out_dir.mkdir(parents=True, exist_ok=True) + + # Load key file if provided + keys = {} + if args.keyfile: + keyfile_path = Path(args.keyfile) + if not keyfile_path.exists(): + die(f"key file not found: {args.keyfile}") + for line in keyfile_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + paste_id, key_encoded = line.split("=", 1) + keys[paste_id.strip()] = key_encoded.strip() + + # Fetch paste list + headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} + url = f"{base}/pastes?limit=1000" # Fetch all pastes + status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + + if status == 401: + die("authentication failed") + elif status != 200: + die(f"failed to list pastes ({status})") + + data = json.loads(body) + pastes = data.get("pastes", []) + + if not pastes: + print("no pastes to export") + return + + # Export each paste + exported = 0 + skipped = 0 + errors = 0 + manifest = [] + + for p in pastes: + paste_id = p["id"] + mime_type = p.get("mime_type", "application/octet-stream") + + if not args.quiet: + print(f"exporting {paste_id}...", end=" ", file=sys.stderr) + + # Skip burn-after-read pastes + if p.get("burn_after_read"): + if not args.quiet: + print("skipped (burn-after-read)", file=sys.stderr) + skipped += 1 + continue + + # Fetch raw content + raw_url = f"{base}/{paste_id}/raw" + req_headers = dict(headers) + if p.get("password_protected"): + if not args.quiet: + print("skipped (password-protected)", file=sys.stderr) + skipped += 1 + continue + + ssl_ctx = config.get("ssl_context") + status, content, _ = request(raw_url, headers=req_headers, ssl_context=ssl_ctx) + + if status != 200: + if not args.quiet: + print(f"error ({status})", file=sys.stderr) + errors += 1 + continue + + # Decrypt if key available + decrypted = False + if paste_id in keys: + try: + key = decode_key(keys[paste_id]) + content = decrypt_content(content, key) + decrypted = True + except SystemExit: + # Decryption failed, keep encrypted content + if not args.quiet: + print("decryption failed, keeping encrypted", file=sys.stderr, end=" ") + + # Determine file extension from MIME type + ext = get_extension_for_mime(mime_type) + filename = f"{paste_id}{ext}" + filepath = out_dir / filename + + # Write content + filepath.write_bytes(content) + + # Add to manifest + manifest.append( + { + "id": paste_id, + "filename": filename, + "mime_type": mime_type, + "size": len(content), + "created_at": p.get("created_at"), + "decrypted": decrypted, + "encrypted": paste_id in keys and not decrypted, + } + ) + + if not args.quiet: + status_msg = "decrypted" if decrypted else ("encrypted" if paste_id in keys else "ok") + print(status_msg, file=sys.stderr) + + exported += 1 + + # Write manifest + if args.manifest: + manifest_path = out_dir / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2)) + if not args.quiet: + print(f"manifest: {manifest_path}", file=sys.stderr) + + # Summary + print(f"\nexported: {exported}, skipped: {skipped}, errors: {errors}") + print(f"output: {out_dir}") + + +def get_extension_for_mime(mime_type): + """Get file extension for MIME type.""" + mime_map = { + "text/plain": ".txt", + "text/html": ".html", + "text/css": ".css", + "text/javascript": ".js", + "text/markdown": ".md", + "text/x-python": ".py", + "application/json": ".json", + "application/xml": ".xml", + "application/javascript": ".js", + "application/octet-stream": ".bin", + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/svg+xml": ".svg", + "application/pdf": ".pdf", + "application/zip": ".zip", + "application/gzip": ".gz", + "application/x-tar": ".tar", + } + return mime_map.get(mime_type, ".bin") + + def cmd_pki_status(args, config): """Show PKI status and CA information.""" url = config["server"].rstrip("/") + "/pki" @@ -752,7 +1171,28 @@ def is_file_path(arg): def main(): # Pre-process arguments: if first positional looks like a file, insert "create" args_to_parse = sys.argv[1:] - commands = {"create", "c", "new", "get", "g", "delete", "d", "rm", "info", "i", "cert", "pki"} + commands = { + "create", + "c", + "new", + "get", + "g", + "delete", + "d", + "rm", + "info", + "i", + "list", + "ls", + "search", + "s", + "find", + "update", + "u", + "export", + "cert", + "pki", + } # Find insertion point for "create" command insert_pos = 0 @@ -823,6 +1263,37 @@ def main(): # info subparsers.add_parser("info", aliases=["i"], help="show server info") + # list + p_list = subparsers.add_parser("list", aliases=["ls"], help="list your pastes") + p_list.add_argument("-l", "--limit", type=int, metavar="N", help="max pastes (default: 50)") + p_list.add_argument("-o", "--offset", type=int, metavar="N", help="skip first N pastes") + p_list.add_argument("--json", action="store_true", help="output as JSON") + + # search + p_search = subparsers.add_parser("search", aliases=["s", "find"], help="search your pastes") + p_search.add_argument("-t", "--type", metavar="PATTERN", help="filter by MIME type (image/*)") + p_search.add_argument("--after", metavar="DATE", help="created after (YYYY-MM-DD or timestamp)") + p_search.add_argument("--before", metavar="DATE", help="created before (YYYY-MM-DD)") + p_search.add_argument("-l", "--limit", type=int, metavar="N", help="max results (default: 50)") + p_search.add_argument("--json", action="store_true", help="output as JSON") + + # update + p_update = subparsers.add_parser("update", aliases=["u"], help="update existing paste") + p_update.add_argument("id", help="paste ID or URL") + p_update.add_argument("file", nargs="?", help="new content (- for stdin)") + p_update.add_argument("-E", "--no-encrypt", action="store_true", help="disable encryption") + p_update.add_argument("-p", "--password", metavar="PASS", help="set/change password") + p_update.add_argument("--remove-password", action="store_true", help="remove password") + p_update.add_argument("-x", "--expiry", type=int, metavar="SEC", help="extend expiry (seconds)") + p_update.add_argument("-q", "--quiet", action="store_true", help="minimal output") + + # export + p_export = subparsers.add_parser("export", help="export all pastes to directory") + p_export.add_argument("-o", "--output", metavar="DIR", help="output directory") + p_export.add_argument("-k", "--keyfile", metavar="FILE", help="key file (paste_id=key format)") + p_export.add_argument("--manifest", action="store_true", help="write manifest.json") + p_export.add_argument("-q", "--quiet", action="store_true", help="minimal output") + # cert p_cert = subparsers.add_parser("cert", help="generate client certificate") p_cert.add_argument("-o", "--output", metavar="DIR", help="output directory") @@ -904,6 +1375,14 @@ def main(): cmd_delete(args, config) elif args.command in ("info", "i"): cmd_info(args, config) + elif args.command in ("list", "ls"): + cmd_list(args, config) + elif args.command in ("search", "s", "find"): + cmd_search(args, config) + elif args.command in ("update", "u"): + cmd_update(args, config) + elif args.command == "export": + cmd_export(args, config) elif args.command == "cert": cmd_cert(args, config) elif args.command == "pki": diff --git a/tests/conftest.py b/tests/conftest.py index bdb1711..b8f8250 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,14 +2,37 @@ import pytest +import app.database as db_module from app import create_app +from app.api.routes import reset_rate_limits + + +def _clear_database(): + """Clear all data from database tables for test isolation.""" + if db_module._memory_db_holder is not None: + db_module._memory_db_holder.execute("DELETE FROM pastes") + db_module._memory_db_holder.execute("DELETE FROM content_hashes") + db_module._memory_db_holder.execute("DELETE FROM issued_certificates") + db_module._memory_db_holder.execute("DELETE FROM certificate_authority") + db_module._memory_db_holder.commit() @pytest.fixture def app(): """Create application for testing.""" - app = create_app("testing") - yield app + # Reset global state for test isolation + reset_rate_limits() + _clear_database() + + test_app = create_app("testing") + + # Clear database again after app init (in case init added anything) + _clear_database() + + yield test_app + + # Cleanup after test + reset_rate_limits() @pytest.fixture diff --git a/tests/test_paste_listing.py b/tests/test_paste_listing.py new file mode 100644 index 0000000..3492f30 --- /dev/null +++ b/tests/test_paste_listing.py @@ -0,0 +1,335 @@ +"""Tests for paste listing endpoint (GET /pastes).""" + +import json + + +class TestPastesListEndpoint: + """Tests for GET /pastes endpoint.""" + + def test_list_pastes_requires_auth(self, client): + """List pastes requires authentication.""" + response = client.get("/pastes") + assert response.status_code == 401 + data = json.loads(response.data) + assert "error" in data + + def test_list_pastes_empty(self, client, auth_header): + """List pastes returns empty when user has no pastes.""" + response = client.get("/pastes", headers=auth_header) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["pastes"] == [] + assert data["count"] == 0 + assert data["total"] == 0 + + def test_list_pastes_returns_own_pastes(self, client, sample_text, auth_header): + """List pastes returns only user's own pastes.""" + # Create a paste + create = client.post( + "/", + data=sample_text, + content_type="text/plain", + headers=auth_header, + ) + paste_id = json.loads(create.data)["id"] + + # List pastes + response = client.get("/pastes", headers=auth_header) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["count"] == 1 + assert data["total"] == 1 + assert data["pastes"][0]["id"] == paste_id + + def test_list_pastes_excludes_others(self, client, sample_text, auth_header, other_auth_header): + """List pastes does not include other users' pastes.""" + # Create paste as user A + client.post( + "/", + data=sample_text, + content_type="text/plain", + headers=auth_header, + ) + + # List pastes as user B + response = client.get("/pastes", headers=other_auth_header) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["count"] == 0 + assert data["total"] == 0 + + def test_list_pastes_excludes_anonymous(self, client, sample_text, auth_header): + """List pastes does not include anonymous pastes.""" + # Create anonymous paste + client.post("/", data=sample_text, content_type="text/plain") + + # List pastes as authenticated user + response = client.get("/pastes", headers=auth_header) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["count"] == 0 + + def test_list_pastes_metadata_only(self, client, sample_text, auth_header): + """List pastes returns metadata, not content.""" + # Create a paste + client.post( + "/", + data=sample_text, + content_type="text/plain", + headers=auth_header, + ) + + # List pastes + response = client.get("/pastes", headers=auth_header) + data = json.loads(response.data) + paste = data["pastes"][0] + + # Verify metadata fields + assert "id" in paste + assert "mime_type" in paste + assert "size" in paste + assert "created_at" in paste + assert "last_accessed" in paste + assert "url" in paste + assert "raw" in paste + + # Verify content is NOT included + assert "content" not in paste + + def test_list_pastes_pagination(self, client, auth_header): + """List pastes supports pagination.""" + # Create multiple pastes + for i in range(5): + client.post( + "/", + data=f"paste {i}", + content_type="text/plain", + headers=auth_header, + ) + + # Get first page + response = client.get("/pastes?limit=2&offset=0", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 2 + assert data["total"] == 5 + assert data["limit"] == 2 + assert data["offset"] == 0 + + # Get second page + response = client.get("/pastes?limit=2&offset=2", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 2 + assert data["offset"] == 2 + + def test_list_pastes_max_limit(self, client, auth_header): + """List pastes enforces maximum limit.""" + response = client.get("/pastes?limit=500", headers=auth_header) + data = json.loads(response.data) + assert data["limit"] == 200 # Max limit enforced + + def test_list_pastes_invalid_pagination(self, client, auth_header): + """List pastes handles invalid pagination gracefully.""" + response = client.get("/pastes?limit=abc&offset=-1", headers=auth_header) + assert response.status_code == 200 + data = json.loads(response.data) + # Should use defaults + assert data["limit"] == 50 + assert data["offset"] == 0 + + def test_list_pastes_includes_special_fields(self, client, auth_header): + """List pastes includes burn_after_read, expires_at, password_protected.""" + # Create paste with burn-after-read + client.post( + "/", + data="burn test", + content_type="text/plain", + headers={**auth_header, "X-Burn-After-Read": "true"}, + ) + + response = client.get("/pastes", headers=auth_header) + data = json.loads(response.data) + paste = data["pastes"][0] + assert paste.get("burn_after_read") is True + + def test_list_pastes_ordered_by_created_at(self, client, auth_header): + """List pastes returns all created pastes ordered by created_at DESC.""" + # Create pastes + ids = set() + for i in range(3): + create = client.post( + "/", + data=f"paste {i}", + content_type="text/plain", + headers=auth_header, + ) + ids.add(json.loads(create.data)["id"]) + + response = client.get("/pastes", headers=auth_header) + data = json.loads(response.data) + + # All created pastes should be present + returned_ids = {p["id"] for p in data["pastes"]} + assert returned_ids == ids + assert data["count"] == 3 + + +class TestPastesPrivacy: + """Privacy-focused tests for paste listing.""" + + def test_cannot_see_other_user_pastes( + self, client, sample_text, auth_header, other_auth_header + ): + """Users cannot see pastes owned by others.""" + # User A creates paste + create = client.post( + "/", + data=sample_text, + content_type="text/plain", + headers=auth_header, + ) + paste_id = json.loads(create.data)["id"] + + # User B lists pastes - should not see A's paste + response = client.get("/pastes", headers=other_auth_header) + data = json.loads(response.data) + paste_ids = [p["id"] for p in data["pastes"]] + assert paste_id not in paste_ids + + def test_no_admin_bypass(self, client, sample_text, auth_header): + """No special admin access to list all pastes.""" + # Create paste as regular user + client.post( + "/", + data=sample_text, + content_type="text/plain", + headers=auth_header, + ) + + # Different user cannot see it, regardless of auth header format + admin_header = {"X-SSL-Client-SHA1": "0" * 40} + response = client.get("/pastes", headers=admin_header) + data = json.loads(response.data) + assert data["count"] == 0 + + def test_content_never_exposed(self, client, auth_header): + """Paste content is never exposed in listing.""" + secret = "super secret content that should never be exposed" + client.post( + "/", + data=secret, + content_type="text/plain", + headers=auth_header, + ) + + response = client.get("/pastes", headers=auth_header) + # Content should not appear anywhere in response + assert secret.encode() not in response.data + + +class TestPastesSearch: + """Tests for paste search parameters.""" + + def test_search_by_type_exact(self, client, auth_header, png_bytes): + """Search pastes by exact MIME type.""" + # Create text paste + client.post("/", data="text content", content_type="text/plain", headers=auth_header) + # Create image paste + create = client.post("/", data=png_bytes, content_type="image/png", headers=auth_header) + png_id = json.loads(create.data)["id"] + + # Search for image/png + response = client.get("/pastes?type=image/png", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 1 + assert data["pastes"][0]["id"] == png_id + + def test_search_by_type_glob(self, client, auth_header, png_bytes, jpeg_bytes): + """Search pastes by MIME type glob pattern.""" + # Create text paste + client.post("/", data="text content", content_type="text/plain", headers=auth_header) + # Create image pastes + client.post("/", data=png_bytes, content_type="image/png", headers=auth_header) + client.post("/", data=jpeg_bytes, content_type="image/jpeg", headers=auth_header) + + # Search for all images + response = client.get("/pastes?type=image/*", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 2 + for paste in data["pastes"]: + assert paste["mime_type"].startswith("image/") + + def test_search_by_after_timestamp(self, client, auth_header): + """Search pastes created after timestamp.""" + + # Create paste + create = client.post("/", data="test", content_type="text/plain", headers=auth_header) + paste_data = json.loads(create.data) + created_at = paste_data["created_at"] + + # Search for pastes after creation time (should find it) + response = client.get(f"/pastes?after={created_at - 1}", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 1 + + # Search for pastes after creation time + 1 (should not find it) + response = client.get(f"/pastes?after={created_at + 1}", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 0 + + def test_search_by_before_timestamp(self, client, auth_header): + """Search pastes created before timestamp.""" + + # Create paste + create = client.post("/", data="test", content_type="text/plain", headers=auth_header) + paste_data = json.loads(create.data) + created_at = paste_data["created_at"] + + # Search for pastes before creation time + 1 (should find it) + response = client.get(f"/pastes?before={created_at + 1}", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 1 + + # Search for pastes before creation time (should not find it) + response = client.get(f"/pastes?before={created_at - 1}", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 0 + + def test_search_combined_filters(self, client, auth_header, png_bytes): + """Search with multiple filters combined.""" + + # Create text paste + client.post("/", data="text", content_type="text/plain", headers=auth_header) + # Create image paste + create = client.post("/", data=png_bytes, content_type="image/png", headers=auth_header) + png_data = json.loads(create.data) + created_at = png_data["created_at"] + + # Search for images after a certain time + response = client.get( + f"/pastes?type=image/*&after={created_at - 1}", + headers=auth_header, + ) + data = json.loads(response.data) + assert data["count"] == 1 + assert data["pastes"][0]["mime_type"] == "image/png" + + def test_search_no_matches(self, client, auth_header): + """Search with no matching results.""" + # Create text paste + client.post("/", data="text", content_type="text/plain", headers=auth_header) + + # Search for video (no matches) + response = client.get("/pastes?type=video/*", headers=auth_header) + data = json.loads(response.data) + assert data["count"] == 0 + assert data["pastes"] == [] + + def test_search_invalid_timestamp(self, client, auth_header): + """Search with invalid timestamp uses default.""" + client.post("/", data="test", content_type="text/plain", headers=auth_header) + + # Invalid timestamp should be ignored + response = client.get("/pastes?after=invalid", headers=auth_header) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["count"] == 1 diff --git a/tests/test_paste_update.py b/tests/test_paste_update.py new file mode 100644 index 0000000..0ad9645 --- /dev/null +++ b/tests/test_paste_update.py @@ -0,0 +1,242 @@ +"""Tests for paste update endpoint (PUT /).""" + +import json + + +class TestPasteUpdateEndpoint: + """Tests for PUT / endpoint.""" + + def test_update_requires_auth(self, client, sample_text, auth_header): + """Update requires authentication.""" + # Create paste + create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Try to update without auth + response = client.put(f"/{paste_id}", data="updated") + assert response.status_code == 401 + + def test_update_requires_ownership(self, client, sample_text, auth_header, other_auth_header): + """Update requires paste ownership.""" + # Create paste as user A + create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Try to update as user B + response = client.put( + f"/{paste_id}", + data="updated content", + content_type="text/plain", + headers=other_auth_header, + ) + assert response.status_code == 403 + + def test_update_content(self, client, auth_header): + """Update paste content.""" + # Create paste + create = client.post("/", data="original", content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Update content + response = client.put( + f"/{paste_id}", + data="updated content", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["size"] == len("updated content") + + # Verify content changed + raw = client.get(f"/{paste_id}/raw") + assert raw.data == b"updated content" + + def test_update_password_set(self, client, auth_header): + """Set password on paste.""" + # Create paste without password + create = client.post("/", data="content", content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Add password + response = client.put( + f"/{paste_id}", + data="", + headers={**auth_header, "X-Paste-Password": "secret123"}, + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert data.get("password_protected") is True + + # Verify password required + raw = client.get(f"/{paste_id}/raw") + assert raw.status_code == 401 + + # Verify correct password works + raw = client.get(f"/{paste_id}/raw", headers={"X-Paste-Password": "secret123"}) + assert raw.status_code == 200 + + def test_update_password_remove(self, client, auth_header): + """Remove password from paste.""" + # Create paste with password + create = client.post( + "/", + data="content", + content_type="text/plain", + headers={**auth_header, "X-Paste-Password": "secret123"}, + ) + paste_id = json.loads(create.data)["id"] + + # Remove password + response = client.put( + f"/{paste_id}", + data="", + headers={**auth_header, "X-Remove-Password": "true"}, + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert data.get("password_protected") is not True + + # Verify no password required + raw = client.get(f"/{paste_id}/raw") + assert raw.status_code == 200 + + def test_update_extend_expiry(self, client, auth_header): + """Extend paste expiry.""" + # Create paste with expiry + create = client.post( + "/", + data="content", + content_type="text/plain", + headers={**auth_header, "X-Expiry": "3600"}, + ) + paste_id = json.loads(create.data)["id"] + original_expiry = json.loads(create.data)["expires_at"] + + # Extend expiry + response = client.put( + f"/{paste_id}", + data="", + headers={**auth_header, "X-Extend-Expiry": "7200"}, + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["expires_at"] == original_expiry + 7200 + + def test_update_add_expiry(self, client, auth_header): + """Add expiry to paste without one.""" + # Create paste without expiry + create = client.post("/", data="content", content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Add expiry + response = client.put( + f"/{paste_id}", + data="", + headers={**auth_header, "X-Extend-Expiry": "3600"}, + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert "expires_at" in data + + def test_update_burn_after_read_forbidden(self, client, auth_header): + """Cannot update burn-after-read pastes.""" + # Create burn-after-read paste + create = client.post( + "/", + data="content", + content_type="text/plain", + headers={**auth_header, "X-Burn-After-Read": "true"}, + ) + paste_id = json.loads(create.data)["id"] + + # Try to update + response = client.put( + f"/{paste_id}", + data="updated", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 400 + data = json.loads(response.data) + assert "burn" in data["error"].lower() + + def test_update_not_found(self, client, auth_header): + """Update non-existent paste returns 404.""" + response = client.put( + "/000000000000", + data="updated", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 404 + + def test_update_no_changes(self, client, auth_header): + """Update with no changes returns error.""" + # Create paste + create = client.post("/", data="content", content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Try to update with empty request + response = client.put(f"/{paste_id}", data="", headers=auth_header) + assert response.status_code == 400 + data = json.loads(response.data) + assert "no updates" in data["error"].lower() + + def test_update_combined(self, client, auth_header): + """Update content and metadata together.""" + # Create paste + create = client.post("/", data="original", content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Update content and add password + response = client.put( + f"/{paste_id}", + data="new content", + content_type="text/plain", + headers={**auth_header, "X-Paste-Password": "secret"}, + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["size"] == len("new content") + assert data.get("password_protected") is True + + +class TestPasteUpdatePrivacy: + """Privacy-focused tests for paste update.""" + + def test_cannot_update_anonymous_paste(self, client, sample_text, auth_header): + """Cannot update paste without owner.""" + # Create anonymous paste + create = client.post("/", data=sample_text, content_type="text/plain") + paste_id = json.loads(create.data)["id"] + + # Try to update + response = client.put( + f"/{paste_id}", + data="updated", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 403 + + def test_cannot_update_other_user_paste( + self, client, sample_text, auth_header, other_auth_header + ): + """Cannot update paste owned by another user.""" + # Create paste as user A + create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header) + paste_id = json.loads(create.data)["id"] + + # Try to update as user B + response = client.put( + f"/{paste_id}", + data="hijacked", + content_type="text/plain", + headers=other_auth_header, + ) + assert response.status_code == 403 + + # Verify original content unchanged + raw = client.get(f"/{paste_id}/raw") + assert raw.data == sample_text.encode() diff --git a/tests/test_rate_limiting.py b/tests/test_rate_limiting.py new file mode 100644 index 0000000..219ab16 --- /dev/null +++ b/tests/test_rate_limiting.py @@ -0,0 +1,168 @@ +"""Tests for IP-based rate limiting.""" + +import json + + +class TestRateLimiting: + """Tests for rate limiting on paste creation.""" + + def test_rate_limit_allows_normal_usage(self, client, sample_text): + """Normal usage within rate limit succeeds.""" + # TestingConfig has RATE_LIMIT_MAX=100 + for i in range(5): + response = client.post( + "/", + data=f"paste {i}", + content_type="text/plain", + ) + assert response.status_code == 201 + + def test_rate_limit_exceeded_returns_429(self, client, app): + """Exceeding rate limit returns 429.""" + # Temporarily lower rate limit for test + original_max = app.config["RATE_LIMIT_MAX"] + app.config["RATE_LIMIT_MAX"] = 3 + + try: + # Make requests up to limit + for i in range(3): + response = client.post( + "/", + data=f"paste {i}", + content_type="text/plain", + ) + assert response.status_code == 201 + + # Next request should be rate limited + response = client.post( + "/", + data="one more", + content_type="text/plain", + ) + assert response.status_code == 429 + data = json.loads(response.data) + assert "error" in data + assert "Rate limit" in data["error"] + assert "retry_after" in data + finally: + app.config["RATE_LIMIT_MAX"] = original_max + + def test_rate_limit_headers(self, client, app): + """Rate limit response includes proper headers.""" + original_max = app.config["RATE_LIMIT_MAX"] + app.config["RATE_LIMIT_MAX"] = 1 + + try: + # First request succeeds + client.post("/", data="first", content_type="text/plain") + + # Second request is rate limited + response = client.post("/", data="second", content_type="text/plain") + assert response.status_code == 429 + assert "Retry-After" in response.headers + assert "X-RateLimit-Remaining" in response.headers + assert response.headers["X-RateLimit-Remaining"] == "0" + finally: + app.config["RATE_LIMIT_MAX"] = original_max + + def test_rate_limit_auth_multiplier(self, client, app, auth_header): + """Authenticated users get higher rate limits.""" + original_max = app.config["RATE_LIMIT_MAX"] + original_mult = app.config["RATE_LIMIT_AUTH_MULTIPLIER"] + app.config["RATE_LIMIT_MAX"] = 2 + app.config["RATE_LIMIT_AUTH_MULTIPLIER"] = 3 # 2 * 3 = 6 for auth users + + try: + # Authenticated user can make more requests than base limit + for i in range(5): + response = client.post( + "/", + data=f"auth {i}", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 201 + + # 6th request should succeed (limit is 2*3=6) + response = client.post( + "/", + data="auth 6", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 201 + + # 7th should fail + response = client.post( + "/", + data="auth 7", + content_type="text/plain", + headers=auth_header, + ) + assert response.status_code == 429 + finally: + app.config["RATE_LIMIT_MAX"] = original_max + app.config["RATE_LIMIT_AUTH_MULTIPLIER"] = original_mult + + def test_rate_limit_can_be_disabled(self, client, app): + """Rate limiting can be disabled via config.""" + original_enabled = app.config["RATE_LIMIT_ENABLED"] + original_max = app.config["RATE_LIMIT_MAX"] + app.config["RATE_LIMIT_ENABLED"] = False + app.config["RATE_LIMIT_MAX"] = 1 + + try: + # Should be able to make many requests + for i in range(5): + response = client.post( + "/", + data=f"paste {i}", + content_type="text/plain", + ) + assert response.status_code == 201 + finally: + app.config["RATE_LIMIT_ENABLED"] = original_enabled + app.config["RATE_LIMIT_MAX"] = original_max + + def test_rate_limit_only_affects_paste_creation(self, client, app, sample_text): + """Rate limiting only affects POST /, not GET endpoints.""" + original_max = app.config["RATE_LIMIT_MAX"] + app.config["RATE_LIMIT_MAX"] = 2 + + try: + # Create paste + create = client.post("/", data=sample_text, content_type="text/plain") + assert create.status_code == 201 + paste_id = json.loads(create.data)["id"] + + # Use up rate limit + client.post("/", data="second", content_type="text/plain") + + # Should be rate limited for creation + response = client.post("/", data="third", content_type="text/plain") + assert response.status_code == 429 + + # But GET should still work + response = client.get(f"/{paste_id}") + assert response.status_code == 200 + + response = client.get(f"/{paste_id}/raw") + assert response.status_code == 200 + + response = client.get("/health") + assert response.status_code == 200 + finally: + app.config["RATE_LIMIT_MAX"] = original_max + + +class TestRateLimitCleanup: + """Tests for rate limit cleanup.""" + + def test_rate_limit_window_expiry(self, app): + """Rate limit cleanup works with explicit window.""" + from app.api.routes import cleanup_rate_limits + + # Should work without app context when window is explicit + cleaned = cleanup_rate_limits(window=60) + assert isinstance(cleaned, int) + assert cleaned >= 0 diff --git a/tests/test_scheduled_cleanup.py b/tests/test_scheduled_cleanup.py new file mode 100644 index 0000000..0d05442 --- /dev/null +++ b/tests/test_scheduled_cleanup.py @@ -0,0 +1,169 @@ +"""Tests for scheduled cleanup functionality.""" + +import time + + +class TestScheduledCleanup: + """Tests for scheduled cleanup in before_request hook.""" + + def test_cleanup_times_reset(self, app, client): + """Verify cleanup times can be reset for testing.""" + from app.api import _cleanup_times, reset_cleanup_times + + # Set some cleanup times + with app.app_context(): + reset_cleanup_times() + for key in _cleanup_times: + assert _cleanup_times[key] == 0 + + def test_cleanup_runs_on_request(self, app, client, auth_header): + """Verify cleanup runs during request handling.""" + from app.api import _cleanup_times, reset_cleanup_times + + with app.app_context(): + reset_cleanup_times() + + # Make a request to trigger cleanup + client.get("/health") + + # Cleanup times should be set + with app.app_context(): + for key in _cleanup_times: + assert _cleanup_times[key] > 0 + + def test_cleanup_respects_intervals(self, app, client): + """Verify cleanup respects configured intervals.""" + from app.api import _cleanup_times, reset_cleanup_times + + with app.app_context(): + reset_cleanup_times() + + # First request triggers cleanup + client.get("/health") + + with app.app_context(): + first_times = {k: v for k, v in _cleanup_times.items()} + + # Immediate second request should not reset times + client.get("/health") + + with app.app_context(): + for key in _cleanup_times: + assert _cleanup_times[key] == first_times[key] + + def test_expired_paste_cleanup(self, app, client, auth_header): + """Test that expired pastes are cleaned up.""" + import json + + from app.database import get_db + + # Create paste with short expiry + response = client.post( + "/", + data="test content", + content_type="text/plain", + headers={**auth_header, "X-Expiry": "1"}, # 1 second + ) + assert response.status_code == 201 + paste_id = json.loads(response.data)["id"] + + # Wait for expiry + time.sleep(1.5) + + # Trigger cleanup via database function + with app.app_context(): + from app.database import cleanup_expired_pastes + + count = cleanup_expired_pastes() + assert count >= 1 + + # Verify paste is gone + db = get_db() + row = db.execute("SELECT id FROM pastes WHERE id = ?", (paste_id,)).fetchone() + assert row is None + + def test_rate_limit_cleanup(self, app, client): + """Test that rate limit entries are cleaned up.""" + from app.api.routes import ( + _rate_limit_requests, + check_rate_limit, + cleanup_rate_limits, + reset_rate_limits, + ) + + with app.app_context(): + reset_rate_limits() + + # Add some rate limit entries + check_rate_limit("192.168.1.1", authenticated=False) + check_rate_limit("192.168.1.2", authenticated=False) + + assert len(_rate_limit_requests) == 2 + + # Cleanup with very large window (nothing should be removed) + count = cleanup_rate_limits(window=3600) + assert count == 0 + + # Wait a bit and cleanup with tiny window + time.sleep(0.1) + count = cleanup_rate_limits(window=0) # Immediate cleanup + assert count == 2 + assert len(_rate_limit_requests) == 0 + + +class TestCleanupThreadSafety: + """Tests for thread-safety of cleanup operations.""" + + def test_cleanup_lock_exists(self, app): + """Verify cleanup lock exists.""" + import threading + + from app.api import _cleanup_lock + + assert isinstance(_cleanup_lock, type(threading.Lock())) + + def test_concurrent_cleanup_access(self, app, client): + """Test that concurrent requests don't corrupt cleanup state.""" + import threading + + from app.api import reset_cleanup_times + + with app.app_context(): + reset_cleanup_times() + + errors = [] + results = [] + + def make_request(): + try: + resp = client.get("/health") + results.append(resp.status_code) + except Exception as e: + errors.append(str(e)) + + # Simulate concurrent requests + threads = [threading.Thread(target=make_request) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert all(r == 200 for r in results) + + +class TestCleanupConfiguration: + """Tests for cleanup configuration.""" + + def test_cleanup_intervals_configured(self, app): + """Verify cleanup intervals are properly configured.""" + from app.api import _CLEANUP_INTERVALS + + assert "pastes" in _CLEANUP_INTERVALS + assert "hashes" in _CLEANUP_INTERVALS + assert "rate_limits" in _CLEANUP_INTERVALS + + # Verify reasonable intervals + assert _CLEANUP_INTERVALS["pastes"] >= 60 # At least 1 minute + assert _CLEANUP_INTERVALS["hashes"] >= 60 + assert _CLEANUP_INTERVALS["rate_limits"] >= 60