forked from username/flaskpaste
add CLI enhancements and scheduled cleanup
CLI commands: - list: show user's pastes with pagination - search: filter by type (glob), after/before timestamps - update: modify content, password, or extend expiry - export: save pastes to directory with optional decryption API changes: - PUT /<id>: update paste content and metadata - GET /pastes: add type, after, before query params Scheduled tasks: - Thread-safe cleanup with per-task intervals - Activate cleanup_expired_hashes (15min) - Activate cleanup_rate_limits (5min) Tests: 205 passing
This commit is contained in:
@@ -1,32 +1,65 @@
|
||||
"""API blueprint registration."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from flask import Blueprint, current_app
|
||||
|
||||
bp = Blueprint("api", __name__)
|
||||
|
||||
# Throttle cleanup to run at most once per hour
|
||||
_last_cleanup = 0
|
||||
_CLEANUP_INTERVAL = 3600 # 1 hour
|
||||
# Thread-safe cleanup scheduling
|
||||
_cleanup_lock = threading.Lock()
|
||||
_cleanup_times = {
|
||||
"pastes": 0,
|
||||
"hashes": 0,
|
||||
"rate_limits": 0,
|
||||
}
|
||||
_CLEANUP_INTERVALS = {
|
||||
"pastes": 3600, # 1 hour
|
||||
"hashes": 900, # 15 minutes
|
||||
"rate_limits": 300, # 5 minutes
|
||||
}
|
||||
|
||||
|
||||
def reset_cleanup_times() -> None:
|
||||
"""Reset cleanup timestamps. For testing only."""
|
||||
with _cleanup_lock:
|
||||
for key in _cleanup_times:
|
||||
_cleanup_times[key] = 0
|
||||
|
||||
|
||||
@bp.before_request
|
||||
def cleanup_expired():
|
||||
"""Periodically clean up expired pastes."""
|
||||
global _last_cleanup
|
||||
|
||||
def run_scheduled_cleanup():
|
||||
"""Periodically run cleanup tasks on schedule."""
|
||||
now = time.time()
|
||||
if now - _last_cleanup < _CLEANUP_INTERVAL:
|
||||
return
|
||||
|
||||
_last_cleanup = now
|
||||
|
||||
with _cleanup_lock:
|
||||
# Cleanup expired pastes
|
||||
if now - _cleanup_times["pastes"] >= _CLEANUP_INTERVALS["pastes"]:
|
||||
_cleanup_times["pastes"] = now
|
||||
from app.database import cleanup_expired_pastes
|
||||
|
||||
count = cleanup_expired_pastes()
|
||||
if count > 0:
|
||||
current_app.logger.info(f"Cleaned up {count} expired paste(s)")
|
||||
|
||||
# Cleanup expired content hashes
|
||||
if now - _cleanup_times["hashes"] >= _CLEANUP_INTERVALS["hashes"]:
|
||||
_cleanup_times["hashes"] = now
|
||||
from app.database import cleanup_expired_hashes
|
||||
|
||||
count = cleanup_expired_hashes()
|
||||
if count > 0:
|
||||
current_app.logger.info(f"Cleaned up {count} expired hash(es)")
|
||||
|
||||
# Cleanup rate limit entries
|
||||
if now - _cleanup_times["rate_limits"] >= _CLEANUP_INTERVALS["rate_limits"]:
|
||||
_cleanup_times["rate_limits"] = now
|
||||
from app.api.routes import cleanup_rate_limits
|
||||
|
||||
count = cleanup_rate_limits()
|
||||
if count > 0:
|
||||
current_app.logger.info(f"Cleaned up {count} rate limit entr(ies)")
|
||||
|
||||
|
||||
from app.api import routes # noqa: E402, F401
|
||||
|
||||
@@ -8,7 +8,9 @@ import json
|
||||
import math
|
||||
import re
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from flask import Response, current_app, g, request
|
||||
@@ -50,6 +52,96 @@ GENERIC_MIME_TYPES = frozenset(
|
||||
# Runtime PoW secret cache
|
||||
_pow_secret_cache: bytes | None = None
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Rate Limiting (in-memory sliding window)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_rate_limit_lock = threading.Lock()
|
||||
_rate_limit_requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
def get_client_ip() -> str:
|
||||
"""Get client IP address, respecting X-Forwarded-For from trusted proxy."""
|
||||
if is_trusted_proxy():
|
||||
forwarded = request.headers.get("X-Forwarded-For", "")
|
||||
if forwarded:
|
||||
# Take the first (client) IP from the chain
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.remote_addr or "unknown"
|
||||
|
||||
|
||||
def check_rate_limit(client_ip: str, authenticated: bool = False) -> tuple[bool, int, int]:
|
||||
"""Check if request is within rate limit.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address
|
||||
authenticated: Whether client is authenticated (higher limits)
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, remaining, reset_seconds)
|
||||
"""
|
||||
if not current_app.config.get("RATE_LIMIT_ENABLED", True):
|
||||
return True, -1, 0
|
||||
|
||||
window = current_app.config["RATE_LIMIT_WINDOW"]
|
||||
max_requests = current_app.config["RATE_LIMIT_MAX"]
|
||||
|
||||
if authenticated:
|
||||
max_requests *= current_app.config.get("RATE_LIMIT_AUTH_MULTIPLIER", 5)
|
||||
|
||||
now = time.time()
|
||||
cutoff = now - window
|
||||
|
||||
with _rate_limit_lock:
|
||||
# Clean old requests and get current list
|
||||
requests = _rate_limit_requests[client_ip]
|
||||
requests[:] = [t for t in requests if t > cutoff]
|
||||
|
||||
current_count = len(requests)
|
||||
|
||||
if current_count >= max_requests:
|
||||
# Calculate reset time (when oldest request expires)
|
||||
reset_at = int(requests[0] + window - now) + 1 if requests else window
|
||||
return False, 0, reset_at
|
||||
|
||||
# Record this request
|
||||
requests.append(now)
|
||||
remaining = max_requests - len(requests)
|
||||
|
||||
return True, remaining, window
|
||||
|
||||
|
||||
def cleanup_rate_limits(window: int | None = None) -> int:
|
||||
"""Remove expired rate limit entries. Returns count of cleaned entries.
|
||||
|
||||
Args:
|
||||
window: Rate limit window in seconds. If None, uses app config.
|
||||
"""
|
||||
# This should be called periodically (e.g., via cleanup task)
|
||||
if window is None:
|
||||
window = current_app.config.get("RATE_LIMIT_WINDOW", 60)
|
||||
cutoff = time.time() - window
|
||||
|
||||
cleaned = 0
|
||||
with _rate_limit_lock:
|
||||
to_remove = []
|
||||
for ip, requests in _rate_limit_requests.items():
|
||||
requests[:] = [t for t in requests if t > cutoff]
|
||||
if not requests:
|
||||
to_remove.append(ip)
|
||||
|
||||
for ip in to_remove:
|
||||
del _rate_limit_requests[ip]
|
||||
cleaned += 1
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def reset_rate_limits() -> None:
|
||||
"""Clear all rate limit state. For testing only."""
|
||||
with _rate_limit_lock:
|
||||
_rate_limit_requests.clear()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Response Helpers
|
||||
@@ -379,6 +471,7 @@ class IndexView(MethodView):
|
||||
f"GET {prefixed_url('/client')}": "Download CLI client",
|
||||
f"GET {prefixed_url('/challenge')}": "Get PoW challenge",
|
||||
f"POST {prefixed_url('/')}": "Create paste",
|
||||
f"GET {prefixed_url('/pastes')}": "List your pastes (auth required)",
|
||||
f"GET {prefixed_url('/<id>')}": "Retrieve paste metadata",
|
||||
f"GET {prefixed_url('/<id>/raw')}": "Retrieve raw paste content",
|
||||
f"DELETE {prefixed_url('/<id>')}": "Delete paste",
|
||||
@@ -413,6 +506,22 @@ class IndexView(MethodView):
|
||||
|
||||
owner = get_client_id()
|
||||
|
||||
# Rate limiting (check before expensive operations)
|
||||
client_ip = get_client_ip()
|
||||
allowed, _remaining, reset_seconds = check_rate_limit(client_ip, authenticated=bool(owner))
|
||||
|
||||
if not allowed:
|
||||
current_app.logger.warning("Rate limit exceeded: ip=%s auth=%s", client_ip, bool(owner))
|
||||
response = error_response(
|
||||
"Rate limit exceeded",
|
||||
429,
|
||||
retry_after=reset_seconds,
|
||||
)
|
||||
response.headers["Retry-After"] = str(reset_seconds)
|
||||
response.headers["X-RateLimit-Remaining"] = "0"
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_seconds)
|
||||
return response
|
||||
|
||||
# Proof-of-work verification
|
||||
difficulty = current_app.config["POW_DIFFICULTY"]
|
||||
if difficulty > 0:
|
||||
@@ -681,6 +790,125 @@ class PasteView(MethodView):
|
||||
"""Return paste metadata headers only."""
|
||||
return self.get(paste_id)
|
||||
|
||||
def put(self, paste_id: str) -> Response:
|
||||
"""Update paste content and/or metadata.
|
||||
|
||||
Requires authentication and ownership.
|
||||
|
||||
Content update: Send raw body with Content-Type header
|
||||
Metadata update: Use headers with empty body
|
||||
|
||||
Headers:
|
||||
- X-Paste-Password: Set/change password
|
||||
- X-Remove-Password: true to remove password
|
||||
- X-Extend-Expiry: Seconds to add to current expiry
|
||||
"""
|
||||
# Validate paste ID format
|
||||
if err := validate_paste_id(paste_id):
|
||||
return err
|
||||
if err := require_auth():
|
||||
return err
|
||||
|
||||
db = get_db()
|
||||
|
||||
# Fetch current paste
|
||||
row = db.execute(
|
||||
"""SELECT id, owner, content, mime_type, expires_at, password_hash
|
||||
FROM pastes WHERE id = ?""",
|
||||
(paste_id,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
return error_response("Paste not found", 404)
|
||||
|
||||
if row["owner"] != g.client_id:
|
||||
return error_response("Permission denied", 403)
|
||||
|
||||
# Check for burn-after-read (cannot update)
|
||||
burn_check = db.execute(
|
||||
"SELECT burn_after_read FROM pastes WHERE id = ?", (paste_id,)
|
||||
).fetchone()
|
||||
if burn_check and burn_check["burn_after_read"]:
|
||||
return error_response("Cannot update burn-after-read paste", 400)
|
||||
|
||||
# Parse update parameters
|
||||
new_password = request.headers.get("X-Paste-Password", "").strip() or None
|
||||
remove_password = request.headers.get("X-Remove-Password", "").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
extend_expiry_str = request.headers.get("X-Extend-Expiry", "").strip()
|
||||
|
||||
# Prepare update fields
|
||||
update_fields = []
|
||||
update_params: list[Any] = []
|
||||
|
||||
# Content update (if body provided)
|
||||
content = request.get_data()
|
||||
if content:
|
||||
mime_type = request.content_type or "application/octet-stream"
|
||||
# Sanitize MIME type
|
||||
if not MIME_PATTERN.match(mime_type.split(";")[0].strip()):
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
update_fields.append("content = ?")
|
||||
update_params.append(content)
|
||||
update_fields.append("mime_type = ?")
|
||||
update_params.append(mime_type.split(";")[0].strip())
|
||||
|
||||
# Password update
|
||||
if remove_password:
|
||||
update_fields.append("password_hash = NULL")
|
||||
elif new_password:
|
||||
update_fields.append("password_hash = ?")
|
||||
update_params.append(hash_password(new_password))
|
||||
|
||||
# Expiry extension
|
||||
if extend_expiry_str:
|
||||
try:
|
||||
extend_seconds = int(extend_expiry_str)
|
||||
if extend_seconds > 0:
|
||||
current_expiry = row["expires_at"]
|
||||
if current_expiry:
|
||||
new_expiry = current_expiry + extend_seconds
|
||||
else:
|
||||
# If no expiry set, create one from now
|
||||
new_expiry = int(time.time()) + extend_seconds
|
||||
update_fields.append("expires_at = ?")
|
||||
update_params.append(new_expiry)
|
||||
except ValueError:
|
||||
return error_response("Invalid X-Extend-Expiry value", 400)
|
||||
|
||||
if not update_fields:
|
||||
return error_response("No updates provided", 400)
|
||||
|
||||
# Execute update (fields are hardcoded strings, safe from injection)
|
||||
update_sql = f"UPDATE pastes SET {', '.join(update_fields)} WHERE id = ?" # noqa: S608
|
||||
update_params.append(paste_id)
|
||||
db.execute(update_sql, update_params)
|
||||
db.commit()
|
||||
|
||||
# Fetch updated paste for response
|
||||
updated = db.execute(
|
||||
"""SELECT id, mime_type, length(content) as size, expires_at,
|
||||
CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected
|
||||
FROM pastes WHERE id = ?""",
|
||||
(paste_id,),
|
||||
).fetchone()
|
||||
|
||||
response_data: dict[str, Any] = {
|
||||
"id": updated["id"],
|
||||
"size": updated["size"],
|
||||
"mime_type": updated["mime_type"],
|
||||
}
|
||||
if updated["expires_at"]:
|
||||
response_data["expires_at"] = updated["expires_at"]
|
||||
if updated["password_protected"]:
|
||||
response_data["password_protected"] = True
|
||||
|
||||
return json_response(response_data)
|
||||
|
||||
|
||||
class PasteRawView(MethodView):
|
||||
"""Raw paste content retrieval."""
|
||||
@@ -759,6 +987,119 @@ class PasteDeleteView(MethodView):
|
||||
return json_response({"message": "Paste deleted"})
|
||||
|
||||
|
||||
class PastesListView(MethodView):
|
||||
"""List authenticated user's pastes (privacy-focused)."""
|
||||
|
||||
def get(self) -> Response:
|
||||
"""List pastes owned by authenticated user.
|
||||
|
||||
Privacy guarantees:
|
||||
- Requires authentication (mTLS client certificate)
|
||||
- Users can ONLY see their own pastes
|
||||
- No admin bypass or cross-user visibility
|
||||
- Content is never returned, only metadata
|
||||
|
||||
Query parameters:
|
||||
- limit: max results (default 50, max 200)
|
||||
- offset: pagination offset (default 0)
|
||||
- type: filter by MIME type (glob pattern, e.g., "image/*")
|
||||
- after: filter by created_at >= timestamp
|
||||
- before: filter by created_at <= timestamp
|
||||
"""
|
||||
import fnmatch
|
||||
|
||||
# Strict authentication requirement
|
||||
if err := require_auth():
|
||||
return err
|
||||
|
||||
client_id = g.client_id
|
||||
|
||||
# Parse pagination parameters
|
||||
try:
|
||||
limit = min(int(request.args.get("limit", 50)), 200)
|
||||
offset = max(int(request.args.get("offset", 0)), 0)
|
||||
except (ValueError, TypeError):
|
||||
limit, offset = 50, 0
|
||||
|
||||
# Parse filter parameters
|
||||
type_filter = request.args.get("type", "").strip()
|
||||
try:
|
||||
after_ts = int(request.args.get("after", 0))
|
||||
except (ValueError, TypeError):
|
||||
after_ts = 0
|
||||
try:
|
||||
before_ts = int(request.args.get("before", 0))
|
||||
except (ValueError, TypeError):
|
||||
before_ts = 0
|
||||
|
||||
db = get_db()
|
||||
|
||||
# Build query with filters
|
||||
where_clauses = ["owner = ?"]
|
||||
params: list[Any] = [client_id]
|
||||
|
||||
if after_ts > 0:
|
||||
where_clauses.append("created_at >= ?")
|
||||
params.append(after_ts)
|
||||
if before_ts > 0:
|
||||
where_clauses.append("created_at <= ?")
|
||||
params.append(before_ts)
|
||||
|
||||
where_sql = " AND ".join(where_clauses)
|
||||
|
||||
# Count total pastes matching filters (where_sql is safe, built from constants)
|
||||
count_row = db.execute(
|
||||
f"SELECT COUNT(*) as total FROM pastes WHERE {where_sql}", # noqa: S608
|
||||
params,
|
||||
).fetchone()
|
||||
total = count_row["total"] if count_row else 0
|
||||
|
||||
# Fetch pastes with metadata only (where_sql is safe, built from constants)
|
||||
rows = db.execute(
|
||||
f"""SELECT id, mime_type, length(content) as size, created_at,
|
||||
last_accessed, burn_after_read, expires_at,
|
||||
CASE WHEN password_hash IS NOT NULL THEN 1 ELSE 0 END as password_protected
|
||||
FROM pastes
|
||||
WHERE {where_sql}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?""", # noqa: S608
|
||||
[*params, limit, offset],
|
||||
).fetchall()
|
||||
|
||||
# Apply MIME type filter (glob pattern matching done in Python for flexibility)
|
||||
if type_filter:
|
||||
rows = [r for r in rows if fnmatch.fnmatch(r["mime_type"], type_filter)]
|
||||
|
||||
pastes = []
|
||||
for row in rows:
|
||||
paste: dict[str, Any] = {
|
||||
"id": row["id"],
|
||||
"mime_type": row["mime_type"],
|
||||
"size": row["size"],
|
||||
"created_at": row["created_at"],
|
||||
"last_accessed": row["last_accessed"],
|
||||
"url": f"/{row['id']}",
|
||||
"raw": f"/{row['id']}/raw",
|
||||
}
|
||||
if row["burn_after_read"]:
|
||||
paste["burn_after_read"] = True
|
||||
if row["expires_at"]:
|
||||
paste["expires_at"] = row["expires_at"]
|
||||
if row["password_protected"]:
|
||||
paste["password_protected"] = True
|
||||
pastes.append(paste)
|
||||
|
||||
return json_response(
|
||||
{
|
||||
"pastes": pastes,
|
||||
"count": len(pastes),
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# PKI Views (Certificate Authority)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -1060,7 +1401,8 @@ bp.add_url_rule("/challenge", view_func=ChallengeView.as_view("challenge"))
|
||||
bp.add_url_rule("/client", view_func=ClientView.as_view("client"))
|
||||
|
||||
# Paste operations
|
||||
bp.add_url_rule("/<paste_id>", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD"])
|
||||
bp.add_url_rule("/pastes", view_func=PastesListView.as_view("pastes_list"))
|
||||
bp.add_url_rule("/<paste_id>", view_func=PasteView.as_view("paste"), methods=["GET", "HEAD", "PUT"])
|
||||
bp.add_url_rule(
|
||||
"/<paste_id>/raw", view_func=PasteRawView.as_view("paste_raw"), methods=["GET", "HEAD"]
|
||||
)
|
||||
|
||||
@@ -67,6 +67,18 @@ class Config:
|
||||
# URL prefix for reverse proxy deployments (e.g., "/paste" for mymx.me/paste)
|
||||
URL_PREFIX = os.environ.get("FLASKPASTE_URL_PREFIX", "").rstrip("/")
|
||||
|
||||
# IP-based rate limiting
|
||||
# Limits paste creation per IP address using sliding window
|
||||
RATE_LIMIT_ENABLED = os.environ.get("FLASKPASTE_RATE_LIMIT", "1").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
RATE_LIMIT_WINDOW = int(os.environ.get("FLASKPASTE_RATE_WINDOW", "60")) # seconds
|
||||
RATE_LIMIT_MAX = int(os.environ.get("FLASKPASTE_RATE_MAX", "10")) # requests per window
|
||||
# Authenticated users get higher limits (multiplier)
|
||||
RATE_LIMIT_AUTH_MULTIPLIER = int(os.environ.get("FLASKPASTE_RATE_AUTH_MULT", "5"))
|
||||
|
||||
# PKI Configuration
|
||||
# Enable PKI endpoints for certificate authority and issuance
|
||||
PKI_ENABLED = os.environ.get("FLASKPASTE_PKI_ENABLED", "0").lower() in ("1", "true", "yes")
|
||||
@@ -103,6 +115,11 @@ class TestingConfig(Config):
|
||||
# Disable PoW for most tests (easier testing)
|
||||
POW_DIFFICULTY = 0
|
||||
|
||||
# Relaxed rate limiting for tests
|
||||
RATE_LIMIT_ENABLED = True
|
||||
RATE_LIMIT_WINDOW = 1
|
||||
RATE_LIMIT_MAX = 100
|
||||
|
||||
# PKI testing configuration
|
||||
PKI_ENABLED = True
|
||||
PKI_CA_PASSWORD = "test-ca-password"
|
||||
|
||||
481
fpaste
481
fpaste
@@ -385,6 +385,425 @@ def cmd_info(args, config):
|
||||
die("failed to connect to server")
|
||||
|
||||
|
||||
def format_size(size):
|
||||
"""Format byte size as human-readable string."""
|
||||
if size < 1024:
|
||||
return f"{size}B"
|
||||
elif size < 1024 * 1024:
|
||||
return f"{size / 1024:.1f}K"
|
||||
else:
|
||||
return f"{size / (1024 * 1024):.1f}M"
|
||||
|
||||
|
||||
def format_timestamp(ts):
|
||||
"""Format Unix timestamp as human-readable date."""
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(ts, tz=UTC)
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
|
||||
def cmd_list(args, config):
|
||||
"""List user's pastes."""
|
||||
if not config["cert_sha1"]:
|
||||
die("authentication required (set FLASKPASTE_CERT_SHA1)")
|
||||
|
||||
base = config["server"].rstrip("/")
|
||||
params = []
|
||||
if args.limit:
|
||||
params.append(f"limit={args.limit}")
|
||||
if args.offset:
|
||||
params.append(f"offset={args.offset}")
|
||||
|
||||
url = f"{base}/pastes"
|
||||
if params:
|
||||
url += "?" + "&".join(params)
|
||||
|
||||
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
|
||||
status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context"))
|
||||
|
||||
if status == 401:
|
||||
die("authentication failed")
|
||||
elif status != 200:
|
||||
die(f"failed to list pastes ({status})")
|
||||
|
||||
data = json.loads(body)
|
||||
pastes = data.get("pastes", [])
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(data, indent=2))
|
||||
return
|
||||
|
||||
if not pastes:
|
||||
print("no pastes found")
|
||||
return
|
||||
|
||||
# Print header
|
||||
print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS")
|
||||
|
||||
for p in pastes:
|
||||
paste_id = p["id"]
|
||||
mime_type = p.get("mime_type", "unknown")[:16]
|
||||
size = format_size(p.get("size", 0))
|
||||
created = format_timestamp(p.get("created_at", 0))
|
||||
|
||||
flags = []
|
||||
if p.get("burn_after_read"):
|
||||
flags.append("burn")
|
||||
if p.get("password_protected"):
|
||||
flags.append("pass")
|
||||
if p.get("expires_at"):
|
||||
flags.append("exp")
|
||||
|
||||
flag_str = " ".join(flags)
|
||||
print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}")
|
||||
|
||||
# Print summary
|
||||
print(f"\n{data.get('count', 0)} of {data.get('total', 0)} pastes shown")
|
||||
|
||||
|
||||
def parse_date(date_str):
|
||||
"""Parse date string to Unix timestamp."""
|
||||
from datetime import datetime
|
||||
|
||||
if not date_str:
|
||||
return 0
|
||||
|
||||
# Try various formats
|
||||
formats = [
|
||||
"%Y-%m-%d",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
]
|
||||
for fmt in formats:
|
||||
try:
|
||||
dt = datetime.strptime(date_str, fmt)
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return int(dt.timestamp())
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try as Unix timestamp
|
||||
try:
|
||||
return int(date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
die(f"invalid date format: {date_str}")
|
||||
|
||||
|
||||
def cmd_search(args, config):
|
||||
"""Search user's pastes."""
|
||||
if not config["cert_sha1"]:
|
||||
die("authentication required (set FLASKPASTE_CERT_SHA1)")
|
||||
|
||||
base = config["server"].rstrip("/")
|
||||
params = []
|
||||
|
||||
if args.type:
|
||||
params.append(f"type={args.type}")
|
||||
if args.after:
|
||||
ts = parse_date(args.after)
|
||||
params.append(f"after={ts}")
|
||||
if args.before:
|
||||
ts = parse_date(args.before)
|
||||
params.append(f"before={ts}")
|
||||
if args.limit:
|
||||
params.append(f"limit={args.limit}")
|
||||
|
||||
url = f"{base}/pastes"
|
||||
if params:
|
||||
url += "?" + "&".join(params)
|
||||
|
||||
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
|
||||
status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context"))
|
||||
|
||||
if status == 401:
|
||||
die("authentication failed")
|
||||
elif status != 200:
|
||||
die(f"failed to search pastes ({status})")
|
||||
|
||||
data = json.loads(body)
|
||||
pastes = data.get("pastes", [])
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(data, indent=2))
|
||||
return
|
||||
|
||||
if not pastes:
|
||||
print("no matching pastes found")
|
||||
return
|
||||
|
||||
# Print header
|
||||
print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS")
|
||||
|
||||
for p in pastes:
|
||||
paste_id = p["id"]
|
||||
mime_type = p.get("mime_type", "unknown")[:16]
|
||||
size = format_size(p.get("size", 0))
|
||||
created = format_timestamp(p.get("created_at", 0))
|
||||
|
||||
flags = []
|
||||
if p.get("burn_after_read"):
|
||||
flags.append("burn")
|
||||
if p.get("password_protected"):
|
||||
flags.append("pass")
|
||||
if p.get("expires_at"):
|
||||
flags.append("exp")
|
||||
|
||||
flag_str = " ".join(flags)
|
||||
print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}")
|
||||
|
||||
# Print summary
|
||||
print(f"\n{data.get('count', 0)} matching pastes found")
|
||||
|
||||
|
||||
def cmd_update(args, config):
|
||||
"""Update an existing paste."""
|
||||
if not config["cert_sha1"]:
|
||||
die("authentication required (set FLASKPASTE_CERT_SHA1)")
|
||||
|
||||
paste_id = args.id.split("/")[-1] # Handle full URLs
|
||||
if "#" in paste_id:
|
||||
paste_id = paste_id.split("#")[0] # Remove key fragment
|
||||
|
||||
base = config["server"].rstrip("/")
|
||||
url = f"{base}/{paste_id}"
|
||||
|
||||
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
|
||||
content = None
|
||||
|
||||
# Read content from file if provided
|
||||
if args.file:
|
||||
if args.file == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
else:
|
||||
path = Path(args.file)
|
||||
if not path.exists():
|
||||
die(f"file not found: {args.file}")
|
||||
content = path.read_bytes()
|
||||
|
||||
if not content:
|
||||
die("empty content")
|
||||
|
||||
# Encrypt if requested (default is to encrypt)
|
||||
if not getattr(args, "no_encrypt", False):
|
||||
if not HAS_CRYPTO:
|
||||
die("encryption requires 'cryptography' package (use -E to disable)")
|
||||
if not args.quiet:
|
||||
print("encrypting...", end="", file=sys.stderr)
|
||||
content, encryption_key = encrypt_content(content)
|
||||
if not args.quiet:
|
||||
print(" done", file=sys.stderr)
|
||||
else:
|
||||
encryption_key = None
|
||||
|
||||
# Set metadata update headers
|
||||
if args.password:
|
||||
headers["X-Paste-Password"] = args.password
|
||||
if args.remove_password:
|
||||
headers["X-Remove-Password"] = "true"
|
||||
if args.expiry:
|
||||
headers["X-Extend-Expiry"] = str(args.expiry)
|
||||
|
||||
# Make request
|
||||
status, body, _ = request(
|
||||
url, method="PUT", data=content, headers=headers, ssl_context=config.get("ssl_context")
|
||||
)
|
||||
|
||||
if status == 200:
|
||||
data = json.loads(body)
|
||||
if args.quiet:
|
||||
print(paste_id)
|
||||
else:
|
||||
print(f"updated: {paste_id}")
|
||||
print(f" size: {data.get('size', 'unknown')}")
|
||||
print(f" type: {data.get('mime_type', 'unknown')}")
|
||||
if data.get("expires_at"):
|
||||
print(f" expires: {data.get('expires_at')}")
|
||||
if data.get("password_protected"):
|
||||
print(" password: protected")
|
||||
|
||||
# Show new encryption key if content was updated and encrypted
|
||||
if content and "encryption_key" in dir() and encryption_key:
|
||||
key_fragment = "#" + encode_key(encryption_key)
|
||||
print(f" key: {base}/{paste_id}{key_fragment}")
|
||||
elif status == 400:
|
||||
try:
|
||||
err = json.loads(body).get("error", "bad request")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
err = "bad request"
|
||||
die(err)
|
||||
elif status == 401:
|
||||
die("authentication failed")
|
||||
elif status == 403:
|
||||
die("permission denied (not owner)")
|
||||
elif status == 404:
|
||||
die(f"not found: {paste_id}")
|
||||
else:
|
||||
die(f"update failed ({status})")
|
||||
|
||||
|
||||
def cmd_export(args, config):
|
||||
"""Export user's pastes to a directory."""
|
||||
if not config["cert_sha1"]:
|
||||
die("authentication required (set FLASKPASTE_CERT_SHA1)")
|
||||
|
||||
base = config["server"].rstrip("/")
|
||||
out_dir = Path(args.output) if args.output else Path("fpaste-export")
|
||||
|
||||
# Create output directory
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load key file if provided
|
||||
keys = {}
|
||||
if args.keyfile:
|
||||
keyfile_path = Path(args.keyfile)
|
||||
if not keyfile_path.exists():
|
||||
die(f"key file not found: {args.keyfile}")
|
||||
for line in keyfile_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
paste_id, key_encoded = line.split("=", 1)
|
||||
keys[paste_id.strip()] = key_encoded.strip()
|
||||
|
||||
# Fetch paste list
|
||||
headers = {"X-SSL-Client-SHA1": config["cert_sha1"]}
|
||||
url = f"{base}/pastes?limit=1000" # Fetch all pastes
|
||||
status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context"))
|
||||
|
||||
if status == 401:
|
||||
die("authentication failed")
|
||||
elif status != 200:
|
||||
die(f"failed to list pastes ({status})")
|
||||
|
||||
data = json.loads(body)
|
||||
pastes = data.get("pastes", [])
|
||||
|
||||
if not pastes:
|
||||
print("no pastes to export")
|
||||
return
|
||||
|
||||
# Export each paste
|
||||
exported = 0
|
||||
skipped = 0
|
||||
errors = 0
|
||||
manifest = []
|
||||
|
||||
for p in pastes:
|
||||
paste_id = p["id"]
|
||||
mime_type = p.get("mime_type", "application/octet-stream")
|
||||
|
||||
if not args.quiet:
|
||||
print(f"exporting {paste_id}...", end=" ", file=sys.stderr)
|
||||
|
||||
# Skip burn-after-read pastes
|
||||
if p.get("burn_after_read"):
|
||||
if not args.quiet:
|
||||
print("skipped (burn-after-read)", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Fetch raw content
|
||||
raw_url = f"{base}/{paste_id}/raw"
|
||||
req_headers = dict(headers)
|
||||
if p.get("password_protected"):
|
||||
if not args.quiet:
|
||||
print("skipped (password-protected)", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
ssl_ctx = config.get("ssl_context")
|
||||
status, content, _ = request(raw_url, headers=req_headers, ssl_context=ssl_ctx)
|
||||
|
||||
if status != 200:
|
||||
if not args.quiet:
|
||||
print(f"error ({status})", file=sys.stderr)
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
# Decrypt if key available
|
||||
decrypted = False
|
||||
if paste_id in keys:
|
||||
try:
|
||||
key = decode_key(keys[paste_id])
|
||||
content = decrypt_content(content, key)
|
||||
decrypted = True
|
||||
except SystemExit:
|
||||
# Decryption failed, keep encrypted content
|
||||
if not args.quiet:
|
||||
print("decryption failed, keeping encrypted", file=sys.stderr, end=" ")
|
||||
|
||||
# Determine file extension from MIME type
|
||||
ext = get_extension_for_mime(mime_type)
|
||||
filename = f"{paste_id}{ext}"
|
||||
filepath = out_dir / filename
|
||||
|
||||
# Write content
|
||||
filepath.write_bytes(content)
|
||||
|
||||
# Add to manifest
|
||||
manifest.append(
|
||||
{
|
||||
"id": paste_id,
|
||||
"filename": filename,
|
||||
"mime_type": mime_type,
|
||||
"size": len(content),
|
||||
"created_at": p.get("created_at"),
|
||||
"decrypted": decrypted,
|
||||
"encrypted": paste_id in keys and not decrypted,
|
||||
}
|
||||
)
|
||||
|
||||
if not args.quiet:
|
||||
status_msg = "decrypted" if decrypted else ("encrypted" if paste_id in keys else "ok")
|
||||
print(status_msg, file=sys.stderr)
|
||||
|
||||
exported += 1
|
||||
|
||||
# Write manifest
|
||||
if args.manifest:
|
||||
manifest_path = out_dir / "manifest.json"
|
||||
manifest_path.write_text(json.dumps(manifest, indent=2))
|
||||
if not args.quiet:
|
||||
print(f"manifest: {manifest_path}", file=sys.stderr)
|
||||
|
||||
# Summary
|
||||
print(f"\nexported: {exported}, skipped: {skipped}, errors: {errors}")
|
||||
print(f"output: {out_dir}")
|
||||
|
||||
|
||||
def get_extension_for_mime(mime_type):
|
||||
"""Get file extension for MIME type."""
|
||||
mime_map = {
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
"text/css": ".css",
|
||||
"text/javascript": ".js",
|
||||
"text/markdown": ".md",
|
||||
"text/x-python": ".py",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"application/javascript": ".js",
|
||||
"application/octet-stream": ".bin",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/svg+xml": ".svg",
|
||||
"application/pdf": ".pdf",
|
||||
"application/zip": ".zip",
|
||||
"application/gzip": ".gz",
|
||||
"application/x-tar": ".tar",
|
||||
}
|
||||
return mime_map.get(mime_type, ".bin")
|
||||
|
||||
|
||||
def cmd_pki_status(args, config):
|
||||
"""Show PKI status and CA information."""
|
||||
url = config["server"].rstrip("/") + "/pki"
|
||||
@@ -752,7 +1171,28 @@ def is_file_path(arg):
|
||||
def main():
|
||||
# Pre-process arguments: if first positional looks like a file, insert "create"
|
||||
args_to_parse = sys.argv[1:]
|
||||
commands = {"create", "c", "new", "get", "g", "delete", "d", "rm", "info", "i", "cert", "pki"}
|
||||
commands = {
|
||||
"create",
|
||||
"c",
|
||||
"new",
|
||||
"get",
|
||||
"g",
|
||||
"delete",
|
||||
"d",
|
||||
"rm",
|
||||
"info",
|
||||
"i",
|
||||
"list",
|
||||
"ls",
|
||||
"search",
|
||||
"s",
|
||||
"find",
|
||||
"update",
|
||||
"u",
|
||||
"export",
|
||||
"cert",
|
||||
"pki",
|
||||
}
|
||||
|
||||
# Find insertion point for "create" command
|
||||
insert_pos = 0
|
||||
@@ -823,6 +1263,37 @@ def main():
|
||||
# info
|
||||
subparsers.add_parser("info", aliases=["i"], help="show server info")
|
||||
|
||||
# list
|
||||
p_list = subparsers.add_parser("list", aliases=["ls"], help="list your pastes")
|
||||
p_list.add_argument("-l", "--limit", type=int, metavar="N", help="max pastes (default: 50)")
|
||||
p_list.add_argument("-o", "--offset", type=int, metavar="N", help="skip first N pastes")
|
||||
p_list.add_argument("--json", action="store_true", help="output as JSON")
|
||||
|
||||
# search
|
||||
p_search = subparsers.add_parser("search", aliases=["s", "find"], help="search your pastes")
|
||||
p_search.add_argument("-t", "--type", metavar="PATTERN", help="filter by MIME type (image/*)")
|
||||
p_search.add_argument("--after", metavar="DATE", help="created after (YYYY-MM-DD or timestamp)")
|
||||
p_search.add_argument("--before", metavar="DATE", help="created before (YYYY-MM-DD)")
|
||||
p_search.add_argument("-l", "--limit", type=int, metavar="N", help="max results (default: 50)")
|
||||
p_search.add_argument("--json", action="store_true", help="output as JSON")
|
||||
|
||||
# update
|
||||
p_update = subparsers.add_parser("update", aliases=["u"], help="update existing paste")
|
||||
p_update.add_argument("id", help="paste ID or URL")
|
||||
p_update.add_argument("file", nargs="?", help="new content (- for stdin)")
|
||||
p_update.add_argument("-E", "--no-encrypt", action="store_true", help="disable encryption")
|
||||
p_update.add_argument("-p", "--password", metavar="PASS", help="set/change password")
|
||||
p_update.add_argument("--remove-password", action="store_true", help="remove password")
|
||||
p_update.add_argument("-x", "--expiry", type=int, metavar="SEC", help="extend expiry (seconds)")
|
||||
p_update.add_argument("-q", "--quiet", action="store_true", help="minimal output")
|
||||
|
||||
# export
|
||||
p_export = subparsers.add_parser("export", help="export all pastes to directory")
|
||||
p_export.add_argument("-o", "--output", metavar="DIR", help="output directory")
|
||||
p_export.add_argument("-k", "--keyfile", metavar="FILE", help="key file (paste_id=key format)")
|
||||
p_export.add_argument("--manifest", action="store_true", help="write manifest.json")
|
||||
p_export.add_argument("-q", "--quiet", action="store_true", help="minimal output")
|
||||
|
||||
# cert
|
||||
p_cert = subparsers.add_parser("cert", help="generate client certificate")
|
||||
p_cert.add_argument("-o", "--output", metavar="DIR", help="output directory")
|
||||
@@ -904,6 +1375,14 @@ def main():
|
||||
cmd_delete(args, config)
|
||||
elif args.command in ("info", "i"):
|
||||
cmd_info(args, config)
|
||||
elif args.command in ("list", "ls"):
|
||||
cmd_list(args, config)
|
||||
elif args.command in ("search", "s", "find"):
|
||||
cmd_search(args, config)
|
||||
elif args.command in ("update", "u"):
|
||||
cmd_update(args, config)
|
||||
elif args.command == "export":
|
||||
cmd_export(args, config)
|
||||
elif args.command == "cert":
|
||||
cmd_cert(args, config)
|
||||
elif args.command == "pki":
|
||||
|
||||
@@ -2,14 +2,37 @@
|
||||
|
||||
import pytest
|
||||
|
||||
import app.database as db_module
|
||||
from app import create_app
|
||||
from app.api.routes import reset_rate_limits
|
||||
|
||||
|
||||
def _clear_database():
|
||||
"""Clear all data from database tables for test isolation."""
|
||||
if db_module._memory_db_holder is not None:
|
||||
db_module._memory_db_holder.execute("DELETE FROM pastes")
|
||||
db_module._memory_db_holder.execute("DELETE FROM content_hashes")
|
||||
db_module._memory_db_holder.execute("DELETE FROM issued_certificates")
|
||||
db_module._memory_db_holder.execute("DELETE FROM certificate_authority")
|
||||
db_module._memory_db_holder.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create application for testing."""
|
||||
app = create_app("testing")
|
||||
yield app
|
||||
# Reset global state for test isolation
|
||||
reset_rate_limits()
|
||||
_clear_database()
|
||||
|
||||
test_app = create_app("testing")
|
||||
|
||||
# Clear database again after app init (in case init added anything)
|
||||
_clear_database()
|
||||
|
||||
yield test_app
|
||||
|
||||
# Cleanup after test
|
||||
reset_rate_limits()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
335
tests/test_paste_listing.py
Normal file
335
tests/test_paste_listing.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for paste listing endpoint (GET /pastes)."""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class TestPastesListEndpoint:
|
||||
"""Tests for GET /pastes endpoint."""
|
||||
|
||||
def test_list_pastes_requires_auth(self, client):
|
||||
"""List pastes requires authentication."""
|
||||
response = client.get("/pastes")
|
||||
assert response.status_code == 401
|
||||
data = json.loads(response.data)
|
||||
assert "error" in data
|
||||
|
||||
def test_list_pastes_empty(self, client, auth_header):
|
||||
"""List pastes returns empty when user has no pastes."""
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["pastes"] == []
|
||||
assert data["count"] == 0
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_list_pastes_returns_own_pastes(self, client, sample_text, auth_header):
|
||||
"""List pastes returns only user's own pastes."""
|
||||
# Create a paste
|
||||
create = client.post(
|
||||
"/",
|
||||
data=sample_text,
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# List pastes
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 1
|
||||
assert data["total"] == 1
|
||||
assert data["pastes"][0]["id"] == paste_id
|
||||
|
||||
def test_list_pastes_excludes_others(self, client, sample_text, auth_header, other_auth_header):
|
||||
"""List pastes does not include other users' pastes."""
|
||||
# Create paste as user A
|
||||
client.post(
|
||||
"/",
|
||||
data=sample_text,
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
|
||||
# List pastes as user B
|
||||
response = client.get("/pastes", headers=other_auth_header)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 0
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_list_pastes_excludes_anonymous(self, client, sample_text, auth_header):
|
||||
"""List pastes does not include anonymous pastes."""
|
||||
# Create anonymous paste
|
||||
client.post("/", data=sample_text, content_type="text/plain")
|
||||
|
||||
# List pastes as authenticated user
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 0
|
||||
|
||||
def test_list_pastes_metadata_only(self, client, sample_text, auth_header):
|
||||
"""List pastes returns metadata, not content."""
|
||||
# Create a paste
|
||||
client.post(
|
||||
"/",
|
||||
data=sample_text,
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
|
||||
# List pastes
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
paste = data["pastes"][0]
|
||||
|
||||
# Verify metadata fields
|
||||
assert "id" in paste
|
||||
assert "mime_type" in paste
|
||||
assert "size" in paste
|
||||
assert "created_at" in paste
|
||||
assert "last_accessed" in paste
|
||||
assert "url" in paste
|
||||
assert "raw" in paste
|
||||
|
||||
# Verify content is NOT included
|
||||
assert "content" not in paste
|
||||
|
||||
def test_list_pastes_pagination(self, client, auth_header):
|
||||
"""List pastes supports pagination."""
|
||||
# Create multiple pastes
|
||||
for i in range(5):
|
||||
client.post(
|
||||
"/",
|
||||
data=f"paste {i}",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
|
||||
# Get first page
|
||||
response = client.get("/pastes?limit=2&offset=0", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 2
|
||||
assert data["total"] == 5
|
||||
assert data["limit"] == 2
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Get second page
|
||||
response = client.get("/pastes?limit=2&offset=2", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 2
|
||||
assert data["offset"] == 2
|
||||
|
||||
def test_list_pastes_max_limit(self, client, auth_header):
|
||||
"""List pastes enforces maximum limit."""
|
||||
response = client.get("/pastes?limit=500", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["limit"] == 200 # Max limit enforced
|
||||
|
||||
def test_list_pastes_invalid_pagination(self, client, auth_header):
|
||||
"""List pastes handles invalid pagination gracefully."""
|
||||
response = client.get("/pastes?limit=abc&offset=-1", headers=auth_header)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
# Should use defaults
|
||||
assert data["limit"] == 50
|
||||
assert data["offset"] == 0
|
||||
|
||||
def test_list_pastes_includes_special_fields(self, client, auth_header):
|
||||
"""List pastes includes burn_after_read, expires_at, password_protected."""
|
||||
# Create paste with burn-after-read
|
||||
client.post(
|
||||
"/",
|
||||
data="burn test",
|
||||
content_type="text/plain",
|
||||
headers={**auth_header, "X-Burn-After-Read": "true"},
|
||||
)
|
||||
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
paste = data["pastes"][0]
|
||||
assert paste.get("burn_after_read") is True
|
||||
|
||||
def test_list_pastes_ordered_by_created_at(self, client, auth_header):
|
||||
"""List pastes returns all created pastes ordered by created_at DESC."""
|
||||
# Create pastes
|
||||
ids = set()
|
||||
for i in range(3):
|
||||
create = client.post(
|
||||
"/",
|
||||
data=f"paste {i}",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
ids.add(json.loads(create.data)["id"])
|
||||
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
|
||||
# All created pastes should be present
|
||||
returned_ids = {p["id"] for p in data["pastes"]}
|
||||
assert returned_ids == ids
|
||||
assert data["count"] == 3
|
||||
|
||||
|
||||
class TestPastesPrivacy:
|
||||
"""Privacy-focused tests for paste listing."""
|
||||
|
||||
def test_cannot_see_other_user_pastes(
|
||||
self, client, sample_text, auth_header, other_auth_header
|
||||
):
|
||||
"""Users cannot see pastes owned by others."""
|
||||
# User A creates paste
|
||||
create = client.post(
|
||||
"/",
|
||||
data=sample_text,
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# User B lists pastes - should not see A's paste
|
||||
response = client.get("/pastes", headers=other_auth_header)
|
||||
data = json.loads(response.data)
|
||||
paste_ids = [p["id"] for p in data["pastes"]]
|
||||
assert paste_id not in paste_ids
|
||||
|
||||
def test_no_admin_bypass(self, client, sample_text, auth_header):
|
||||
"""No special admin access to list all pastes."""
|
||||
# Create paste as regular user
|
||||
client.post(
|
||||
"/",
|
||||
data=sample_text,
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
|
||||
# Different user cannot see it, regardless of auth header format
|
||||
admin_header = {"X-SSL-Client-SHA1": "0" * 40}
|
||||
response = client.get("/pastes", headers=admin_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 0
|
||||
|
||||
def test_content_never_exposed(self, client, auth_header):
|
||||
"""Paste content is never exposed in listing."""
|
||||
secret = "super secret content that should never be exposed"
|
||||
client.post(
|
||||
"/",
|
||||
data=secret,
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
|
||||
response = client.get("/pastes", headers=auth_header)
|
||||
# Content should not appear anywhere in response
|
||||
assert secret.encode() not in response.data
|
||||
|
||||
|
||||
class TestPastesSearch:
|
||||
"""Tests for paste search parameters."""
|
||||
|
||||
def test_search_by_type_exact(self, client, auth_header, png_bytes):
|
||||
"""Search pastes by exact MIME type."""
|
||||
# Create text paste
|
||||
client.post("/", data="text content", content_type="text/plain", headers=auth_header)
|
||||
# Create image paste
|
||||
create = client.post("/", data=png_bytes, content_type="image/png", headers=auth_header)
|
||||
png_id = json.loads(create.data)["id"]
|
||||
|
||||
# Search for image/png
|
||||
response = client.get("/pastes?type=image/png", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 1
|
||||
assert data["pastes"][0]["id"] == png_id
|
||||
|
||||
def test_search_by_type_glob(self, client, auth_header, png_bytes, jpeg_bytes):
|
||||
"""Search pastes by MIME type glob pattern."""
|
||||
# Create text paste
|
||||
client.post("/", data="text content", content_type="text/plain", headers=auth_header)
|
||||
# Create image pastes
|
||||
client.post("/", data=png_bytes, content_type="image/png", headers=auth_header)
|
||||
client.post("/", data=jpeg_bytes, content_type="image/jpeg", headers=auth_header)
|
||||
|
||||
# Search for all images
|
||||
response = client.get("/pastes?type=image/*", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 2
|
||||
for paste in data["pastes"]:
|
||||
assert paste["mime_type"].startswith("image/")
|
||||
|
||||
def test_search_by_after_timestamp(self, client, auth_header):
|
||||
"""Search pastes created after timestamp."""
|
||||
|
||||
# Create paste
|
||||
create = client.post("/", data="test", content_type="text/plain", headers=auth_header)
|
||||
paste_data = json.loads(create.data)
|
||||
created_at = paste_data["created_at"]
|
||||
|
||||
# Search for pastes after creation time (should find it)
|
||||
response = client.get(f"/pastes?after={created_at - 1}", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 1
|
||||
|
||||
# Search for pastes after creation time + 1 (should not find it)
|
||||
response = client.get(f"/pastes?after={created_at + 1}", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 0
|
||||
|
||||
def test_search_by_before_timestamp(self, client, auth_header):
|
||||
"""Search pastes created before timestamp."""
|
||||
|
||||
# Create paste
|
||||
create = client.post("/", data="test", content_type="text/plain", headers=auth_header)
|
||||
paste_data = json.loads(create.data)
|
||||
created_at = paste_data["created_at"]
|
||||
|
||||
# Search for pastes before creation time + 1 (should find it)
|
||||
response = client.get(f"/pastes?before={created_at + 1}", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 1
|
||||
|
||||
# Search for pastes before creation time (should not find it)
|
||||
response = client.get(f"/pastes?before={created_at - 1}", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 0
|
||||
|
||||
def test_search_combined_filters(self, client, auth_header, png_bytes):
|
||||
"""Search with multiple filters combined."""
|
||||
|
||||
# Create text paste
|
||||
client.post("/", data="text", content_type="text/plain", headers=auth_header)
|
||||
# Create image paste
|
||||
create = client.post("/", data=png_bytes, content_type="image/png", headers=auth_header)
|
||||
png_data = json.loads(create.data)
|
||||
created_at = png_data["created_at"]
|
||||
|
||||
# Search for images after a certain time
|
||||
response = client.get(
|
||||
f"/pastes?type=image/*&after={created_at - 1}",
|
||||
headers=auth_header,
|
||||
)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 1
|
||||
assert data["pastes"][0]["mime_type"] == "image/png"
|
||||
|
||||
def test_search_no_matches(self, client, auth_header):
|
||||
"""Search with no matching results."""
|
||||
# Create text paste
|
||||
client.post("/", data="text", content_type="text/plain", headers=auth_header)
|
||||
|
||||
# Search for video (no matches)
|
||||
response = client.get("/pastes?type=video/*", headers=auth_header)
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 0
|
||||
assert data["pastes"] == []
|
||||
|
||||
def test_search_invalid_timestamp(self, client, auth_header):
|
||||
"""Search with invalid timestamp uses default."""
|
||||
client.post("/", data="test", content_type="text/plain", headers=auth_header)
|
||||
|
||||
# Invalid timestamp should be ignored
|
||||
response = client.get("/pastes?after=invalid", headers=auth_header)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["count"] == 1
|
||||
242
tests/test_paste_update.py
Normal file
242
tests/test_paste_update.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Tests for paste update endpoint (PUT /<id>)."""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class TestPasteUpdateEndpoint:
|
||||
"""Tests for PUT /<id> endpoint."""
|
||||
|
||||
def test_update_requires_auth(self, client, sample_text, auth_header):
|
||||
"""Update requires authentication."""
|
||||
# Create paste
|
||||
create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Try to update without auth
|
||||
response = client.put(f"/{paste_id}", data="updated")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_update_requires_ownership(self, client, sample_text, auth_header, other_auth_header):
|
||||
"""Update requires paste ownership."""
|
||||
# Create paste as user A
|
||||
create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Try to update as user B
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="updated content",
|
||||
content_type="text/plain",
|
||||
headers=other_auth_header,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_update_content(self, client, auth_header):
|
||||
"""Update paste content."""
|
||||
# Create paste
|
||||
create = client.post("/", data="original", content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Update content
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="updated content",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["size"] == len("updated content")
|
||||
|
||||
# Verify content changed
|
||||
raw = client.get(f"/{paste_id}/raw")
|
||||
assert raw.data == b"updated content"
|
||||
|
||||
def test_update_password_set(self, client, auth_header):
|
||||
"""Set password on paste."""
|
||||
# Create paste without password
|
||||
create = client.post("/", data="content", content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Add password
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="",
|
||||
headers={**auth_header, "X-Paste-Password": "secret123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data.get("password_protected") is True
|
||||
|
||||
# Verify password required
|
||||
raw = client.get(f"/{paste_id}/raw")
|
||||
assert raw.status_code == 401
|
||||
|
||||
# Verify correct password works
|
||||
raw = client.get(f"/{paste_id}/raw", headers={"X-Paste-Password": "secret123"})
|
||||
assert raw.status_code == 200
|
||||
|
||||
def test_update_password_remove(self, client, auth_header):
|
||||
"""Remove password from paste."""
|
||||
# Create paste with password
|
||||
create = client.post(
|
||||
"/",
|
||||
data="content",
|
||||
content_type="text/plain",
|
||||
headers={**auth_header, "X-Paste-Password": "secret123"},
|
||||
)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Remove password
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="",
|
||||
headers={**auth_header, "X-Remove-Password": "true"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data.get("password_protected") is not True
|
||||
|
||||
# Verify no password required
|
||||
raw = client.get(f"/{paste_id}/raw")
|
||||
assert raw.status_code == 200
|
||||
|
||||
def test_update_extend_expiry(self, client, auth_header):
|
||||
"""Extend paste expiry."""
|
||||
# Create paste with expiry
|
||||
create = client.post(
|
||||
"/",
|
||||
data="content",
|
||||
content_type="text/plain",
|
||||
headers={**auth_header, "X-Expiry": "3600"},
|
||||
)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
original_expiry = json.loads(create.data)["expires_at"]
|
||||
|
||||
# Extend expiry
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="",
|
||||
headers={**auth_header, "X-Extend-Expiry": "7200"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["expires_at"] == original_expiry + 7200
|
||||
|
||||
def test_update_add_expiry(self, client, auth_header):
|
||||
"""Add expiry to paste without one."""
|
||||
# Create paste without expiry
|
||||
create = client.post("/", data="content", content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Add expiry
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="",
|
||||
headers={**auth_header, "X-Extend-Expiry": "3600"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert "expires_at" in data
|
||||
|
||||
def test_update_burn_after_read_forbidden(self, client, auth_header):
|
||||
"""Cannot update burn-after-read pastes."""
|
||||
# Create burn-after-read paste
|
||||
create = client.post(
|
||||
"/",
|
||||
data="content",
|
||||
content_type="text/plain",
|
||||
headers={**auth_header, "X-Burn-After-Read": "true"},
|
||||
)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Try to update
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="updated",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert "burn" in data["error"].lower()
|
||||
|
||||
def test_update_not_found(self, client, auth_header):
|
||||
"""Update non-existent paste returns 404."""
|
||||
response = client.put(
|
||||
"/000000000000",
|
||||
data="updated",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_no_changes(self, client, auth_header):
|
||||
"""Update with no changes returns error."""
|
||||
# Create paste
|
||||
create = client.post("/", data="content", content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Try to update with empty request
|
||||
response = client.put(f"/{paste_id}", data="", headers=auth_header)
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert "no updates" in data["error"].lower()
|
||||
|
||||
def test_update_combined(self, client, auth_header):
|
||||
"""Update content and metadata together."""
|
||||
# Create paste
|
||||
create = client.post("/", data="original", content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Update content and add password
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="new content",
|
||||
content_type="text/plain",
|
||||
headers={**auth_header, "X-Paste-Password": "secret"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["size"] == len("new content")
|
||||
assert data.get("password_protected") is True
|
||||
|
||||
|
||||
class TestPasteUpdatePrivacy:
|
||||
"""Privacy-focused tests for paste update."""
|
||||
|
||||
def test_cannot_update_anonymous_paste(self, client, sample_text, auth_header):
|
||||
"""Cannot update paste without owner."""
|
||||
# Create anonymous paste
|
||||
create = client.post("/", data=sample_text, content_type="text/plain")
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Try to update
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="updated",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_cannot_update_other_user_paste(
|
||||
self, client, sample_text, auth_header, other_auth_header
|
||||
):
|
||||
"""Cannot update paste owned by another user."""
|
||||
# Create paste as user A
|
||||
create = client.post("/", data=sample_text, content_type="text/plain", headers=auth_header)
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Try to update as user B
|
||||
response = client.put(
|
||||
f"/{paste_id}",
|
||||
data="hijacked",
|
||||
content_type="text/plain",
|
||||
headers=other_auth_header,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Verify original content unchanged
|
||||
raw = client.get(f"/{paste_id}/raw")
|
||||
assert raw.data == sample_text.encode()
|
||||
168
tests/test_rate_limiting.py
Normal file
168
tests/test_rate_limiting.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for IP-based rate limiting."""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Tests for rate limiting on paste creation."""
|
||||
|
||||
def test_rate_limit_allows_normal_usage(self, client, sample_text):
|
||||
"""Normal usage within rate limit succeeds."""
|
||||
# TestingConfig has RATE_LIMIT_MAX=100
|
||||
for i in range(5):
|
||||
response = client.post(
|
||||
"/",
|
||||
data=f"paste {i}",
|
||||
content_type="text/plain",
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
def test_rate_limit_exceeded_returns_429(self, client, app):
|
||||
"""Exceeding rate limit returns 429."""
|
||||
# Temporarily lower rate limit for test
|
||||
original_max = app.config["RATE_LIMIT_MAX"]
|
||||
app.config["RATE_LIMIT_MAX"] = 3
|
||||
|
||||
try:
|
||||
# Make requests up to limit
|
||||
for i in range(3):
|
||||
response = client.post(
|
||||
"/",
|
||||
data=f"paste {i}",
|
||||
content_type="text/plain",
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Next request should be rate limited
|
||||
response = client.post(
|
||||
"/",
|
||||
data="one more",
|
||||
content_type="text/plain",
|
||||
)
|
||||
assert response.status_code == 429
|
||||
data = json.loads(response.data)
|
||||
assert "error" in data
|
||||
assert "Rate limit" in data["error"]
|
||||
assert "retry_after" in data
|
||||
finally:
|
||||
app.config["RATE_LIMIT_MAX"] = original_max
|
||||
|
||||
def test_rate_limit_headers(self, client, app):
|
||||
"""Rate limit response includes proper headers."""
|
||||
original_max = app.config["RATE_LIMIT_MAX"]
|
||||
app.config["RATE_LIMIT_MAX"] = 1
|
||||
|
||||
try:
|
||||
# First request succeeds
|
||||
client.post("/", data="first", content_type="text/plain")
|
||||
|
||||
# Second request is rate limited
|
||||
response = client.post("/", data="second", content_type="text/plain")
|
||||
assert response.status_code == 429
|
||||
assert "Retry-After" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert response.headers["X-RateLimit-Remaining"] == "0"
|
||||
finally:
|
||||
app.config["RATE_LIMIT_MAX"] = original_max
|
||||
|
||||
def test_rate_limit_auth_multiplier(self, client, app, auth_header):
|
||||
"""Authenticated users get higher rate limits."""
|
||||
original_max = app.config["RATE_LIMIT_MAX"]
|
||||
original_mult = app.config["RATE_LIMIT_AUTH_MULTIPLIER"]
|
||||
app.config["RATE_LIMIT_MAX"] = 2
|
||||
app.config["RATE_LIMIT_AUTH_MULTIPLIER"] = 3 # 2 * 3 = 6 for auth users
|
||||
|
||||
try:
|
||||
# Authenticated user can make more requests than base limit
|
||||
for i in range(5):
|
||||
response = client.post(
|
||||
"/",
|
||||
data=f"auth {i}",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 6th request should succeed (limit is 2*3=6)
|
||||
response = client.post(
|
||||
"/",
|
||||
data="auth 6",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 7th should fail
|
||||
response = client.post(
|
||||
"/",
|
||||
data="auth 7",
|
||||
content_type="text/plain",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 429
|
||||
finally:
|
||||
app.config["RATE_LIMIT_MAX"] = original_max
|
||||
app.config["RATE_LIMIT_AUTH_MULTIPLIER"] = original_mult
|
||||
|
||||
def test_rate_limit_can_be_disabled(self, client, app):
|
||||
"""Rate limiting can be disabled via config."""
|
||||
original_enabled = app.config["RATE_LIMIT_ENABLED"]
|
||||
original_max = app.config["RATE_LIMIT_MAX"]
|
||||
app.config["RATE_LIMIT_ENABLED"] = False
|
||||
app.config["RATE_LIMIT_MAX"] = 1
|
||||
|
||||
try:
|
||||
# Should be able to make many requests
|
||||
for i in range(5):
|
||||
response = client.post(
|
||||
"/",
|
||||
data=f"paste {i}",
|
||||
content_type="text/plain",
|
||||
)
|
||||
assert response.status_code == 201
|
||||
finally:
|
||||
app.config["RATE_LIMIT_ENABLED"] = original_enabled
|
||||
app.config["RATE_LIMIT_MAX"] = original_max
|
||||
|
||||
def test_rate_limit_only_affects_paste_creation(self, client, app, sample_text):
|
||||
"""Rate limiting only affects POST /, not GET endpoints."""
|
||||
original_max = app.config["RATE_LIMIT_MAX"]
|
||||
app.config["RATE_LIMIT_MAX"] = 2
|
||||
|
||||
try:
|
||||
# Create paste
|
||||
create = client.post("/", data=sample_text, content_type="text/plain")
|
||||
assert create.status_code == 201
|
||||
paste_id = json.loads(create.data)["id"]
|
||||
|
||||
# Use up rate limit
|
||||
client.post("/", data="second", content_type="text/plain")
|
||||
|
||||
# Should be rate limited for creation
|
||||
response = client.post("/", data="third", content_type="text/plain")
|
||||
assert response.status_code == 429
|
||||
|
||||
# But GET should still work
|
||||
response = client.get(f"/{paste_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get(f"/{paste_id}/raw")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
finally:
|
||||
app.config["RATE_LIMIT_MAX"] = original_max
|
||||
|
||||
|
||||
class TestRateLimitCleanup:
|
||||
"""Tests for rate limit cleanup."""
|
||||
|
||||
def test_rate_limit_window_expiry(self, app):
|
||||
"""Rate limit cleanup works with explicit window."""
|
||||
from app.api.routes import cleanup_rate_limits
|
||||
|
||||
# Should work without app context when window is explicit
|
||||
cleaned = cleanup_rate_limits(window=60)
|
||||
assert isinstance(cleaned, int)
|
||||
assert cleaned >= 0
|
||||
169
tests/test_scheduled_cleanup.py
Normal file
169
tests/test_scheduled_cleanup.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for scheduled cleanup functionality."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class TestScheduledCleanup:
|
||||
"""Tests for scheduled cleanup in before_request hook."""
|
||||
|
||||
def test_cleanup_times_reset(self, app, client):
|
||||
"""Verify cleanup times can be reset for testing."""
|
||||
from app.api import _cleanup_times, reset_cleanup_times
|
||||
|
||||
# Set some cleanup times
|
||||
with app.app_context():
|
||||
reset_cleanup_times()
|
||||
for key in _cleanup_times:
|
||||
assert _cleanup_times[key] == 0
|
||||
|
||||
def test_cleanup_runs_on_request(self, app, client, auth_header):
|
||||
"""Verify cleanup runs during request handling."""
|
||||
from app.api import _cleanup_times, reset_cleanup_times
|
||||
|
||||
with app.app_context():
|
||||
reset_cleanup_times()
|
||||
|
||||
# Make a request to trigger cleanup
|
||||
client.get("/health")
|
||||
|
||||
# Cleanup times should be set
|
||||
with app.app_context():
|
||||
for key in _cleanup_times:
|
||||
assert _cleanup_times[key] > 0
|
||||
|
||||
def test_cleanup_respects_intervals(self, app, client):
|
||||
"""Verify cleanup respects configured intervals."""
|
||||
from app.api import _cleanup_times, reset_cleanup_times
|
||||
|
||||
with app.app_context():
|
||||
reset_cleanup_times()
|
||||
|
||||
# First request triggers cleanup
|
||||
client.get("/health")
|
||||
|
||||
with app.app_context():
|
||||
first_times = {k: v for k, v in _cleanup_times.items()}
|
||||
|
||||
# Immediate second request should not reset times
|
||||
client.get("/health")
|
||||
|
||||
with app.app_context():
|
||||
for key in _cleanup_times:
|
||||
assert _cleanup_times[key] == first_times[key]
|
||||
|
||||
def test_expired_paste_cleanup(self, app, client, auth_header):
|
||||
"""Test that expired pastes are cleaned up."""
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
|
||||
# Create paste with short expiry
|
||||
response = client.post(
|
||||
"/",
|
||||
data="test content",
|
||||
content_type="text/plain",
|
||||
headers={**auth_header, "X-Expiry": "1"}, # 1 second
|
||||
)
|
||||
assert response.status_code == 201
|
||||
paste_id = json.loads(response.data)["id"]
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(1.5)
|
||||
|
||||
# Trigger cleanup via database function
|
||||
with app.app_context():
|
||||
from app.database import cleanup_expired_pastes
|
||||
|
||||
count = cleanup_expired_pastes()
|
||||
assert count >= 1
|
||||
|
||||
# Verify paste is gone
|
||||
db = get_db()
|
||||
row = db.execute("SELECT id FROM pastes WHERE id = ?", (paste_id,)).fetchone()
|
||||
assert row is None
|
||||
|
||||
def test_rate_limit_cleanup(self, app, client):
|
||||
"""Test that rate limit entries are cleaned up."""
|
||||
from app.api.routes import (
|
||||
_rate_limit_requests,
|
||||
check_rate_limit,
|
||||
cleanup_rate_limits,
|
||||
reset_rate_limits,
|
||||
)
|
||||
|
||||
with app.app_context():
|
||||
reset_rate_limits()
|
||||
|
||||
# Add some rate limit entries
|
||||
check_rate_limit("192.168.1.1", authenticated=False)
|
||||
check_rate_limit("192.168.1.2", authenticated=False)
|
||||
|
||||
assert len(_rate_limit_requests) == 2
|
||||
|
||||
# Cleanup with very large window (nothing should be removed)
|
||||
count = cleanup_rate_limits(window=3600)
|
||||
assert count == 0
|
||||
|
||||
# Wait a bit and cleanup with tiny window
|
||||
time.sleep(0.1)
|
||||
count = cleanup_rate_limits(window=0) # Immediate cleanup
|
||||
assert count == 2
|
||||
assert len(_rate_limit_requests) == 0
|
||||
|
||||
|
||||
class TestCleanupThreadSafety:
|
||||
"""Tests for thread-safety of cleanup operations."""
|
||||
|
||||
def test_cleanup_lock_exists(self, app):
|
||||
"""Verify cleanup lock exists."""
|
||||
import threading
|
||||
|
||||
from app.api import _cleanup_lock
|
||||
|
||||
assert isinstance(_cleanup_lock, type(threading.Lock()))
|
||||
|
||||
def test_concurrent_cleanup_access(self, app, client):
|
||||
"""Test that concurrent requests don't corrupt cleanup state."""
|
||||
import threading
|
||||
|
||||
from app.api import reset_cleanup_times
|
||||
|
||||
with app.app_context():
|
||||
reset_cleanup_times()
|
||||
|
||||
errors = []
|
||||
results = []
|
||||
|
||||
def make_request():
|
||||
try:
|
||||
resp = client.get("/health")
|
||||
results.append(resp.status_code)
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
# Simulate concurrent requests
|
||||
threads = [threading.Thread(target=make_request) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors
|
||||
assert all(r == 200 for r in results)
|
||||
|
||||
|
||||
class TestCleanupConfiguration:
|
||||
"""Tests for cleanup configuration."""
|
||||
|
||||
def test_cleanup_intervals_configured(self, app):
|
||||
"""Verify cleanup intervals are properly configured."""
|
||||
from app.api import _CLEANUP_INTERVALS
|
||||
|
||||
assert "pastes" in _CLEANUP_INTERVALS
|
||||
assert "hashes" in _CLEANUP_INTERVALS
|
||||
assert "rate_limits" in _CLEANUP_INTERVALS
|
||||
|
||||
# Verify reasonable intervals
|
||||
assert _CLEANUP_INTERVALS["pastes"] >= 60 # At least 1 minute
|
||||
assert _CLEANUP_INTERVALS["hashes"] >= 60
|
||||
assert _CLEANUP_INTERVALS["rate_limits"] >= 60
|
||||
Reference in New Issue
Block a user