fix: add comprehensive type annotations for mypy

- database.py: add type hints for Path, Flask, Any, BaseException
- pki.py: add assertions to narrow Optional types after has_ca() checks
- routes.py: annotate config values to avoid Any return types
- api/__init__.py: use float for cleanup timestamps (time.time())
- __init__.py: remove unused return from setup_rate_limiting
This commit is contained in:
Username
2025-12-22 19:11:11 +01:00
parent 680b068c00
commit ca9342e92d
5 changed files with 36 additions and 20 deletions

View File

@@ -174,8 +174,6 @@ def setup_rate_limiting(app: Flask) -> None:
# Store limiter on app for use in routes # Store limiter on app for use in routes
app.extensions["limiter"] = limiter app.extensions["limiter"] = limiter
return limiter
def setup_metrics(app: Flask) -> None: def setup_metrics(app: Flask) -> None:
"""Configure Prometheus metrics.""" """Configure Prometheus metrics."""

View File

@@ -9,10 +9,10 @@ bp = Blueprint("api", __name__)
# Thread-safe cleanup scheduling # Thread-safe cleanup scheduling
_cleanup_lock = threading.Lock() _cleanup_lock = threading.Lock()
_cleanup_times = { _cleanup_times: dict[str, float] = {
"pastes": 0, "pastes": 0.0,
"hashes": 0, "hashes": 0.0,
"rate_limits": 0, "rate_limits": 0.0,
} }
_CLEANUP_INTERVALS = { _CLEANUP_INTERVALS = {
"pastes": 3600, # 1 hour "pastes": 3600, # 1 hour
@@ -25,7 +25,7 @@ def reset_cleanup_times() -> None:
"""Reset cleanup timestamps. For testing only.""" """Reset cleanup timestamps. For testing only."""
with _cleanup_lock: with _cleanup_lock:
for key in _cleanup_times: for key in _cleanup_times:
_cleanup_times[key] = 0 _cleanup_times[key] = 0.0
@bp.before_request @bp.before_request

View File

@@ -64,11 +64,12 @@ _antiflood_last_increase: float = 0 # Last time difficulty was increased
def get_dynamic_difficulty() -> int: def get_dynamic_difficulty() -> int:
"""Get current PoW difficulty including anti-flood adjustment.""" """Get current PoW difficulty including anti-flood adjustment."""
base = current_app.config["POW_DIFFICULTY"] base: int = current_app.config["POW_DIFFICULTY"]
if base == 0 or not current_app.config.get("ANTIFLOOD_ENABLED", True): if base == 0 or not current_app.config.get("ANTIFLOOD_ENABLED", True):
return base return base
with _antiflood_lock: with _antiflood_lock:
return min(base + _antiflood_difficulty, current_app.config["ANTIFLOOD_MAX"]) max_diff: int = current_app.config["ANTIFLOOD_MAX"]
return min(base + _antiflood_difficulty, max_diff)
def record_antiflood_request() -> None: def record_antiflood_request() -> None:
@@ -236,7 +237,8 @@ def error_response(message: str, status: int, **extra: Any) -> Response:
def url_prefix() -> str: def url_prefix() -> str:
"""Get configured URL prefix for reverse proxy deployments.""" """Get configured URL prefix for reverse proxy deployments."""
return current_app.config.get("URL_PREFIX", "") prefix: str = current_app.config.get("URL_PREFIX", "")
return prefix
def prefixed_url(path: str) -> str: def prefixed_url(path: str) -> str:
@@ -403,7 +405,7 @@ def is_admin() -> bool:
def get_pow_secret() -> bytes: def get_pow_secret() -> bytes:
"""Get or generate PoW signing secret.""" """Get or generate PoW signing secret."""
global _pow_secret_cache global _pow_secret_cache
configured = current_app.config.get("POW_SECRET", "") configured: str = current_app.config.get("POW_SECRET", "")
if configured: if configured:
return configured.encode() return configured.encode()
if _pow_secret_cache is None: if _pow_secret_cache is None:
@@ -1056,6 +1058,7 @@ class RegisterView(MethodView):
# Load CA cert for PKCS#12 (reuse ca_info from above, or refresh if it was just generated) # Load CA cert for PKCS#12 (reuse ca_info from above, or refresh if it was just generated)
if ca_info is None or "certificate_pem" not in ca_info: if ca_info is None or "certificate_pem" not in ca_info:
ca_info = get_ca_info(skip_enabled_check=True) ca_info = get_ca_info(skip_enabled_check=True)
assert ca_info is not None # CA was just generated or exists
ca_cert = x509.load_pem_x509_certificate(ca_info["certificate_pem"].encode()) ca_cert = x509.load_pem_x509_certificate(ca_info["certificate_pem"].encode())
client_cert = x509.load_pem_x509_certificate(cert_info["certificate_pem"].encode()) client_cert = x509.load_pem_x509_certificate(cert_info["certificate_pem"].encode())
client_key = serialization.load_pem_private_key( client_key = serialization.load_pem_private_key(

View File

@@ -1,13 +1,19 @@
"""Database connection and schema management.""" """Database connection and schema management."""
from __future__ import annotations
import hashlib import hashlib
import secrets import secrets
import sqlite3 import sqlite3
import time import time
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any
from flask import current_app, g from flask import current_app, g
if TYPE_CHECKING:
from flask import Flask
SCHEMA = """ SCHEMA = """
CREATE TABLE IF NOT EXISTS pastes ( CREATE TABLE IF NOT EXISTS pastes (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -74,7 +80,7 @@ _HASH_ITERATIONS = 600000 # OWASP 2023 recommendation for PBKDF2-SHA256
_SALT_LENGTH = 32 _SALT_LENGTH = 32
def hash_password(password: str) -> str: def hash_password(password: str) -> str | None:
"""Hash password using PBKDF2-HMAC-SHA256. """Hash password using PBKDF2-HMAC-SHA256.
Returns format: $pbkdf2-sha256$iterations$salt$hash Returns format: $pbkdf2-sha256$iterations$salt$hash
@@ -119,7 +125,7 @@ def verify_password(password: str, password_hash: str) -> bool:
_memory_db_holder = None _memory_db_holder = None
def _get_connection_string(db_path) -> tuple[str, dict]: def _get_connection_string(db_path: str | Path) -> tuple[str, dict[str, Any]]:
"""Get connection string and kwargs for sqlite3.connect.""" """Get connection string and kwargs for sqlite3.connect."""
if isinstance(db_path, Path): if isinstance(db_path, Path):
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
@@ -129,7 +135,7 @@ def _get_connection_string(db_path) -> tuple[str, dict]:
return db_path, {} return db_path, {}
def _is_file_database(db_path) -> bool: def _is_file_database(db_path: str | Path) -> bool:
"""Check if database path refers to a file (not in-memory).""" """Check if database path refers to a file (not in-memory)."""
if isinstance(db_path, Path): if isinstance(db_path, Path):
return True return True
@@ -150,10 +156,11 @@ def get_db() -> sqlite3.Connection:
# WAL mode set in init_db; these optimize per-connection behavior # WAL mode set in init_db; these optimize per-connection behavior
g.db.execute("PRAGMA busy_timeout = 5000") g.db.execute("PRAGMA busy_timeout = 5000")
g.db.execute("PRAGMA synchronous = NORMAL") g.db.execute("PRAGMA synchronous = NORMAL")
return g.db conn: sqlite3.Connection = g.db
return conn
def close_db(exception=None) -> None: def close_db(exception: BaseException | None = None) -> None:
"""Close database connection at end of request.""" """Close database connection at end of request."""
db = g.pop("db", None) db = g.pop("db", None)
if db is not None: if db is not None:
@@ -280,7 +287,7 @@ def check_content_hash(content_hash: str) -> tuple[bool, int]:
return True, current_count return True, current_count
def init_app(app) -> None: def init_app(app: Flask) -> None:
"""Register database functions with Flask app.""" """Register database functions with Flask app."""
app.teardown_appcontext(close_db) app.teardown_appcontext(close_db)

View File

@@ -433,6 +433,7 @@ class PKI:
if not self.has_ca(): if not self.has_ca():
raise CANotFoundError("No CA configured") raise CANotFoundError("No CA configured")
assert self._ca_store is not None # narrowing for mypy
return { return {
"id": self._ca_store["id"], "id": self._ca_store["id"],
"common_name": self._ca_store["common_name"], "common_name": self._ca_store["common_name"],
@@ -449,7 +450,9 @@ class PKI:
""" """
if not self.has_ca(): if not self.has_ca():
raise CANotFoundError("No CA configured") raise CANotFoundError("No CA configured")
return self._ca_store["certificate_pem"] assert self._ca_store is not None # narrowing for mypy
cert_pem: str = self._ca_store["certificate_pem"]
return cert_pem
def _get_signing_key(self) -> tuple[Any, Any]: def _get_signing_key(self) -> tuple[Any, Any]:
"""Get CA private key and certificate for signing. """Get CA private key and certificate for signing.
@@ -460,6 +463,8 @@ class PKI:
if not self.has_ca(): if not self.has_ca():
raise CANotFoundError("No CA configured") raise CANotFoundError("No CA configured")
assert self._ca_store is not None # narrowing for mypy
# Use cached key if available # Use cached key if available
if "_private_key" in self._ca_store: if "_private_key" in self._ca_store:
return self._ca_store["_private_key"], self._ca_store["_certificate"] return self._ca_store["_private_key"], self._ca_store["_certificate"]
@@ -504,6 +509,7 @@ class PKI:
days = self.cert_days days = self.cert_days
ca_key, ca_cert = self._get_signing_key() ca_key, ca_cert = self._get_signing_key()
assert self._ca_store is not None # narrowing for mypy (validated in _get_signing_key)
# Generate client key # Generate client key
curves = { curves = {
@@ -636,7 +642,8 @@ class PKI:
return False return False
# Check expiry # Check expiry
return cert["expires_at"] >= int(time.time()) expires_at: int = cert["expires_at"]
return expires_at >= int(time.time())
def get_certificate(self, fingerprint: str) -> dict | None: def get_certificate(self, fingerprint: str) -> dict | None:
"""Get certificate info by fingerprint. """Get certificate info by fingerprint.
@@ -728,7 +735,8 @@ def is_certificate_valid(fingerprint: str) -> bool:
return False return False
# Check expiry # Check expiry
return row["expires_at"] >= int(time.time()) expires_at: int = row["expires_at"]
return expires_at >= int(time.time())
# ───────────────────────────────────────────────────────────────────────────── # ─────────────────────────────────────────────────────────────────────────────