add CLI enhancements and scheduled cleanup

CLI commands:
- list: show user's pastes with pagination
- search: filter by type (glob), after/before timestamps
- update: modify content, password, or extend expiry
- export: save pastes to directory with optional decryption

API changes:
- PUT /<id>: update paste content and metadata
- GET /pastes: add type, after, before query params

Scheduled tasks:
- Thread-safe cleanup with per-task intervals
- Activate cleanup_expired_hashes (15min)
- Activate cleanup_rate_limits (5min)

Tests: 205 passing
This commit is contained in:
Username
2025-12-20 20:13:00 +01:00
parent cf31eab678
commit bfc238b5cf
9 changed files with 1826 additions and 18 deletions

View File

@@ -8,7 +8,9 @@ import json
import math
import re
import secrets
import threading
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from flask import Response, current_app, g, request
@@ -50,6 +52,96 @@ GENERIC_MIME_TYPES = frozenset(
# Runtime PoW secret cache
_pow_secret_cache: bytes | None = None
# ─────────────────────────────────────────────────────────────────────────────
# Rate Limiting (in-memory sliding window)
# ─────────────────────────────────────────────────────────────────────────────
_rate_limit_lock = threading.Lock()
_rate_limit_requests: dict[str, list[float]] = defaultdict(list)
def get_client_ip() -> str:
"""Get client IP address, respecting X-Forwarded-For from trusted proxy."""
if is_trusted_proxy():
forwarded = request.headers.get("X-Forwarded-For", "")
if forwarded:
# Take the first (client) IP from the chain
return forwarded.split(",")[0].strip()
return request.remote_addr or "unknown"
def check_rate_limit(client_ip: str, authenticated: bool = False) -> tuple[bool, int, int]:
"""Check if request is within rate limit.
Args:
client_ip: Client IP address
authenticated: Whether client is authenticated (higher limits)
Returns:
Tuple of (allowed, remaining, reset_seconds)
"""
if not current_app.config.get("RATE_LIMIT_ENABLED", True):
return True, -1, 0
window = current_app.config["RATE_LIMIT_WINDOW"]
max_requests = current_app.config["RATE_LIMIT_MAX"]
if authenticated:
max_requests *= current_app.config.get("RATE_LIMIT_AUTH_MULTIPLIER", 5)
now = time.time()
cutoff = now - window
with _rate_limit_lock:
# Clean old requests and get current list
requests = _rate_limit_requests[client_ip]
requests[:] = [t for t in requests if t > cutoff]
current_count = len(requests)
if current_count >= max_requests:
# Calculate reset time (when oldest request expires)
reset_at = int(requests[0] + window - now) + 1 if requests else window
return False, 0, reset_at
# Record this request
requests.append(now)
remaining = max_requests - len(requests)
return True, remaining, window
def cleanup_rate_limits(window: int | None = None) -> int:
"""Remove expired rate limit entries. Returns count of cleaned entries.
Args:
window: Rate limit window in seconds. If None, uses app config.
"""
# This should be called periodically (e.g., via cleanup task)
if window is None:
window = current_app.config.get("RATE_LIMIT_WINDOW", 60)
cutoff = time.time() - window
cleaned = 0
with _rate_limit_lock:
to_remove = []
for ip, requests in _rate_limit_requests.items():
requests[:] = [t for t in requests if t > cutoff]
if not requests:
to_remove.append(ip)
for ip in to_remove:
del _rate_limit_requests[ip]
cleaned += 1
return cleaned
def reset_rate_limits() -> None:
"""Clear all rate limit state. For testing only."""
with _rate_limit_lock:
_rate_limit_requests.clear()
# ─────────────────────────────────────────────────────────────────────────────
# Response Helpers
@@ -379,6 +471,7 @@ class IndexView(MethodView):
f"GET {prefixed_url('/client')}": "Download CLI client",
f"GET {prefixed_url('/challenge')}": "Get PoW challenge",
f"POST {prefixed_url('/')}": "Create paste",
f"GET {prefixed_url('/pastes')}": "List your pastes (auth required)",
f"GET {prefixed_url('/<id>')}": "Retrieve paste metadata",
f"GET {prefixed_url('/<id>/raw')}": "Retrieve raw paste content",
f"DELETE {prefixed_url('/<id>')}": "Delete paste",
@@ -413,6 +506,22 @@ class IndexView(MethodView):
owner = get_client_id()
# Rate limiting (check before expensive operations)
client_ip = get_client_ip()
allowed, _remaining, reset_seconds = check_rate_limit(client_ip, authenticated=bool(owner))
if not allowed:
current_app.logger.warning("Rate limit exceeded: ip=%s auth=%s", client_ip, bool(owner))
response = error_response(
"Rate limit exceeded",
429,
retry_after=reset_seconds,
)
response.headers["Retry-After"] = str(reset_seconds)
response.headers["X-RateLimit-Remaining"] = "0"
response.headers["X-RateLimit-Reset"] = str(reset_seconds)
return response
# Proof-of-work verification
difficulty = current_app.config["POW_DIFFICULTY"]
if difficulty > 0:
@@ -681,6 +790,125 @@ class PasteView(MethodView):
"""Return paste metadata headers only."""
return self.get(paste_id)
def put(self, paste_id: str) -> Response:
"""Update paste content and/or metadata.
Requires authentication and ownership.
Content update: Send raw body with Content-Type header
Metadata update: Use headers with empty body
Headers:
- X-Paste-Password: Set/change password
- X-Remove-Password: true to remove password
- X-Extend-Expiry: Seconds to add to current expiry
"""
# Validate paste ID format
if err := validate_paste_id(paste_id):
return err
if err := require_auth():
return err
db = get_db()
# Fetch current paste
row = db.execute(
"""SELECT id, owner, content, mime_type, expires_at, password_hash
FROM pastes WHERE id = ?""",
(paste_id,),
).fetchone()
if row is None:
return error_response("Paste not found", 404)
if row["owner"] != g.client_id:
return error_response("Permission denied", 403)
# Check for burn-after-read (cannot update)
burn_check = db.execute(
"SELECT burn_after_read FROM pastes WHERE id = ?", (paste_id,)
).fetchone()
if burn_check and burn_check["burn_after_read"]:
return error_response("Cannot update burn-after-read paste", 400)
# Parse update parameters
new_password = request.headers.get("X-Paste-Password", "").strip() or None
remove_password = request.headers.get("X-Remove-Password", "").lower() in (
"true",
"1",
"yes",
)
extend_expiry_str = request.headers.get("X-Extend-Expiry", "").strip()
# Prepare update fields
update_fields = []
update_params: list[Any] = []
# Content update (if body provided)
content = request.get_data()
if content:
mime_type = request.content_type or "application/octet-stream"
# Sanitize MIME type
if not MIME_PATTERN.match(mime_type.split(";")[0].strip()):
mime_type = "application/octet-stream"
update_fields.append("content = ?")
update_params.append(content)
update_fields.append("mime_type = ?")
update_params.append(mime_type.split(";")[0].strip())
# Password update
if remove_password:
update_fields.append("password_hash = NULL")
elif new_password:
update_fields.append("password_hash = ?")
update_params.append(hash_password(new_password))
# Expiry extension
if extend_expiry_str:
try:
extend_seconds = int(extend_expiry_str)
if extend_seconds > 0:
current_expiry = row["expires_at"]
if current_expiry:
new_expiry = current_expiry + extend_seconds
else:
# If no expiry set, create one from now
new_expiry = int(time.time()) + extend_seconds
update_fields.append("expires_at = ?")
update_params.append(new_expiry)
except ValueError:
return error_response("Invalid X-Extend-Expiry value", 400)
if not update_fields:
return error_response("No updates provided", 400)
# Execute update (fields are hardcoded strings, safe from injection)
update_sql = f"UPDATE pastes SET {', '.join(update_fields)} WHERE id = ?" # noqa: S608
update_params.append(paste_id)
db.execute(update_sql, update_params)
db.commit()
# Fetch updated paste for response
updated = db.execute(
"""SELECT id, mime_type, length(content) as size, expires_at,
CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected
FROM pastes WHERE id = ?""",
(paste_id,),
).fetchone()
response_data: dict[str, Any] = {
"id": updated["id"],
"size": updated["size"],
"mime_type": updated["mime_type"],
}
if updated["expires_at"]:
response_data["expires_at"] = updated["expires_at"]
if updated["password_protected"]:
response_data["password_protected"] = True
return json_response(response_data)
class PasteRawView(MethodView):
"""Raw paste content retrieval."""
@@ -759,6 +987,119 @@ class PasteDeleteView(MethodView):
return json_response({"message": "Paste deleted"})
class PastesListView(MethodView):
"""List authenticated user's pastes (privacy-focused)."""
def get(self) -> Response:
"""List pastes owned by authenticated user.
Privacy guarantees:
- Requires authentication (mTLS client certificate)
- Users can ONLY see their own pastes
- No admin bypass or cross-user visibility
- Content is never returned, only metadata
Query parameters:
- limit: max results (default 50, max 200)
- offset: pagination offset (default 0)
- type: filter by MIME type (glob pattern, e.g., "image/*")
- after: filter by created_at >= timestamp
- before: filter by created_at <= timestamp
"""
import fnmatch
# Strict authentication requirement
if err := require_auth():
return err
client_id = g.client_id
# Parse pagination parameters
try:
limit = min(int(request.args.get("limit", 50)), 200)
offset = max(int(request.args.get("offset", 0)), 0)
except (ValueError, TypeError):
limit, offset = 50, 0
# Parse filter parameters
type_filter = request.args.get("type", "").strip()
try:
after_ts = int(request.args.get("after", 0))
except (ValueError, TypeError):
after_ts = 0
try:
before_ts = int(request.args.get("before", 0))
except (ValueError, TypeError):
before_ts = 0
db = get_db()
# Build query with filters
where_clauses = ["owner = ?"]
params: list[Any] = [client_id]
if after_ts > 0:
where_clauses.append("created_at >= ?")
params.append(after_ts)
if before_ts > 0:
where_clauses.append("created_at <= ?")
params.append(before_ts)
where_sql = " AND ".join(where_clauses)
# Count total pastes matching filters (where_sql is safe, built from constants)
count_row = db.execute(
f"SELECT COUNT(*) as total FROM pastes WHERE {where_sql}", # noqa: S608
params,
).fetchone()
total = count_row["total"] if count_row else 0
# Fetch pastes with metadata only (where_sql is safe, built from constants)
rows = db.execute(
f"""SELECT id, mime_type, length(content) as size, created_at,
last_accessed, burn_after_read, expires_at,
CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected
FROM pastes
WHERE {where_sql}
ORDER BY created_at DESC
LIMIT ? OFFSET ?""", # noqa: S608
[*params, limit, offset],
).fetchall()
# Apply MIME type filter (glob pattern matching done in Python for flexibility)
if type_filter:
rows = [r for r in rows if fnmatch.fnmatch(r["mime_type"], type_filter)]
pastes = []
for row in rows:
paste: dict[str, Any] = {
"id": row["id"],
"mime_type": row["mime_type"],
"size": row["size"],
"created_at": row["created_at"],
"last_accessed": row["last_accessed"],
"url": f"/{row['id']}",
"raw": f"/{row['id']}/raw",
}
if row["burn_after_read"]:
paste["burn_after_read"] = True
if row["expires_at"]:
paste["expires_at"] = row["expires_at"]
if row["password_protected"]:
paste["password_protected"] = True
pastes.append(paste)
return json_response(
{
"pastes": pastes,
"count": len(pastes),
"total": total,
"limit": limit,
"offset": offset,
}
)
# ─────────────────────────────────────────────────────────────────────────────
# PKI Views (Certificate Authority)
# ─────────────────────────────────────────────────────────────────────────────
@@ -1060,7 +1401,8 @@ bp.add_url_rule("/challenge", view_func=ChallengeView.as_view("challenge"))
bp.add_url_rule("/client", view_func=ClientView.as_view("client"))
# Paste operations
bp.add_url_rule("/<paste_id>", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD"])
bp.add_url_rule("/pastes", view_func=PastesListView.as_view("pastes_list"))
bp.add_url_rule("/<paste_id>", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD", "PUT"])
bp.add_url_rule(
"/<paste_id>/raw", view_func=PasteRawView.as_view("paste_raw"), methods=["GET", "HEAD"]
)