forked from username/flaskpaste
331 lines
10 KiB
Python
331 lines
10 KiB
Python
"""Flask application factory."""
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
import uuid
|
|
|
|
from flask import Flask, Response, g, request
|
|
from werkzeug.exceptions import HTTPException
|
|
|
|
from app.config import VERSION, config
|
|
|
|
|
|
def setup_logging(app: Flask) -> None:
|
|
"""Configure structured logging."""
|
|
log_level = logging.DEBUG if app.debug else logging.INFO
|
|
if app.debug:
|
|
log_format = "%(asctime)s %(levelname)s [%(name)s] %(message)s"
|
|
else:
|
|
log_format = (
|
|
'{"time":"%(asctime)s","level":"%(levelname)s",'
|
|
'"logger":"%(name)s","message":"%(message)s"}'
|
|
)
|
|
|
|
logging.basicConfig(
|
|
level=log_level,
|
|
format=log_format,
|
|
stream=sys.stdout,
|
|
)
|
|
|
|
# Reduce noise from werkzeug in production
|
|
if not app.debug:
|
|
logging.getLogger("werkzeug").setLevel(logging.WARNING)
|
|
|
|
app.logger.info("FlaskPaste starting", extra={"config": type(app.config).__name__})
|
|
|
|
|
|
def validate_security_config(app: Flask) -> None:
|
|
"""Validate security configuration and log warnings.
|
|
|
|
Checks for common security misconfigurations that could lead to
|
|
vulnerabilities in production deployments.
|
|
"""
|
|
is_production = not app.debug and not app.config.get("TESTING")
|
|
|
|
# PROXY-001: Check TRUSTED_PROXY_SECRET
|
|
proxy_secret = app.config.get("TRUSTED_PROXY_SECRET", "")
|
|
if is_production and not proxy_secret:
|
|
app.logger.warning(
|
|
"SECURITY WARNING: TRUSTED_PROXY_SECRET is not set. "
|
|
"Client certificate headers (X-SSL-Client-SHA1) can be spoofed. "
|
|
"Set FLASKPASTE_PROXY_SECRET to a shared secret known only by your reverse proxy."
|
|
)
|
|
|
|
# Warn if PKI is enabled without proxy secret
|
|
pki_enabled = app.config.get("PKI_ENABLED", False)
|
|
if pki_enabled and not proxy_secret:
|
|
app.logger.warning(
|
|
"SECURITY WARNING: PKI is enabled but TRUSTED_PROXY_SECRET is not set. "
|
|
"Certificate-based authentication can be bypassed by spoofing headers."
|
|
)
|
|
|
|
|
|
def setup_security_headers(app: Flask) -> None:
|
|
"""Add security headers to all responses."""
|
|
|
|
@app.after_request
|
|
def add_security_headers(response: Response) -> Response:
|
|
"""Apply security headers to response.
|
|
|
|
Headers follow OWASP recommendations for API security.
|
|
"""
|
|
# Prevent MIME type sniffing
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
# Prevent clickjacking
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
# Referrer policy
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
# Content Security Policy (restrictive for API)
|
|
response.headers["Content-Security-Policy"] = "default-src 'none'; frame-ancestors 'none'"
|
|
# Permissions policy
|
|
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
|
# HSTS - enforce HTTPS (1 year, include subdomains)
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
# Prevent caching of sensitive paste data
|
|
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, private"
|
|
response.headers["Pragma"] = "no-cache"
|
|
|
|
return response
|
|
|
|
|
|
def setup_request_id(app: Flask) -> None:
|
|
"""Add request ID tracking for log correlation and tracing."""
|
|
|
|
@app.before_request
|
|
def assign_request_id() -> None:
|
|
"""Assign unique request ID from header or generate new UUID."""
|
|
request_id = request.headers.get("X-Request-ID", "").strip()
|
|
if not request_id:
|
|
request_id = str(uuid.uuid4())
|
|
g.request_id = request_id
|
|
|
|
@app.after_request
|
|
def add_request_id_header(response: Response) -> Response:
|
|
"""Echo request ID in response header and log access."""
|
|
request_id = getattr(g, "request_id", None)
|
|
if request_id:
|
|
response.headers["X-Request-ID"] = request_id
|
|
|
|
# Access logging with request ID
|
|
app.logger.info(
|
|
"%s %s %s [rid=%s]",
|
|
request.method,
|
|
request.path,
|
|
response.status_code,
|
|
request_id or "-",
|
|
)
|
|
return response
|
|
|
|
|
|
def setup_request_metrics(app: Flask) -> None:
|
|
"""Record request duration metrics for Prometheus."""
|
|
from app.metrics import observe_request_duration
|
|
|
|
@app.before_request
|
|
def record_request_start() -> None:
|
|
"""Record request start time for duration metrics."""
|
|
g.request_start_time = time.time()
|
|
|
|
@app.after_request
|
|
def record_request_duration(response: Response) -> Response:
|
|
"""Record request duration to Prometheus histogram."""
|
|
start_time = getattr(g, "request_start_time", None)
|
|
if start_time is not None:
|
|
duration = time.time() - start_time
|
|
observe_request_duration(
|
|
method=request.method,
|
|
endpoint=request.path,
|
|
status=response.status_code,
|
|
duration=duration,
|
|
)
|
|
return response
|
|
|
|
|
|
def setup_error_handlers(app: Flask) -> None:
|
|
"""Register global error handlers with JSON responses."""
|
|
import json
|
|
|
|
@app.errorhandler(400)
|
|
def bad_request(error: HTTPException) -> Response:
|
|
"""Handle 400 Bad Request errors."""
|
|
app.logger.warning("Bad request: %s [rid=%s]", request.path, getattr(g, "request_id", "-"))
|
|
return Response(
|
|
json.dumps({"error": "Bad request"}),
|
|
status=400,
|
|
mimetype="application/json",
|
|
)
|
|
|
|
@app.errorhandler(404)
|
|
def not_found(error: HTTPException) -> Response:
|
|
"""Handle 404 Not Found errors."""
|
|
return Response(
|
|
json.dumps({"error": "Not found"}),
|
|
status=404,
|
|
mimetype="application/json",
|
|
)
|
|
|
|
@app.errorhandler(429)
|
|
def rate_limit_exceeded(error: HTTPException) -> Response:
|
|
"""Handle 429 Too Many Requests errors."""
|
|
app.logger.warning(
|
|
"Rate limit exceeded: %s from %s [rid=%s]",
|
|
request.path,
|
|
request.remote_addr,
|
|
getattr(g, "request_id", "-"),
|
|
)
|
|
return Response(
|
|
json.dumps({"error": "Rate limit exceeded", "retry_after": error.description}),
|
|
status=429,
|
|
mimetype="application/json",
|
|
)
|
|
|
|
@app.errorhandler(405)
|
|
def method_not_allowed(error: HTTPException) -> Response:
|
|
"""Handle 405 Method Not Allowed errors."""
|
|
return Response(
|
|
json.dumps({"error": "Method not allowed"}),
|
|
status=405,
|
|
mimetype="application/json",
|
|
)
|
|
|
|
@app.errorhandler(500)
|
|
def internal_error(error: HTTPException) -> Response:
|
|
"""Handle 500 Internal Server errors."""
|
|
app.logger.error(
|
|
"Internal error: %s - %s [rid=%s]",
|
|
request.path,
|
|
str(error),
|
|
getattr(g, "request_id", "-"),
|
|
)
|
|
return Response(
|
|
json.dumps({"error": "Internal server error"}),
|
|
status=500,
|
|
mimetype="application/json",
|
|
)
|
|
|
|
@app.errorhandler(Exception)
|
|
def handle_exception(error: Exception) -> Response:
|
|
"""Handle unhandled exceptions with generic 500 response."""
|
|
# Re-raise HTTP exceptions to their proper handlers
|
|
if isinstance(error, HTTPException):
|
|
return error # type: ignore[return-value]
|
|
app.logger.exception(
|
|
"Unhandled exception: %s [rid=%s]", str(error), getattr(g, "request_id", "-")
|
|
)
|
|
return Response(
|
|
json.dumps({"error": "Internal server error"}),
|
|
status=500,
|
|
mimetype="application/json",
|
|
)
|
|
|
|
|
|
def setup_rate_limiting(app: Flask) -> None:
|
|
"""Configure rate limiting."""
|
|
from flask_limiter import Limiter
|
|
from flask_limiter.util import get_remote_address
|
|
|
|
def is_rate_limit_exempt() -> bool:
|
|
"""Check if request is exempt from global rate limiting.
|
|
|
|
Exempt: health endpoint, trusted certificate holders.
|
|
"""
|
|
prefix = app.config.get("URL_PREFIX", "")
|
|
health_path = f"{prefix}/health" if prefix else "/health"
|
|
if request.path == health_path:
|
|
return True
|
|
|
|
# Trusted certificate holders bypass rate limiting
|
|
try:
|
|
from app.api.routes import get_client_id
|
|
|
|
return get_client_id() is not None
|
|
except Exception:
|
|
return False
|
|
|
|
limiter = Limiter(
|
|
key_func=get_remote_address,
|
|
app=app,
|
|
default_limits=["200 per day", "60 per hour"],
|
|
storage_uri="memory://",
|
|
strategy="fixed-window",
|
|
default_limits_exempt_when=is_rate_limit_exempt,
|
|
)
|
|
|
|
# Store limiter on app for use in routes
|
|
app.extensions["limiter"] = limiter
|
|
|
|
|
|
def setup_metrics(app: Flask) -> None:
|
|
"""Configure Prometheus metrics."""
|
|
# Only enable metrics in production
|
|
if app.config.get("TESTING"):
|
|
return
|
|
|
|
try:
|
|
from prometheus_flask_exporter import PrometheusMetrics
|
|
|
|
metrics = PrometheusMetrics(app)
|
|
|
|
# Add app info
|
|
metrics.info("flaskpaste_info", "FlaskPaste application info", version=VERSION)
|
|
|
|
app.extensions["metrics"] = metrics
|
|
|
|
# Setup custom metrics
|
|
from app.metrics import setup_custom_metrics
|
|
|
|
setup_custom_metrics(app)
|
|
except ImportError:
|
|
app.logger.warning("prometheus_flask_exporter not available, metrics disabled")
|
|
|
|
|
|
def create_app(config_name: str | None = None) -> Flask:
|
|
"""Create and configure the Flask application."""
|
|
if config_name is None:
|
|
config_name = os.environ.get("FLASK_ENV", "default")
|
|
|
|
app = Flask(__name__)
|
|
app.config.from_object(config[config_name])
|
|
|
|
# Setup logging first
|
|
setup_logging(app)
|
|
|
|
# Validate security configuration
|
|
validate_security_config(app)
|
|
|
|
# Setup request ID tracking
|
|
setup_request_id(app)
|
|
|
|
# Setup security headers
|
|
setup_security_headers(app)
|
|
|
|
# Setup error handlers
|
|
setup_error_handlers(app)
|
|
|
|
# Setup rate limiting (skip in testing)
|
|
if not app.config.get("TESTING"):
|
|
setup_rate_limiting(app)
|
|
|
|
# Setup metrics (skip in testing)
|
|
setup_metrics(app)
|
|
|
|
# Setup request duration metrics (skip in testing)
|
|
if not app.config.get("TESTING"):
|
|
setup_request_metrics(app)
|
|
|
|
# Initialize database
|
|
from app import database
|
|
|
|
database.init_app(app)
|
|
|
|
# Register blueprints
|
|
from app.api import bp as api_bp
|
|
|
|
app.register_blueprint(api_bp)
|
|
|
|
app.logger.info("FlaskPaste initialized successfully")
|
|
|
|
return app
|