"""Flask application factory.""" import logging import os import sys import time import uuid from flask import Flask, Response, g, request from werkzeug.exceptions import HTTPException from app.config import VERSION, config def setup_logging(app: Flask) -> None: """Configure structured logging.""" log_level = logging.DEBUG if app.debug else logging.INFO if app.debug: log_format = "%(asctime)s %(levelname)s [%(name)s] %(message)s" else: log_format = ( '{"time":"%(asctime)s","level":"%(levelname)s",' '"logger":"%(name)s","message":"%(message)s"}' ) logging.basicConfig( level=log_level, format=log_format, stream=sys.stdout, ) # Reduce noise from werkzeug in production if not app.debug: logging.getLogger("werkzeug").setLevel(logging.WARNING) app.logger.info("FlaskPaste starting", extra={"config": type(app.config).__name__}) def validate_security_config(app: Flask) -> None: """Validate security configuration and log warnings. Checks for common security misconfigurations that could lead to vulnerabilities in production deployments. """ is_production = not app.debug and not app.config.get("TESTING") # PROXY-001: Check TRUSTED_PROXY_SECRET proxy_secret = app.config.get("TRUSTED_PROXY_SECRET", "") if is_production and not proxy_secret: app.logger.warning( "SECURITY WARNING: TRUSTED_PROXY_SECRET is not set. " "Client certificate headers (X-SSL-Client-SHA1) can be spoofed. " "Set FLASKPASTE_PROXY_SECRET to a shared secret known only by your reverse proxy." ) # Warn if PKI is enabled without proxy secret pki_enabled = app.config.get("PKI_ENABLED", False) if pki_enabled and not proxy_secret: app.logger.warning( "SECURITY WARNING: PKI is enabled but TRUSTED_PROXY_SECRET is not set. " "Certificate-based authentication can be bypassed by spoofing headers." ) def setup_security_headers(app: Flask) -> None: """Add security headers to all responses.""" @app.after_request def add_security_headers(response: Response) -> Response: """Apply security headers to response. Headers follow OWASP recommendations for API security. """ # Prevent MIME type sniffing response.headers["X-Content-Type-Options"] = "nosniff" # Prevent clickjacking response.headers["X-Frame-Options"] = "DENY" # Referrer policy response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # Content Security Policy (restrictive for API) response.headers["Content-Security-Policy"] = "default-src 'none'; frame-ancestors 'none'" # Permissions policy response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" # HSTS - enforce HTTPS (1 year, include subdomains) response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" # Prevent caching of sensitive paste data response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, private" response.headers["Pragma"] = "no-cache" return response def setup_request_id(app: Flask) -> None: """Add request ID tracking for log correlation and tracing.""" @app.before_request def assign_request_id() -> None: """Assign unique request ID from header or generate new UUID.""" request_id = request.headers.get("X-Request-ID", "").strip() if not request_id: request_id = str(uuid.uuid4()) g.request_id = request_id @app.after_request def add_request_id_header(response: Response) -> Response: """Echo request ID in response header and log access.""" request_id = getattr(g, "request_id", None) if request_id: response.headers["X-Request-ID"] = request_id # Access logging with request ID app.logger.info( "%s %s %s [rid=%s]", request.method, request.path, response.status_code, request_id or "-", ) return response def setup_request_metrics(app: Flask) -> None: """Record request duration metrics for Prometheus.""" from app.metrics import observe_request_duration @app.before_request def record_request_start() -> None: """Record request start time for duration metrics.""" g.request_start_time = time.time() @app.after_request def record_request_duration(response: Response) -> Response: """Record request duration to Prometheus histogram.""" start_time = getattr(g, "request_start_time", None) if start_time is not None: duration = time.time() - start_time observe_request_duration( method=request.method, endpoint=request.path, status=response.status_code, duration=duration, ) return response def setup_error_handlers(app: Flask) -> None: """Register global error handlers with JSON responses.""" import json @app.errorhandler(400) def bad_request(error: HTTPException) -> Response: """Handle 400 Bad Request errors.""" app.logger.warning("Bad request: %s [rid=%s]", request.path, getattr(g, "request_id", "-")) return Response( json.dumps({"error": "Bad request"}), status=400, mimetype="application/json", ) @app.errorhandler(404) def not_found(error: HTTPException) -> Response: """Handle 404 Not Found errors.""" return Response( json.dumps({"error": "Not found"}), status=404, mimetype="application/json", ) @app.errorhandler(429) def rate_limit_exceeded(error: HTTPException) -> Response: """Handle 429 Too Many Requests errors.""" app.logger.warning( "Rate limit exceeded: %s from %s [rid=%s]", request.path, request.remote_addr, getattr(g, "request_id", "-"), ) return Response( json.dumps({"error": "Rate limit exceeded", "retry_after": error.description}), status=429, mimetype="application/json", ) @app.errorhandler(405) def method_not_allowed(error: HTTPException) -> Response: """Handle 405 Method Not Allowed errors.""" return Response( json.dumps({"error": "Method not allowed"}), status=405, mimetype="application/json", ) @app.errorhandler(500) def internal_error(error: HTTPException) -> Response: """Handle 500 Internal Server errors.""" app.logger.error( "Internal error: %s - %s [rid=%s]", request.path, str(error), getattr(g, "request_id", "-"), ) return Response( json.dumps({"error": "Internal server error"}), status=500, mimetype="application/json", ) @app.errorhandler(Exception) def handle_exception(error: Exception) -> Response: """Handle unhandled exceptions with generic 500 response.""" # Re-raise HTTP exceptions to their proper handlers if isinstance(error, HTTPException): return error # type: ignore[return-value] app.logger.exception( "Unhandled exception: %s [rid=%s]", str(error), getattr(g, "request_id", "-") ) return Response( json.dumps({"error": "Internal server error"}), status=500, mimetype="application/json", ) def setup_rate_limiting(app: Flask) -> None: """Configure rate limiting.""" from flask_limiter import Limiter from flask_limiter.util import get_remote_address def is_rate_limit_exempt() -> bool: """Check if request is exempt from global rate limiting. Exempt: health endpoint, trusted certificate holders. """ prefix = app.config.get("URL_PREFIX", "") health_path = f"{prefix}/health" if prefix else "/health" if request.path == health_path: return True # Trusted certificate holders bypass rate limiting try: from app.api.routes import get_client_id return get_client_id() is not None except Exception: return False limiter = Limiter( key_func=get_remote_address, app=app, default_limits=["200 per day", "60 per hour"], storage_uri="memory://", strategy="fixed-window", default_limits_exempt_when=is_rate_limit_exempt, ) # Store limiter on app for use in routes app.extensions["limiter"] = limiter def setup_metrics(app: Flask) -> None: """Configure Prometheus metrics.""" # Only enable metrics in production if app.config.get("TESTING"): return try: from prometheus_flask_exporter import PrometheusMetrics metrics = PrometheusMetrics(app) # Add app info metrics.info("flaskpaste_info", "FlaskPaste application info", version=VERSION) app.extensions["metrics"] = metrics # Setup custom metrics from app.metrics import setup_custom_metrics setup_custom_metrics(app) except ImportError: app.logger.warning("prometheus_flask_exporter not available, metrics disabled") def create_app(config_name: str | None = None) -> Flask: """Create and configure the Flask application.""" if config_name is None: config_name = os.environ.get("FLASK_ENV", "default") app = Flask(__name__) app.config.from_object(config[config_name]) # Setup logging first setup_logging(app) # Validate security configuration validate_security_config(app) # Setup request ID tracking setup_request_id(app) # Setup security headers setup_security_headers(app) # Setup error handlers setup_error_handlers(app) # Setup rate limiting (skip in testing) if not app.config.get("TESTING"): setup_rate_limiting(app) # Setup metrics (skip in testing) setup_metrics(app) # Setup request duration metrics (skip in testing) if not app.config.get("TESTING"): setup_request_metrics(app) # Initialize database from app import database database.init_app(app) # Register blueprints from app.api import bp as api_bp app.register_blueprint(api_bp) app.logger.info("FlaskPaste initialized successfully") return app