add CLI enhancements and scheduled cleanup

CLI commands:
- list: show user's pastes with pagination
- search: filter by type (glob), after/before timestamps
- update: modify content, password, or extend expiry
- export: save pastes to directory with optional decryption

API changes:
- PUT /<id>: update paste content and metadata
- GET /pastes: add type, after, before query params

Scheduled tasks:
- Thread-safe cleanup with per-task intervals
- Activate cleanup_expired_hashes (15min)
- Activate cleanup_rate_limits (5min)

Tests: 205 passing
This commit is contained in:
Username
2025-12-20 20:13:00 +01:00
parent cf31eab678
commit bfc238b5cf
9 changed files with 1826 additions and 18 deletions

View File

@@ -1,32 +1,65 @@
"""API blueprint registration."""
import threading
import time
from flask import Blueprint, current_app
bp = Blueprint("api", __name__)
# Throttle cleanup to run at most once per hour
_last_cleanup = 0
_CLEANUP_INTERVAL = 3600 # 1 hour
# Thread-safe cleanup scheduling
_cleanup_lock = threading.Lock()
_cleanup_times = {
"pastes": 0,
"hashes": 0,
"rate_limits": 0,
}
_CLEANUP_INTERVALS = {
"pastes": 3600, # 1 hour
"hashes": 900, # 15 minutes
"rate_limits": 300, # 5 minutes
}
def reset_cleanup_times() -> None:
"""Reset cleanup timestamps. For testing only."""
with _cleanup_lock:
for key in _cleanup_times:
_cleanup_times[key] = 0
@bp.before_request
def cleanup_expired():
"""Periodically clean up expired pastes."""
global _last_cleanup
def run_scheduled_cleanup():
"""Periodically run cleanup tasks on schedule."""
now = time.time()
if now - _last_cleanup < _CLEANUP_INTERVAL:
return
_last_cleanup = now
with _cleanup_lock:
# Cleanup expired pastes
if now - _cleanup_times["pastes"] >= _CLEANUP_INTERVALS["pastes"]:
_cleanup_times["pastes"] = now
from app.database import cleanup_expired_pastes
count = cleanup_expired_pastes()
if count > 0:
current_app.logger.info(f"Cleaned up {count} expired paste(s)")
# Cleanup expired content hashes
if now - _cleanup_times["hashes"] >= _CLEANUP_INTERVALS["hashes"]:
_cleanup_times["hashes"] = now
from app.database import cleanup_expired_hashes
count = cleanup_expired_hashes()
if count > 0:
current_app.logger.info(f"Cleaned up {count} expired hash(es)")
# Cleanup rate limit entries
if now - _cleanup_times["rate_limits"] >= _CLEANUP_INTERVALS["rate_limits"]:
_cleanup_times["rate_limits"] = now
from app.api.routes import cleanup_rate_limits
count = cleanup_rate_limits()
if count > 0:
current_app.logger.info(f"Cleaned up {count} rate limit entr(ies)")
from app.api import routes # noqa: E402, F401

View File

