Files
flaskpaste/app/__init__.py

331 lines
10 KiB
Python

"""Flask application factory."""
import logging
import os
import sys
import time
import uuid
from flask import Flask, Response, g, request
from werkzeug.exceptions import HTTPException
from app.config import VERSION, config
def setup_logging(app: Flask) -> None:
"""Configure structured logging."""
log_level = logging.DEBUG if app.debug else logging.INFO
if app.debug:
log_format = "%(asctime)s %(levelname)s [%(name)s] %(message)s"
else:
log_format = (
'{"time":"%(asctime)s","level":"%(levelname)s",'
'"logger":"%(name)s","message":"%(message)s"}'
)
logging.basicConfig(
level=log_level,
format=log_format,
stream=sys.stdout,
)
# Reduce noise from werkzeug in production
if not app.debug:
logging.getLogger("werkzeug").setLevel(logging.WARNING)
app.logger.info("FlaskPaste starting", extra={"config": type(app.config).__name__})
def validate_security_config(app: Flask) -> None:
"""Validate security configuration and log warnings.
Checks for common security misconfigurations that could lead to
vulnerabilities in production deployments.
"""
is_production = not app.debug and not app.config.get("TESTING")
# PROXY-001: Check TRUSTED_PROXY_SECRET
proxy_secret = app.config.get("TRUSTED_PROXY_SECRET", "")
if is_production and not proxy_secret:
app.logger.warning(
"SECURITY WARNING: TRUSTED_PROXY_SECRET is not set. "
"Client certificate headers (X-SSL-Client-SHA1) can be spoofed. "
"Set FLASKPASTE_PROXY_SECRET to a shared secret known only by your reverse proxy."
)
# Warn if PKI is enabled without proxy secret
pki_enabled = app.config.get("PKI_ENABLED", False)
if pki_enabled and not proxy_secret:
app.logger.warning(
"SECURITY WARNING: PKI is enabled but TRUSTED_PROXY_SECRET is not set. "
"Certificate-based authentication can be bypassed by spoofing headers."
)
def setup_security_headers(app: Flask) -> None:
"""Add security headers to all responses."""
@app.after_request
def add_security_headers(response: Response) -> Response:
"""Apply security headers to response.
Headers follow OWASP recommendations for API security.
"""
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Content Security Policy (restrictive for API)
response.headers["Content-Security-Policy"] = "default-src 'none'; frame-ancestors 'none'"
# Permissions policy
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# HSTS - enforce HTTPS (1 year, include subdomains)
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
# Prevent caching of sensitive paste data
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, private"
response.headers["Pragma"] = "no-cache"
return response
def setup_request_id(app: Flask) -> None:
"""Add request ID tracking for log correlation and tracing."""
@app.before_request
def assign_request_id() -> None:
"""Assign unique request ID from header or generate new UUID."""
request_id = request.headers.get("X-Request-ID", "").strip()
if not request_id:
request_id = str(uuid.uuid4())
g.request_id = request_id
@app.after_request
def add_request_id_header(response: Response) -> Response:
"""Echo request ID in response header and log access."""
request_id = getattr(g, "request_id", None)
if request_id:
response.headers["X-Request-ID"] = request_id
# Access logging with request ID
app.logger.info(
"%s %s %s [rid=%s]",
request.method,
request.path,
response.status_code,
request_id or "-",
)
return response
def setup_request_metrics(app: Flask) -> None:
"""Record request duration metrics for Prometheus."""
from app.metrics import observe_request_duration
@app.before_request
def record_request_start() -> None:
"""Record request start time for duration metrics."""
g.request_start_time = time.time()
@app.after_request
def record_request_duration(response: Response) -> Response:
"""Record request duration to Prometheus histogram."""
start_time = getattr(g, "request_start_time", None)
if start_time is not None:
duration = time.time() - start_time
observe_request_duration(
method=request.method,
endpoint=request.path,
status=response.status_code,
duration=duration,
)
return response
def setup_error_handlers(app: Flask) -> None:
"""Register global error handlers with JSON responses."""
import json
@app.errorhandler(400)
def bad_request(error: HTTPException) -> Response:
"""Handle 400 Bad Request errors."""
app.logger.warning("Bad request: %s [rid=%s]", request.path, getattr(g, "request_id", "-"))
return Response(
json.dumps({"error": "Bad request"}),
status=400,
mimetype="application/json",
)
@app.errorhandler(404)
def not_found(error: HTTPException) -> Response:
"""Handle 404 Not Found errors."""
return Response(
json.dumps({"error": "Not found"}),
status=404,
mimetype="application/json",
)
@app.errorhandler(429)
def rate_limit_exceeded(error: HTTPException) -> Response:
"""Handle 429 Too Many Requests errors."""
app.logger.warning(
"Rate limit exceeded: %s from %s [rid=%s]",
request.path,
request.remote_addr,
getattr(g, "request_id", "-"),
)
return Response(
json.dumps({"error": "Rate limit exceeded", "retry_after": error.description}),
status=429,
mimetype="application/json",
)
@app.errorhandler(405)
def method_not_allowed(error: HTTPException) -> Response:
"""Handle 405 Method Not Allowed errors."""
return Response(
json.dumps({"error": "Method not allowed"}),
status=405,
mimetype="application/json",
)
@app.errorhandler(500)
def internal_error(error: HTTPException) -> Response:
"""Handle 500 Internal Server errors."""
app.logger.error(
"Internal error: %s - %s [rid=%s]",
request.path,
str(error),
getattr(g, "request_id", "-"),
)
return Response(
json.dumps({"error": "Internal server error"}),
status=500,
mimetype="application/json",
)
@app.errorhandler(Exception)
def handle_exception(error: Exception) -> Response:
"""Handle unhandled exceptions with generic 500 response."""
# Re-raise HTTP exceptions to their proper handlers
if isinstance(error, HTTPException):
return error # type: ignore[return-value]
app.logger.exception(
"Unhandled exception: %s [rid=%s]", str(error), getattr(g, "request_id", "-")
)
return Response(
json.dumps({"error": "Internal server error"}),
status=500,
mimetype="application/json",
)
def setup_rate_limiting(app: Flask) -> None:
"""Configure rate limiting."""
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
def is_rate_limit_exempt() -> bool:
"""Check if request is exempt from global rate limiting.
Exempt: health endpoint, trusted certificate holders.
"""
prefix = app.config.get("URL_PREFIX", "")
health_path = f"{prefix}/health" if prefix else "/health"
if request.path == health_path:
return True
# Trusted certificate holders bypass rate limiting
try:
from app.api.routes import get_client_id
return get_client_id() is not None
except Exception:
return False
limiter = Limiter(
key_func=get_remote_address,
app=app,
default_limits=["200 per day", "60 per hour"],
storage_uri="memory://",
strategy="fixed-window",
default_limits_exempt_when=is_rate_limit_exempt,
)
# Store limiter on app for use in routes
app.extensions["limiter"] = limiter
def setup_metrics(app: Flask) -> None:
"""Configure Prometheus metrics."""
# Only enable metrics in production
if app.config.get("TESTING"):
return
try:
from prometheus_flask_exporter import PrometheusMetrics
metrics = PrometheusMetrics(app)
# Add app info
metrics.info("flaskpaste_info", "FlaskPaste application info", version=VERSION)
app.extensions["metrics"] = metrics
# Setup custom metrics
from app.metrics import setup_custom_metrics
setup_custom_metrics(app)
except ImportError:
app.logger.warning("prometheus_flask_exporter not available, metrics disabled")
def create_app(config_name: str | None = None) -> Flask:
"""Create and configure the Flask application."""
if config_name is None:
config_name = os.environ.get("FLASK_ENV", "default")
app = Flask(__name__)
app.config.from_object(config[config_name])
# Setup logging first
setup_logging(app)
# Validate security configuration
validate_security_config(app)
# Setup request ID tracking
setup_request_id(app)
# Setup security headers
setup_security_headers(app)
# Setup error handlers
setup_error_handlers(app)
# Setup rate limiting (skip in testing)
if not app.config.get("TESTING"):
setup_rate_limiting(app)
# Setup metrics (skip in testing)
setup_metrics(app)
# Setup request duration metrics (skip in testing)
if not app.config.get("TESTING"):
setup_request_metrics(app)
# Initialize database
from app import database
database.init_app(app)
# Register blueprints
from app.api import bp as api_bp
app.register_blueprint(api_bp)
app.logger.info("FlaskPaste initialized successfully")
return app