@@ -8,7 +8,9 @@ import json
import math
import re
import secrets
import threading
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from flask import Response, current_app, g, request
@@ -50,6 +52,96 @@ GENERIC_MIME_TYPES = frozenset(
# Runtime PoW secret cache
_pow_secret_cache: bytes | None = None
# ─────────────────────────────────────────────────────────────────────────────
# Rate Limiting (in-memory sliding window)
# ─────────────────────────────────────────────────────────────────────────────
_rate_limit_lock = threading.Lock()
_rate_limit_requests: dict[str, list[float]] = defaultdict(list)
def get_client_ip() -> str:
"""Get client IP address, respecting X-Forwarded-For from trusted proxy."""
if is_trusted_proxy():
forwarded = request.headers.get("X-Forwarded-For", "")
if forwarded:
# Take the first (client) IP from the chain
return forwarded.split(",")[0].strip()
return request.remote_addr or "unknown"
def check_rate_limit(client_ip: str, authenticated: bool = False) -> tuple[bool, int, int]:
"""Check if request is within rate limit.
Args:
client_ip: Client IP address
authenticated: Whether client is authenticated (higher limits)
Returns:
Tuple of (allowed, remaining, reset_seconds)
"""
if not current_app.config.get("RATE_LIMIT_ENABLED", True):
return True, -1, 0
window = current_app.config["RATE_LIMIT_WINDOW"]
max_requests = current_app.config["RATE_LIMIT_MAX"]
if authenticated:
max_requests *= current_app.config.get("RATE_LIMIT_AUTH_MULTIPLIER", 5)
now = time.time()
cutoff = now - window
with _rate_limit_lock:
# Clean old requests and get current list
requests = _rate_limit_requests[client_ip]
requests[:] = [t for t in requests if t > cutoff]
current_count = len(requests)
if current_count >= max_requests:
# Calculate reset time (when oldest request expires)
reset_at = int(requests[0] + window - now) + 1 if requests else window
return False, 0, reset_at
# Record this request
requests.append(now)
remaining = max_requests - len(requests)
return True, remaining, window
def cleanup_rate_limits(window: int | None = None) -> int:
"""Remove expired rate limit entries. Returns count of cleaned entries.
Args:
window: Rate limit window in seconds. If None, uses app config.
"""
# This should be called periodically (e.g., via cleanup task)
if window is None:
window = current_app.config.get("RATE_LIMIT_WINDOW", 60)
cutoff = time.time() - window
cleaned = 0
with _rate_limit_lock:
to_remove = []
for ip, requests in _rate_limit_requests.items():
requests[:] = [t for t in requests if t > cutoff]
if not requests:
to_remove.append(ip)
for ip in to_remove:
del _rate_limit_requests[ip]
cleaned += 1
return cleaned
def reset_rate_limits() -> None:
"""Clear all rate limit state. For testing only."""
with _rate_limit_lock:
_rate_limit_requests.clear()
# ─────────────────────────────────────────────────────────────────────────────
# Response Helpers
@@ -379,6 +471,7 @@ class IndexView(MethodView):
f"GET {prefixed_url('/client')}": "Download CLI client",
f"GET {prefixed_url('/challenge')}": "Get PoW challenge",
f"POST {prefixed_url('/')}": "Create paste",
f"GET {prefixed_url('/pastes')}": "List your pastes (auth required)",
f"GET {prefixed_url('/<id>')}": "Retrieve paste metadata",
f"GET {prefixed_url('/<id>/raw')}": "Retrieve raw paste content",
f"DELETE {prefixed_url('/<id>')}": "Delete paste",
@@ -413,6 +506,22 @@ class IndexView(MethodView):
owner = get_client_id()
# Rate limiting (check before expensive operations)
client_ip = get_client_ip()
allowed, _remaining, reset_seconds = check_rate_limit(client_ip, authenticated=bool(owner))
if not allowed:
current_app.logger.warning("Rate limit exceeded: ip=%s auth=%s", client_ip, bool(owner))
response = error_response(
"Rate limit exceeded",
429,
retry_after=reset_seconds,
)
response.headers["Retry-After"] = str(reset_seconds)
response.headers["X-RateLimit-Remaining"] = "0"
response.headers["X-RateLimit-Reset"] = str(reset_seconds)
return response
# Proof-of-work verification
difficulty = current_app.config["POW_DIFFICULTY"]
if difficulty > 0:
@@ -681,6 +790,125 @@ class PasteView(MethodView):
"""Return paste metadata headers only."""
return self.get(paste_id)
def put(self, paste_id: str) -> Response:
"""Update paste content and/or metadata.
Requires authentication and ownership.
Content update: Send raw body with Content-Type header
Metadata update: Use headers with empty body
Headers:
- X-Paste-Password: Set/change password
- X-Remove-Password: true to remove password
- X-Extend-Expiry: Seconds to add to current expiry
"""
# Validate paste ID format
if err := validate_paste_id(paste_id):
return err
if err := require_auth():
return err
db = get_db()
# Fetch current paste
row = db.execute(
"""SELECT id, owner, content, mime_type, expires_at, password_hash
FROM pastes WHERE id = ?""",
(paste_id,),
).fetchone()
if row is None:
return error_response("Paste not found", 404)
if row["owner"] != g.client_id:
return error_response("Permission denied", 403)
# Check for burn-after-read (cannot update)
burn_check = db.execute(
"SELECT burn_after_read FROM pastes WHERE id = ?", (paste_id,)
).fetchone()
if burn_check and burn_check["burn_after_read"]:
return error_response("Cannot update burn-after-read paste", 400)
# Parse update parameters
new_password = request.headers.get("X-Paste-Password", "").strip() or None
remove_password = request.headers.get("X-Remove-Password", "").lower() in (
"true",
"1",
"yes",
)
extend_expiry_str = request.headers.get("X-Extend-Expiry", "").strip()
# Prepare update fields
update_fields = []
update_params: list[Any] = []
# Content update (if body provided)
content = request.get_data()
if content:
mime_type = request.content_type or "application/octet-stream"
# Sanitize MIME type
if not MIME_PATTERN.match(mime_type.split(";")[0].strip()):
mime_type = "application/octet-stream"
update_fields.append("content = ?")
update_params.append(content)
update_fields.append("mime_type = ?")
update_params.append(mime_type.split(";")[0].strip())
# Password update
if remove_password:
update_fields.append("password_hash = NULL")
elif new_password:
update_fields.append("password_hash = ?")
update_params.append(hash_password(new_password))
# Expiry extension
if extend_expiry_str:
try:
extend_seconds = int(extend_expiry_str)
if extend_seconds > 0:
current_expiry = row["expires_at"]
if current_expiry:
new_expiry = current_expiry + extend_seconds
else:
# If no expiry set, create one from now
new_expiry = int(time.time()) + extend_seconds
update_fields.append("expires_at = ?")
update_params.append(new_expiry)
except ValueError:
return error_response("Invalid X-Extend-Expiry value", 400)
if not update_fields:
return error_response("No updates provided", 400)
# Execute update (fields are hardcoded strings, safe from injection)
update_sql = f"UPDATE pastes SET {', '.join(update_fields)} WHERE id = ?" # noqa: S608
update_params.append(paste_id)
db.execute(update_sql, update_params)
db.commit()
# Fetch updated paste for response
updated = db.execute(
"""SELECT id, mime_type, length(content) as size, expires_at,
CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected
FROM pastes WHERE id = ?""",
(paste_id,),
).fetchone()
response_data: dict[str, Any] = {
"id": updated["id"],
"size": updated["size"],
"mime_type": updated["mime_type"],
}
if updated["expires_at"]:
response_data["expires_at"] = updated["expires_at"]
if updated["password_protected"]:
response_data["password_protected"] = True
return json_response(response_data)
class PasteRawView(MethodView):
"""Raw paste content retrieval."""
@@ -759,6 +987,119 @@ class PasteDeleteView(MethodView):
return json_response({"message": "Paste deleted"})
class PastesListView(MethodView):
"""List authenticated user's pastes (privacy-focused)."""
def get(self) -> Response:
"""List pastes owned by authenticated user.
Privacy guarantees:
- Requires authentication (mTLS client certificate)
- Users can ONLY see their own pastes
- No admin bypass or cross-user visibility
- Content is never returned, only metadata
Query parameters:
- limit: max results (default 50, max 200)
- offset: pagination offset (default 0)
- type: filter by MIME type (glob pattern, e.g., "image/*")
- after: filter by created_at >= timestamp
- before: filter by created_at <= timestamp
"""
import fnmatch
# Strict authentication requirement
if err := require_auth():
return err
client_id = g.client_id
# Parse pagination parameters
try:
limit = min(int(request.args.get("limit", 50)), 200)
offset = max(int(request.args.get("offset", 0)), 0)
except (ValueError, TypeError):
limit, offset = 50, 0
# Parse filter parameters
type_filter = request.args.get("type", "").strip()
try:
after_ts = int(request.args.get("after", 0))
except (ValueError, TypeError):
after_ts = 0
try:
before_ts = int(request.args.get("before", 0))
except (ValueError, TypeError):
before_ts = 0
db = get_db()
# Build query with filters
where_clauses = ["owner = ?"]
params: list[Any] = [client_id]
if after_ts > 0:
where_clauses.append("created_at >= ?")
params.append(after_ts)
if before_ts > 0:
where_clauses.append("created_at <= ?")
params.append(before_ts)
where_sql = " AND ".join(where_clauses)
# Count total pastes matching filters (where_sql is safe, built from constants)
count_row = db.execute(
f"SELECT COUNT(*) as total FROM pastes WHERE {where_sql}", # noqa: S608
params,
).fetchone()
total = count_row["total"] if count_row else 0
# Fetch pastes with metadata only (where_sql is safe, built from constants)
rows = db.execute(
f"""SELECT id, mime_type, length(content) as size, created_at,
last_accessed, burn_after_read, expires_at,
CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected
FROM pastes
WHERE {where_sql}
ORDER BY created_at DESC
LIMIT ? OFFSET ?""", # noqa: S608
[*params, limit, offset],
).fetchall()
# Apply MIME type filter (glob pattern matching done in Python for flexibility)
if type_filter:
rows = [r for r in rows if fnmatch.fnmatch(r["mime_type"], type_filter)]
pastes = []
for row in rows:
paste: dict[str, Any] = {
"id": row["id"],
"mime_type": row["mime_type"],
"size": row["size"],
"created_at": row["created_at"],
"last_accessed": row["last_accessed"],
"url": f"/{row['id']}",
"raw": f"/{row['id']}/raw",
}
if row["burn_after_read"]:
paste["burn_after_read"] = True
if row["expires_at"]:
paste["expires_at"] = row["expires_at"]
if row["password_protected"]:
paste["password_protected"] = True
pastes.append(paste)
return json_response(
{
"pastes": pastes,
"count": len(pastes),
"total": total,
"limit": limit,
"offset": offset,
}
)
# ─────────────────────────────────────────────────────────────────────────────
# PKI Views (Certificate Authority)
# ─────────────────────────────────────────────────────────────────────────────
@@ -1060,7 +1401,8 @@ bp.add_url_rule("/challenge", view_func=ChallengeView.as_view("challenge"))
bp.add_url_rule("/client", view_func=ClientView.as_view("client"))
# Paste operations
bp.add_url_rule("/<paste_id>", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD"])
bp.add_url_rule("/pastes", view_func=PastesListView.as_view("pastes_list"))
bp.add_url_rule("/<paste_id>", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD", "PUT"])
bp.add_url_rule(
"/<paste_id>/raw", view_func=PasteRawView.as_view("paste_raw"), methods=["GET", "HEAD"]
)

View File

@@ -67,6 +67,18 @@ class Config:
# URL prefix for reverse proxy deployments (e.g., "/paste" for mymx.me/paste)
URL_PREFIX = os.environ.get("FLASKPASTE_URL_PREFIX", "").rstrip("/")
# IP-based rate limiting
# Limits paste creation per IP address using sliding window
RATE_LIMIT_ENABLED = os.environ.get("FLASKPASTE_RATE_LIMIT", "1").lower() in (
"1",
"true",
"yes",
)
RATE_LIMIT_WINDOW = int(os.environ.get("FLASKPASTE_RATE_WINDOW", "60")) # seconds
RATE_LIMIT_MAX = int(os.environ.get("FLASKPASTE_RATE_MAX", "10")) # requests per window
# Authenticated users get higher limits (multiplier)
RATE_LIMIT_AUTH_MULTIPLIER = int(os.environ.get("FLASKPASTE_RATE_AUTH_MULT", "5"))
# PKI Configuration
# Enable PKI endpoints for certificate authority and issuance
PKI_ENABLED = os.environ.get("FLASKPASTE_PKI_ENABLED", "0").lower() in ("1", "true", "yes")
@@ -103,6 +115,11 @@ class TestingConfig(Config):
# Disable PoW for most tests (easier testing)
POW_DIFFICULTY = 0
# Relaxed rate limiting for tests
RATE_LIMIT_ENABLED = True
RATE_LIMIT_WINDOW = 1
RATE_LIMIT_MAX = 100
# PKI testing configuration
PKI_ENABLED = True
PKI_CA_PASSWORD = "test-ca-password"

481
fpaste
View File

@@ -385,6 +385,425 @@ def cmd_info(args, config):
die("failed to connect to server")
def format_size(size):
"""Format byte size as human-readable string."""
if size < 1024:
return f"{size}B"
elif size < 1024 * 1024:
return f"{size / 1024:.1f}K"
else:
return f"{size / (1024 * 1024):.1f}M"
def format_timestamp(ts):
"""Format Unix timestamp as human-readable date."""
from datetime import datetime
dt = datetime.fromtimestamp(ts, tz=UTC)
return dt.strftime("%Y-%m-%d %H:%M")
def cmd_list(args, config):
"""List user's pastes."""
if not config["cert_sha1"]:
die("authentication required (set FLASKPASTE_CERT_SHA1)")
base = config["server"].rstrip("/")
params = []
if args.limit:
params.append(f"limit={args.limit}")
if args.offset:
params.append(f"offset={args.offset}")
url = f"{base}/pastes"
if params:
url += "?" + "&".join(params)
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context"))
if status == 401:
die("authentication failed")
elif status != 200:
die(f"failed to list pastes ({status})")
data = json.loads(body)
pastes = data.get("pastes", [])
if args.json:
print(json.dumps(data, indent=2))
return
if not pastes:
print("no pastes found")
return
# Print header
print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS")
for p in pastes:
paste_id = p["id"]
mime_type = p.get("mime_type", "unknown")[:16]
size = format_size(p.get("size", 0))
created = format_timestamp(p.get("created_at", 0))
flags = []
if p.get("burn_after_read"):
flags.append("burn")
if p.get("password_protected"):
flags.append("pass")
if p.get("expires_at"):
flags.append("exp")
flag_str = " ".join(flags)
print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}")
# Print summary
print(f"\n{data.get('count', 0)} of {data.get('total', 0)} pastes shown")
def parse_date(date_str):
"""Parse date string to Unix timestamp."""
from datetime import datetime
if not date_str:
return 0
# Try various formats
formats = [
"%Y-%m-%d",
"%Y-%m-%d %H:%M",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M:%SZ",
]
for fmt in formats:
try:
dt = datetime.strptime(date_str, fmt)
dt = dt.replace(tzinfo=UTC)
return int(dt.timestamp())
except ValueError:
continue
# Try as Unix timestamp
try:
return int(date_str)
except ValueError:
pass
die(f"invalid date format: {date_str}")
def cmd_search(args, config):
"""Search user's pastes."""
if not config["cert_sha1"]:
die("authentication required (set FLASKPASTE_CERT_SHA1)")
base = config["server"].rstrip("/")
params = []
if args.type:
params.append(f"type={args.type}")
if args.after:
ts = parse_date(args.after)
params.append(f"after={ts}")
if args.before:
ts = parse_date(args.before)
params.append(f"before={ts}")
if args.limit:
params.append(f"limit={args.limit}")
url = f"{base}/pastes"
if params:
url += "?" + "&".join(params)
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context"))
if status == 401:
die("authentication failed")
elif status != 200:
die(f"failed to search pastes ({status})")
data = json.loads(body)
pastes = data.get("pastes", [])
if args.json:
print(json.dumps(data, indent=2))
return
if not pastes:
print("no matching pastes found")
return
# Print header
print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS")
for p in pastes:
paste_id = p["id"]
mime_type = p.get("mime_type", "unknown")[:16]
size = format_size(p.get("size", 0))
created = format_timestamp(p.get("created_at", 0))
flags = []
if p.get("burn_after_read"):
flags.append("burn")
if p.get("password_protected"):
flags.append("pass")
if p.get("expires_at"):
flags.append("exp")
flag_str = " ".join(flags)
print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}")
# Print summary
print(f"\n{data.get('count', 0)} matching pastes found")
def cmd_update(args, config):
"""Update an existing paste."""
if not config["cert_sha1"]:
die("authentication required (set FLASKPASTE_CERT_SHA1)")
paste_id = args.id.split("/")[-1] # Handle full URLs
if "#" in paste_id:
paste_id = paste_id.split("#")[0] # Remove key fragment
base = config["server"].rstrip("/")
url = f"{base}/{paste_id}"
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
content = None
# Read content from file if provided
if args.file:
if args.file == "-":
content = sys.stdin.buffer.read()
else:
path = Path(args.file)
if not path.exists():
die(f"file not found: {args.file}")
content = path.read_bytes()
if not content:
die("empty content")
# Encrypt if requested (default is to encrypt)
if not getattr(args, "no_encrypt", False):
if not HAS_CRYPTO:
die("encryption requires 'cryptography' package (use -E to disable)")
if not args.quiet:
print("encrypting...", end="", file=sys.stderr)
content, encryption_key = encrypt_content(content)
if not args.quiet:
print(" done", file=sys.stderr)
else:
encryption_key = None
# Set metadata update headers
if args.password:
headers["X-Paste-Password"] = args.password
if args.remove_password:
headers["X-Remove-Password"] = "true"
if args.expiry:
headers["X-Extend-Expiry"] = str(args.expiry)
# Make request
status, body, _ = request(
url, method="PUT", data=content, headers=headers, ssl_context=config.get("ssl_context")
)
if status == 200:
data = json.loads(body)
if args.quiet:
print(paste_id)
else:
print(f"updated: {paste_id}")
print(f" size: {data.get('size', 'unknown')}")
print(f" type: {data.get('mime_type', 'unknown')}")
if data.get("expires_at"):
print(f" expires: {data.get('expires_at')}")
if data.get("password_protected"):
print(" password: protected")
# Show new encryption key if content was updated and encrypted
if content and "encryption_key" in dir() and encryption_key:
key_fragment = "#" + encode_key(encryption_key)
print(f" key: {base}/{paste_id}{key_fragment}")
elif status == 400:
try:
err = json.loads(body).get("error", "bad request")
except (json.JSONDecodeError, UnicodeDecodeError):
err = "bad request"
die(err)
elif status == 401:
die("authentication failed")
elif status == 403:
die("permission denied (not owner)")
elif status == 404:
die(f"not found: {paste_id}")
else:
die(f"update failed ({status})")
def cmd_export(args, config):
"""Export user's pastes to a directory."""
if not config["cert_sha1"]:
die("authentication required (set FLASKPASTE_CERT_SHA1)")
base = config["server"].rstrip("/")
out_dir = Path(args.output) if args.output else Path("fpaste-export")
# Create output directory
out_dir.mkdir(parents=True, exist_ok=True)
# Load key file if provided
keys = {}
if args.keyfile:
keyfile_path = Path(args.keyfile)
if not keyfile_path.exists():
die(f"key file not found: {args.keyfile}")
for line in keyfile_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
paste_id, key_encoded = line.split("=", 1)
keys[paste_id.strip()] = key_encoded.strip()
# Fetch paste list
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
url = f"{base}/pastes?limit=1000" # Fetch all pastes
status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context"))
if status == 401:
die("authentication failed")
elif status != 200:
die(f"failed to list pastes ({status})")
data = json.loads(body)
pastes = data.get("pastes", [])
if not pastes:
print("no pastes to export")
return
# Export each paste
exported = 0
skipped = 0
errors = 0
manifest = []
for p in pastes:
paste_id = p["id"]
mime_type = p.get("mime_type", "application/octet-stream")
if not args.quiet:
print(f"exporting {paste_id}...", end=" ", file=sys.stderr)
# Skip burn-after-read pastes
if p.get("burn_after_read"):
if not args.quiet:
print("skipped (burn-after-read)", file=sys.stderr)
skipped += 1
continue
# Fetch raw content
raw_url = f"{base}/{paste_id}/raw"
req_headers = dict(headers)
if p.get("password_protected"):
if not args.quiet:
print("skipped (password-protected)", file=sys.stderr)
skipped += 1
continue
ssl_ctx = config.get("ssl_context")
status, content, _ = request(raw_url, headers=req_headers, ssl_context=ssl_ctx)
if status != 200:
if not args.quiet:
print(f"error ({status})", file=sys.stderr)
errors += 1
continue
# Decrypt if key available
decrypted = False
if paste_id in keys:
try:
key = decode_key(keys[paste_id])
content = decrypt_content(content, key)
decrypted = True
except SystemExit:
# Decryption failed, keep encrypted content
if not args.quiet:
print("decryption failed, keeping encrypted", file=sys.stderr, end=" ")
# Determine file extension from MIME type
ext = get_extension_for_mime(mime_type)
filename = f"{paste_id}{ext}"
filepath = out_dir / filename
# Write content
filepath.write_bytes(content)
# Add to manifest
manifest.append(
{
"id": paste_id,
"filename": filename,
"mime_type": mime_type,
"size": len(content),
"created_at": p.get("created_at"),
"decrypted": decrypted,
"encrypted": paste_id in keys and not decrypted,
}
)
if not args.quiet:
status_msg = "decrypted" if decrypted else ("encrypted" if paste_id in keys else "ok")
print(status_msg, file=sys.stderr)
exported += 1
# Write manifest
if args.manifest:
manifest_path = out_dir / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))
if not args.quiet:
print(f"manifest: {manifest_path}", file=sys.stderr)
# Summary
print(f"\nexported: {exported}, skipped: {skipped}, errors: {errors}")
print(f"output: {out_dir}")
def get_extension_for_mime(mime_type):
"""Get file extension for MIME type."""
mime_map = {
"text/plain": ".txt",
"text/html": ".html",
"text/css": ".css",
"text/javascript": ".js",
"text/markdown": ".md",
"text/x-python": ".py",
"application/json": ".json",
"application/xml": ".xml",
"application/javascript": ".js",
"application/octet-stream": ".bin",
"image/png": ".png",
"image/jpeg": ".jpg",
"image/gif": ".gif",
"image/webp": ".webp",
"image/svg+xml": ".svg",
"application/pdf": ".pdf",
"application/zip": ".zip",
"application/gzip": ".gz",
"application/x-tar": ".tar",
}
return mime_map.get(mime_type, ".bin")
def cmd_pki_status(args, config):
"""Show PKI status and CA information."""
url = config["server"].rstrip("/") + "/pki"
@@ -752,7 +1171,28 @@ def is_file_path(arg):
def main():
# Pre-process arguments: if first positional looks like a file, insert "create"
args_to_parse = sys.argv[1:]
commands = {"create", "c", "new", "get", "g", "delete", "d", "rm", "info", "i", "cert", "pki"}
commands = {
"create",
"c",
"new",
"get",
"g",
"delete",
"d",
"rm",
"info",
"i",
"list",
"ls",
"search",
"s",
"find",
"update",
"u",
"export",
"cert",
"pki",
}
# Find insertion point for "create" command
insert_pos = 0
@@ -823,6 +1263,37 @@ def main():
# info
subparsers.add_parser("info", aliases=["i"], help="show server info")
# list
p_list = subparsers.add_parser("list", aliases=["ls"], help="list your pastes")
p_list.add_argument("-l", "--limit", type=int, metavar="N", help="max pastes (default: 50)")
p_list.add_argument("-o", "--offset", type=int, metavar="N", help="skip first N pastes")
p_list.add_argument("--json", action="store_true", help="output as JSON")
# search
p_search = subparsers.add_parser("search", aliases=["s", "find"], help="search your pastes")
p_search.add_argument("-t", "--type", metavar="PATTERN", help="filter by MIME type (image/*)")
p_search.add_argument("--after", metavar="DATE", help="created after (YYYY-MM-DD or timestamp)")
p_search.add_argument("--before", metavar="DATE", help="created before (YYYY-MM-DD)")
p_search.add_argument("-l", "--limit", type=int, metavar="N", help="max results (default: 50)")
p_search.add_argument("--json", action="store_true", help="output as JSON")
# update
p_update = subparsers.add_parser("update", aliases=["u"], help="update existing paste")
p_update.add_argument("id", help="paste ID or URL")
p_update.add_argument("file", nargs="?", help="new content (- for stdin)")
p_update.add_argument("-E", "--no-encrypt", action="store_true", help="disable encryption")
p_update.add_argument("-p", "--password", metavar="PASS", help="set/change password")
p_update.add_argument("--remove-password", action="store_true", help="remove password")
p_update.add_argument("-x", "--expiry", type=int, metavar="SEC", help="extend expiry (seconds)")
p_update.add_argument("-q", "--quiet", action="store_true", help="minimal output")
# export
p_export = subparsers.add_parser("export", help="export all pastes to directory")
p_export.add_argument("-o", "--output", metavar="DIR", help="output directory")
p_export.add_argument("-k", "--keyfile", metavar="FILE", help="key file (paste_id=key format)")
p_export.add_argument("--manifest", action="store_true", help="write manifest.json")
p_export.add_argument("-q", "--quiet", action="store_true", help="minimal output")
# cert
p_cert = subparsers.add_parser("cert", help="generate client certificate")
p_cert.add_argument("-o", "--output", metavar="DIR", help="output directory")
@@ -904,6 +1375,14 @@ def main():
cmd_delete(args, config)
elif args.command in ("info", "i"):
cmd_info(args, config)
elif args.command in ("list", "ls"):
cmd_list(args, config)
elif args.command in ("search", "s", "find"):
cmd_search(args, config)
elif args.command in ("update", "u"):
cmd_update(args, config)
elif args.command == "export":
cmd_export(args, config)
elif args.command == "cert":
cmd_cert(args, config)
elif args.command == "pki":

View File

@@ -2,14 +2,37 @@
import pytest
import app.database as db_module
from app import create_app
from app.api.routes import reset_rate_limits
def _clear_database():
"""Clear all data from database tables for test isolation."""
if db_module._memory_db_holder is not None:
db_module._memory_db_holder.execute("DELETE FROM pastes")
db_module._memory_db_holder.execute("DELETE FROM content_hashes")
db_module._memory_db_holder.execute("DELETE FROM issued_certificates")
db_module._memory_db_holder.execute("DELETE FROM certificate_authority")
db_module._memory_db_holder.commit()
@pytest.fixture
def app():
"""Create application for testing."""
app = create_app("testing")
yield app
# Reset global state for test isolation
reset_rate_limits()
_clear_database()
test_app = create_app("testing")
# Clear database again after app init (in case init added anything)
_clear_database()
yield test_app
# Cleanup after test
reset_rate_limits()
@pytest.fixture

335
tests/test_paste_listing.py Normal file
View File

@@ -0,0 +1,335 @@
"""Tests for paste listing endpoint (GET /pastes)."""
import json
class TestPastesListEndpoint:
"""Tests for GET /pastes endpoint."""
def test_list_pastes_requires_auth(self, client):
"""List pastes requires authentication."""
response = client.get("/pastes")
assert response.status_code == 401
data = json.loads(response.data)
assert "error" in data
def test_list_pastes_empty(self, client, auth_header):
"""List pastes returns empty when user has no pastes."""
response = client.get("/pastes", headers=auth_header)
assert response.status_code == 200
data = json.loads(response.data)
assert data["pastes"] == []
assert data["count"] == 0
assert data["total"] == 0
def test_list_pastes_returns_own_pastes(self, client, sample_text, auth_header):
"""List pastes returns only user's own pastes."""
# Create a paste
create = client.post(
"/",
data=sample_text,
content_type="text/plain",
headers=auth_header,
)
paste_id = json.loads(create.data)["id"]
# List pastes
response = client.get("/pastes", headers=auth_header)
assert response.status_code == 200
data = json.loads(response.data)
assert data["count"] == 1
assert data["total"] == 1
assert data["pastes"][0]["id"] == paste_id
def test_list_pastes_excludes_others(self, client, sample_text, auth_header, other_auth_header):
"""List pastes does not include other users' pastes."""
# Create paste as user A
client.post(
"/",
data=sample_text,
content_type="text/plain",
headers=auth_header,
)
# List pastes as user B
response = client.get("/pastes", headers=other_auth_header)
assert response.status_code == 200
data = json.loads(response.data)
assert data["count"] == 0
assert data["total"] == 0
def test_list_pastes_excludes_anonymous(self, client, sample_text, auth_header):
"""List pastes does not include anonymous pastes."""
# Create anonymous paste
client.post("/", data=sample_text, content_type="text/plain")
# List pastes as authenticated user
response = client.get("/pastes", headers=auth_header)
assert response.status_code == 200
data = json.loads(response.data)
assert data["count"] == 0
def test_list_pastes_metadata_only(self, client, sample_text, auth_header):
"""List pastes returns metadata, not content."""
# Create a paste
client.post(
"/",
data=sample_text,
content_type="text/plain",
headers=auth_header,
)
# List pastes
response = client.get("/pastes", headers=auth_header)
data = json.loads(response.data)
paste = data["pastes"][0]
# Verify metadata fields
assert "id" in paste
assert "mime_type" in paste
assert "size" in paste
assert "created_at" in paste
assert "last_accessed" in paste
assert "url" in paste
assert "raw" in paste
# Verify content is NOT included
assert "content" not in paste
def test_list_pastes_pagination(self, client, auth_header):
"""List pastes supports pagination."""
# Create multiple pastes
for i in range(5):
client.post(
"/",
data=f"paste {i}",
content_type="text/plain",
headers=auth_header,
)
# Get first page
response = client.get("/pastes?limit=2&offset=0", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 2
assert data["total"] == 5
assert data["limit"] == 2
assert data["offset"] == 0
# Get second page
response = client.get("/pastes?limit=2&offset=2", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 2
assert data["offset"] == 2
def test_list_pastes_max_limit(self, client, auth_header):
"""List pastes enforces maximum limit."""
response = client.get("/pastes?limit=500", headers=auth_header)
data = json.loads(response.data)
assert data["limit"] == 200 # Max limit enforced
def test_list_pastes_invalid_pagination(self, client, auth_header):
"""List pastes handles invalid pagination gracefully."""
response = client.get("/pastes?limit=abc&offset=-1", headers=auth_header)
assert response.status_code == 200
data = json.loads(response.data)
# Should use defaults
assert data["limit"] == 50
assert data["offset"] == 0
def test_list_pastes_includes_special_fields(self, client, auth_header):
"""List pastes includes burn_after_read, expires_at, password_protected."""
# Create paste with burn-after-read
client.post(
"/",
data="burn test",
content_type="text/plain",
headers={**auth_header, "X-Burn-After-Read": "true"},
)
response = client.get("/pastes", headers=auth_header)
data = json.loads(response.data)
paste = data["pastes"][0]
assert paste.get("burn_after_read") is True
def test_list_pastes_ordered_by_created_at(self, client, auth_header):
"""List pastes returns all created pastes ordered by created_at DESC."""
# Create pastes
ids = set()
for i in range(3):
create = client.post(
"/",
data=f"paste {i}",
content_type="text/plain",
headers=auth_header,
)
ids.add(json.loads(create.data)["id"])
response = client.get("/pastes", headers=auth_header)
data = json.loads(response.data)
# All created pastes should be present
returned_ids = {p["id"] for p in data["pastes"]}
assert returned_ids == ids
assert data["count"] == 3
class TestPastesPrivacy:
"""Privacy-focused tests for paste listing."""
def test_cannot_see_other_user_pastes(
self, client, sample_text, auth_header, other_auth_header
):
"""Users cannot see pastes owned by others."""
# User A creates paste
create = client.post(
"/",
data=sample_text,
content_type="text/plain",
headers=auth_header,
)
paste_id = json.loads(create.data)["id"]
# User B lists pastes - should not see A's paste
response = client.get("/pastes", headers=other_auth_header)
data = json.loads(response.data)
paste_ids = [p["id"] for p in data["pastes"]]
assert paste_id not in paste_ids
def test_no_admin_bypass(self, client, sample_text, auth_header):
"""No special admin access to list all pastes."""
# Create paste as regular user
client.post(
"/",
data=sample_text,
content_type="text/plain",
headers=auth_header,
)
# Different user cannot see it, regardless of auth header format
admin_header = {"X-SSL-Client-SHA1": "0" * 40}
response = client.get("/pastes", headers=admin_header)
data = json.loads(response.data)
assert data["count"] == 0
def test_content_never_exposed(self, client, auth_header):
"""Paste content is never exposed in listing."""
secret = "super secret content that should never be exposed"
client.post(
"/",
data=secret,
content_type="text/plain",
headers=auth_header,
)
response = client.get("/pastes", headers=auth_header)
# Content should not appear anywhere in response
assert secret.encode() not in response.data
class TestPastesSearch:
"""Tests for paste search parameters."""
def test_search_by_type_exact(self, client, auth_header, png_bytes):
"""Search pastes by exact MIME type."""
# Create text paste
client.post("/", data="text content", content_type="text/plain", headers=auth_header)
# Create image paste
create = client.post("/", data=png_bytes, content_type="image/png", headers=auth_header)
png_id = json.loads(create.data)["id"]
# Search for image/png
response = client.get("/pastes?type=image/png", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 1
assert data["pastes"][0]["id"] == png_id
def test_search_by_type_glob(self, client, auth_header, png_bytes, jpeg_bytes):
"""Search pastes by MIME type glob pattern."""
# Create text paste
client.post("/", data="text content", content_type="text/plain", headers=auth_header)
# Create image pastes
client.post("/", data=png_bytes, content_type="image/png", headers=auth_header)
client.post("/", data=jpeg_bytes, content_type="image/jpeg", headers=auth_header)
# Search for all images
response = client.get("/pastes?type=image/*", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 2
for paste in data["pastes"]:
assert paste["mime_type"].startswith("image/")
def test_search_by_after_timestamp(self, client, auth_header):
"""Search pastes created after timestamp."""
# Create paste
create = client.post("/", data="test", content_type="text/plain", headers=auth_header)
paste_data = json.loads(create.data)
created_at = paste_data["created_at"]
# Search for pastes after creation time (should find it)
response = client.get(f"/pastes?after={created_at - 1}", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 1
# Search for pastes after creation time + 1 (should not find it)
response = client.get(f"/pastes?after={created_at + 1}", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 0
def test_search_by_before_timestamp(self, client, auth_header):
"""Search pastes created before timestamp."""
# Create paste
create = client.post("/", data="test", content_type="text/plain", headers=auth_header)
paste_data = json.loads(create.data)
created_at = paste_data["created_at"]
# Search for pastes before creation time + 1 (should find it)
response = client.get(f"/pastes?before={created_at + 1}", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 1
# Search for pastes before creation time (should not find it)
response = client.get(f"/pastes?before={created_at - 1}", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 0
def test_search_combined_filters(self, client, auth_header, png_bytes):
"""Search with multiple filters combined."""
# Create text paste
client.post("/", data="text", content_type="text/plain", headers=auth_header)
# Create image paste
create = client.post("/", data=png_bytes, content_type="image/png", headers=auth_header)
png_data = json.loads(create.data)
created_at = png_data["created_at"]
# Search for images after a certain time
response = client.get(
f"/pastes?type=image/*&after={created_at - 1}",
headers=auth_header,
)
data = json.loads(response.data)
assert data["count"] == 1
assert data["pastes"][0]["mime_type"] == "image/png"
def test_search_no_matches(self, client, auth_header):
"""Search with no matching results."""
# Create text paste
client.post("/", data="text", content_type="text/plain", headers=auth_header)
# Search for video (no matches)
response = client.get("/pastes?type=video/*", headers=auth_header)
data = json.loads(response.data)
assert data["count"] == 0
assert data["pastes"] == []
def test_search_invalid_timestamp(self, client, auth_header):
"""Search with invalid timestamp uses default."""
client.post("/", data="test", content_type="text/plain", headers=auth_header)
# Invalid timestamp should be ignored
response = client.get("/pastes?after=invalid", headers=auth_header)
assert response.status_code == 200
data = json.loads(response.data)
assert data["count"] == 1

242
tests/test_paste_update.py Normal file
View File

@@ -0,0 +1,242 @@
"""Tests for paste update endpoint (PUT /<id>)."""
import json
class TestPasteUpdateEndpoint:
"""Tests for PUT /<id> endpoint."""
def test_update_requires_auth(self, client, sample_text, auth_header):
"""Update requires authentication."""
# Create paste
create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Try to update without auth
response = client.put(f"/{paste_id}", data="updated")
assert response.status_code == 401
def test_update_requires_ownership(self, client, sample_text, auth_header, other_auth_header):
"""Update requires paste ownership."""
# Create paste as user A
create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Try to update as user B
response = client.put(
f"/{paste_id}",
data="updated content",
content_type="text/plain",
headers=other_auth_header,
)
assert response.status_code == 403
def test_update_content(self, client, auth_header):
"""Update paste content."""
# Create paste
create = client.post("/", data="original", content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Update content
response = client.put(
f"/{paste_id}",
data="updated content",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 200
data = json.loads(response.data)
assert data["size"] == len("updated content")
# Verify content changed
raw = client.get(f"/{paste_id}/raw")
assert raw.data == b"updated content"
def test_update_password_set(self, client, auth_header):
"""Set password on paste."""
# Create paste without password
create = client.post("/", data="content", content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Add password
response = client.put(
f"/{paste_id}",
data="",
headers={**auth_header, "X-Paste-Password": "secret123"},
)
assert response.status_code == 200
data = json.loads(response.data)
assert data.get("password_protected") is True
# Verify password required
raw = client.get(f"/{paste_id}/raw")
assert raw.status_code == 401
# Verify correct password works
raw = client.get(f"/{paste_id}/raw", headers={"X-Paste-Password": "secret123"})
assert raw.status_code == 200
def test_update_password_remove(self, client, auth_header):
"""Remove password from paste."""
# Create paste with password
create = client.post(
"/",
data="content",
content_type="text/plain",
headers={**auth_header, "X-Paste-Password": "secret123"},
)
paste_id = json.loads(create.data)["id"]
# Remove password
response = client.put(
f"/{paste_id}",
data="",
headers={**auth_header, "X-Remove-Password": "true"},
)
assert response.status_code == 200
data = json.loads(response.data)
assert data.get("password_protected") is not True
# Verify no password required
raw = client.get(f"/{paste_id}/raw")
assert raw.status_code == 200
def test_update_extend_expiry(self, client, auth_header):
"""Extend paste expiry."""
# Create paste with expiry
create = client.post(
"/",
data="content",
content_type="text/plain",
headers={**auth_header, "X-Expiry": "3600"},
)
paste_id = json.loads(create.data)["id"]
original_expiry = json.loads(create.data)["expires_at"]
# Extend expiry
response = client.put(
f"/{paste_id}",
data="",
headers={**auth_header, "X-Extend-Expiry": "7200"},
)
assert response.status_code == 200
data = json.loads(response.data)
assert data["expires_at"] == original_expiry + 7200
def test_update_add_expiry(self, client, auth_header):
"""Add expiry to paste without one."""
# Create paste without expiry
create = client.post("/", data="content", content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Add expiry
response = client.put(
f"/{paste_id}",
data="",
headers={**auth_header, "X-Extend-Expiry": "3600"},
)
assert response.status_code == 200
data = json.loads(response.data)
assert "expires_at" in data
def test_update_burn_after_read_forbidden(self, client, auth_header):
"""Cannot update burn-after-read pastes."""
# Create burn-after-read paste
create = client.post(
"/",
data="content",
content_type="text/plain",
headers={**auth_header, "X-Burn-After-Read": "true"},
)
paste_id = json.loads(create.data)["id"]
# Try to update
response = client.put(
f"/{paste_id}",
data="updated",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 400
data = json.loads(response.data)
assert "burn" in data["error"].lower()
def test_update_not_found(self, client, auth_header):
"""Update non-existent paste returns 404."""
response = client.put(
"/000000000000",
data="updated",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 404
def test_update_no_changes(self, client, auth_header):
"""Update with no changes returns error."""
# Create paste
create = client.post("/", data="content", content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Try to update with empty request
response = client.put(f"/{paste_id}", data="", headers=auth_header)
assert response.status_code == 400
data = json.loads(response.data)
assert "no updates" in data["error"].lower()
def test_update_combined(self, client, auth_header):
"""Update content and metadata together."""
# Create paste
create = client.post("/", data="original", content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Update content and add password
response = client.put(
f"/{paste_id}",
data="new content",
content_type="text/plain",
headers={**auth_header, "X-Paste-Password": "secret"},
)
assert response.status_code == 200
data = json.loads(response.data)
assert data["size"] == len("new content")
assert data.get("password_protected") is True
class TestPasteUpdatePrivacy:
"""Privacy-focused tests for paste update."""
def test_cannot_update_anonymous_paste(self, client, sample_text, auth_header):
"""Cannot update paste without owner."""
# Create anonymous paste
create = client.post("/", data=sample_text, content_type="text/plain")
paste_id = json.loads(create.data)["id"]
# Try to update
response = client.put(
f"/{paste_id}",
data="updated",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 403
def test_cannot_update_other_user_paste(
self, client, sample_text, auth_header, other_auth_header
):
"""Cannot update paste owned by another user."""
# Create paste as user A
create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header)
paste_id = json.loads(create.data)["id"]
# Try to update as user B
response = client.put(
f"/{paste_id}",
data="hijacked",
content_type="text/plain",
headers=other_auth_header,
)
assert response.status_code == 403
# Verify original content unchanged
raw = client.get(f"/{paste_id}/raw")
assert raw.data == sample_text.encode()

168
tests/test_rate_limiting.py Normal file
View File

@@ -0,0 +1,168 @@
"""Tests for IP-based rate limiting."""
import json
class TestRateLimiting:
"""Tests for rate limiting on paste creation."""
def test_rate_limit_allows_normal_usage(self, client, sample_text):
"""Normal usage within rate limit succeeds."""
# TestingConfig has RATE_LIMIT_MAX=100
for i in range(5):
response = client.post(
"/",
data=f"paste {i}",
content_type="text/plain",
)
assert response.status_code == 201
def test_rate_limit_exceeded_returns_429(self, client, app):
"""Exceeding rate limit returns 429."""
# Temporarily lower rate limit for test
original_max = app.config["RATE_LIMIT_MAX"]
app.config["RATE_LIMIT_MAX"] = 3
try:
# Make requests up to limit
for i in range(3):
response = client.post(
"/",
data=f"paste {i}",
content_type="text/plain",
)
assert response.status_code == 201
# Next request should be rate limited
response = client.post(
"/",
data="one more",
content_type="text/plain",
)
assert response.status_code == 429
data = json.loads(response.data)
assert "error" in data
assert "Rate limit" in data["error"]
assert "retry_after" in data
finally:
app.config["RATE_LIMIT_MAX"] = original_max
def test_rate_limit_headers(self, client, app):
"""Rate limit response includes proper headers."""
original_max = app.config["RATE_LIMIT_MAX"]
app.config["RATE_LIMIT_MAX"] = 1
try:
# First request succeeds
client.post("/", data="first", content_type="text/plain")
# Second request is rate limited
response = client.post("/", data="second", content_type="text/plain")
assert response.status_code == 429
assert "Retry-After" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert response.headers["X-RateLimit-Remaining"] == "0"
finally:
app.config["RATE_LIMIT_MAX"] = original_max
def test_rate_limit_auth_multiplier(self, client, app, auth_header):
"""Authenticated users get higher rate limits."""
original_max = app.config["RATE_LIMIT_MAX"]
original_mult = app.config["RATE_LIMIT_AUTH_MULTIPLIER"]
app.config["RATE_LIMIT_MAX"] = 2
app.config["RATE_LIMIT_AUTH_MULTIPLIER"] = 3 # 2 * 3 = 6 for auth users
try:
# Authenticated user can make more requests than base limit
for i in range(5):
response = client.post(
"/",
data=f"auth {i}",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 201
# 6th request should succeed (limit is 2*3=6)
response = client.post(
"/",
data="auth 6",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 201
# 7th should fail
response = client.post(
"/",
data="auth 7",
content_type="text/plain",
headers=auth_header,
)
assert response.status_code == 429
finally:
app.config["RATE_LIMIT_MAX"] = original_max
app.config["RATE_LIMIT_AUTH_MULTIPLIER"] = original_mult
def test_rate_limit_can_be_disabled(self, client, app):
"""Rate limiting can be disabled via config."""
original_enabled = app.config["RATE_LIMIT_ENABLED"]
original_max = app.config["RATE_LIMIT_MAX"]
app.config["RATE_LIMIT_ENABLED"] = False
app.config["RATE_LIMIT_MAX"] = 1
try:
# Should be able to make many requests
for i in range(5):
response = client.post(
"/",
data=f"paste {i}",
content_type="text/plain",
)
assert response.status_code == 201
finally:
app.config["RATE_LIMIT_ENABLED"] = original_enabled
app.config["RATE_LIMIT_MAX"] = original_max
def test_rate_limit_only_affects_paste_creation(self, client, app, sample_text):
"""Rate limiting only affects POST /, not GET endpoints."""
original_max = app.config["RATE_LIMIT_MAX"]
app.config["RATE_LIMIT_MAX"] = 2
try:
# Create paste
create = client.post("/", data=sample_text, content_type="text/plain")
assert create.status_code == 201
paste_id = json.loads(create.data)["id"]
# Use up rate limit
client.post("/", data="second", content_type="text/plain")
# Should be rate limited for creation
response = client.post("/", data="third", content_type="text/plain")
assert response.status_code == 429
# But GET should still work
response = client.get(f"/{paste_id}")
assert response.status_code == 200
response = client.get(f"/{paste_id}/raw")
assert response.status_code == 200
response = client.get("/health")
assert response.status_code == 200
finally:
app.config["RATE_LIMIT_MAX"] = original_max
class TestRateLimitCleanup:
"""Tests for rate limit cleanup."""
def test_rate_limit_window_expiry(self, app):
"""Rate limit cleanup works with explicit window."""
from app.api.routes import cleanup_rate_limits
# Should work without app context when window is explicit
cleaned = cleanup_rate_limits(window=60)
assert isinstance(cleaned, int)
assert cleaned >= 0

View File

@@ -0,0 +1,169 @@
"""Tests for scheduled cleanup functionality."""
import time
class TestScheduledCleanup:
"""Tests for scheduled cleanup in before_request hook."""
def test_cleanup_times_reset(self, app, client):
"""Verify cleanup times can be reset for testing."""
from app.api import _cleanup_times, reset_cleanup_times
# Set some cleanup times
with app.app_context():
reset_cleanup_times()
for key in _cleanup_times:
assert _cleanup_times[key] == 0
def test_cleanup_runs_on_request(self, app, client, auth_header):
"""Verify cleanup runs during request handling."""
from app.api import _cleanup_times, reset_cleanup_times
with app.app_context():
reset_cleanup_times()
# Make a request to trigger cleanup
client.get("/health")
# Cleanup times should be set
with app.app_context():
for key in _cleanup_times:
assert _cleanup_times[key] > 0
def test_cleanup_respects_intervals(self, app, client):
"""Verify cleanup respects configured intervals."""
from app.api import _cleanup_times, reset_cleanup_times
with app.app_context():
reset_cleanup_times()
# First request triggers cleanup
client.get("/health")
with app.app_context():
first_times = {k: v for k, v in _cleanup_times.items()}
# Immediate second request should not reset times
client.get("/health")
with app.app_context():
for key in _cleanup_times:
assert _cleanup_times[key] == first_times[key]
def test_expired_paste_cleanup(self, app, client, auth_header):
"""Test that expired pastes are cleaned up."""
import json
from app.database import get_db
# Create paste with short expiry
response = client.post(
"/",
data="test content",
content_type="text/plain",
headers={**auth_header, "X-Expiry": "1"}, # 1 second
)
assert response.status_code == 201
paste_id = json.loads(response.data)["id"]
# Wait for expiry
time.sleep(1.5)
# Trigger cleanup via database function
with app.app_context():
from app.database import cleanup_expired_pastes
count = cleanup_expired_pastes()
assert count >= 1
# Verify paste is gone
db = get_db()
row = db.execute("SELECT id FROM pastes WHERE id = ?", (paste_id,)).fetchone()
assert row is None
def test_rate_limit_cleanup(self, app, client):
"""Test that rate limit entries are cleaned up."""
from app.api.routes import (
_rate_limit_requests,
check_rate_limit,
cleanup_rate_limits,
reset_rate_limits,
)
with app.app_context():
reset_rate_limits()
# Add some rate limit entries
check_rate_limit("192.168.1.1", authenticated=False)
check_rate_limit("192.168.1.2", authenticated=False)
assert len(_rate_limit_requests) == 2
# Cleanup with very large window (nothing should be removed)
count = cleanup_rate_limits(window=3600)
assert count == 0
# Wait a bit and cleanup with tiny window
time.sleep(0.1)
count = cleanup_rate_limits(window=0) # Immediate cleanup
assert count == 2
assert len(_rate_limit_requests) == 0
class TestCleanupThreadSafety:
"""Tests for thread-safety of cleanup operations."""
def test_cleanup_lock_exists(self, app):
"""Verify cleanup lock exists."""
import threading
from app.api import _cleanup_lock
assert isinstance(_cleanup_lock, type(threading.Lock()))
def test_concurrent_cleanup_access(self, app, client):
"""Test that concurrent requests don't corrupt cleanup state."""
import threading
from app.api import reset_cleanup_times
with app.app_context():
reset_cleanup_times()
errors = []
results = []
def make_request():
try:
resp = client.get("/health")
results.append(resp.status_code)
except Exception as e:
errors.append(str(e))
# Simulate concurrent requests
threads = [threading.Thread(target=make_request) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors
assert all(r == 200 for r in results)
class TestCleanupConfiguration:
"""Tests for cleanup configuration."""
def test_cleanup_intervals_configured(self, app):
"""Verify cleanup intervals are properly configured."""
from app.api import _CLEANUP_INTERVALS
assert "pastes" in _CLEANUP_INTERVALS
assert "hashes" in _CLEANUP_INTERVALS
assert "rate_limits" in _CLEANUP_INTERVALS
# Verify reasonable intervals
assert _CLEANUP_INTERVALS["pastes"] >= 60 # At least 1 minute
assert _CLEANUP_INTERVALS["hashes"] >= 60
assert _CLEANUP_INTERVALS["rate_limits"] >= 60