forked from claw/flaskpaste
fix: add comprehensive type annotations for mypy
- database.py: add type hints for Path, Flask, Any, BaseException - pki.py: add assertions to narrow Optional types after has_ca() checks - routes.py: annotate config values to avoid Any return types - api/__init__.py: use float for cleanup timestamps (time.time()) - __init__.py: remove unused return from setup_rate_limiting
This commit is contained in:
@@ -174,8 +174,6 @@ def setup_rate_limiting(app: Flask) -> None:
|
|||||||
# Store limiter on app for use in routes
|
# Store limiter on app for use in routes
|
||||||
app.extensions["limiter"] = limiter
|
app.extensions["limiter"] = limiter
|
||||||
|
|
||||||
return limiter
|
|
||||||
|
|
||||||
|
|
||||||
def setup_metrics(app: Flask) -> None:
|
def setup_metrics(app: Flask) -> None:
|
||||||
"""Configure Prometheus metrics."""
|
"""Configure Prometheus metrics."""
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ bp = Blueprint("api", __name__)
|
|||||||
|
|
||||||
# Thread-safe cleanup scheduling
|
# Thread-safe cleanup scheduling
|
||||||
_cleanup_lock = threading.Lock()
|
_cleanup_lock = threading.Lock()
|
||||||
_cleanup_times = {
|
_cleanup_times: dict[str, float] = {
|
||||||
"pastes": 0,
|
"pastes": 0.0,
|
||||||
"hashes": 0,
|
"hashes": 0.0,
|
||||||
"rate_limits": 0,
|
"rate_limits": 0.0,
|
||||||
}
|
}
|
||||||
_CLEANUP_INTERVALS = {
|
_CLEANUP_INTERVALS = {
|
||||||
"pastes": 3600, # 1 hour
|
"pastes": 3600, # 1 hour
|
||||||
@@ -25,7 +25,7 @@ def reset_cleanup_times() -> None:
|
|||||||
"""Reset cleanup timestamps. For testing only."""
|
"""Reset cleanup timestamps. For testing only."""
|
||||||
with _cleanup_lock:
|
with _cleanup_lock:
|
||||||
for key in _cleanup_times:
|
for key in _cleanup_times:
|
||||||
_cleanup_times[key] = 0
|
_cleanup_times[key] = 0.0
|
||||||
|
|
||||||
|
|
||||||
@bp.before_request
|
@bp.before_request
|
||||||
|
|||||||
@@ -64,11 +64,12 @@ _antiflood_last_increase: float = 0 # Last time difficulty was increased
|
|||||||
|
|
||||||
def get_dynamic_difficulty() -> int:
|
def get_dynamic_difficulty() -> int:
|
||||||
"""Get current PoW difficulty including anti-flood adjustment."""
|
"""Get current PoW difficulty including anti-flood adjustment."""
|
||||||
base = current_app.config["POW_DIFFICULTY"]
|
base: int = current_app.config["POW_DIFFICULTY"]
|
||||||
if base == 0 or not current_app.config.get("ANTIFLOOD_ENABLED", True):
|
if base == 0 or not current_app.config.get("ANTIFLOOD_ENABLED", True):
|
||||||
return base
|
return base
|
||||||
with _antiflood_lock:
|
with _antiflood_lock:
|
||||||
return min(base + _antiflood_difficulty, current_app.config["ANTIFLOOD_MAX"])
|
max_diff: int = current_app.config["ANTIFLOOD_MAX"]
|
||||||
|
return min(base + _antiflood_difficulty, max_diff)
|
||||||
|
|
||||||
|
|
||||||
def record_antiflood_request() -> None:
|
def record_antiflood_request() -> None:
|
||||||
@@ -236,7 +237,8 @@ def error_response(message: str, status: int, **extra: Any) -> Response:
|
|||||||
|
|
||||||
def url_prefix() -> str:
|
def url_prefix() -> str:
|
||||||
"""Get configured URL prefix for reverse proxy deployments."""
|
"""Get configured URL prefix for reverse proxy deployments."""
|
||||||
return current_app.config.get("URL_PREFIX", "")
|
prefix: str = current_app.config.get("URL_PREFIX", "")
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
def prefixed_url(path: str) -> str:
|
def prefixed_url(path: str) -> str:
|
||||||
@@ -403,7 +405,7 @@ def is_admin() -> bool:
|
|||||||
def get_pow_secret() -> bytes:
|
def get_pow_secret() -> bytes:
|
||||||
"""Get or generate PoW signing secret."""
|
"""Get or generate PoW signing secret."""
|
||||||
global _pow_secret_cache
|
global _pow_secret_cache
|
||||||
configured = current_app.config.get("POW_SECRET", "")
|
configured: str = current_app.config.get("POW_SECRET", "")
|
||||||
if configured:
|
if configured:
|
||||||
return configured.encode()
|
return configured.encode()
|
||||||
if _pow_secret_cache is None:
|
if _pow_secret_cache is None:
|
||||||
@@ -1056,6 +1058,7 @@ class RegisterView(MethodView):
|
|||||||
# Load CA cert for PKCS#12 (reuse ca_info from above, or refresh if it was just generated)
|
# Load CA cert for PKCS#12 (reuse ca_info from above, or refresh if it was just generated)
|
||||||
if ca_info is None or "certificate_pem" not in ca_info:
|
if ca_info is None or "certificate_pem" not in ca_info:
|
||||||
ca_info = get_ca_info(skip_enabled_check=True)
|
ca_info = get_ca_info(skip_enabled_check=True)
|
||||||
|
assert ca_info is not None # CA was just generated or exists
|
||||||
ca_cert = x509.load_pem_x509_certificate(ca_info["certificate_pem"].encode())
|
ca_cert = x509.load_pem_x509_certificate(ca_info["certificate_pem"].encode())
|
||||||
client_cert = x509.load_pem_x509_certificate(cert_info["certificate_pem"].encode())
|
client_cert = x509.load_pem_x509_certificate(cert_info["certificate_pem"].encode())
|
||||||
client_key = serialization.load_pem_private_key(
|
client_key = serialization.load_pem_private_key(
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
"""Database connection and schema management."""
|
"""Database connection and schema management."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
SCHEMA = """
|
SCHEMA = """
|
||||||
CREATE TABLE IF NOT EXISTS pastes (
|
CREATE TABLE IF NOT EXISTS pastes (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
@@ -74,7 +80,7 @@ _HASH_ITERATIONS = 600000 # OWASP 2023 recommendation for PBKDF2-SHA256
|
|||||||
_SALT_LENGTH = 32
|
_SALT_LENGTH = 32
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str | None:
|
||||||
"""Hash password using PBKDF2-HMAC-SHA256.
|
"""Hash password using PBKDF2-HMAC-SHA256.
|
||||||
|
|
||||||
Returns format: $pbkdf2-sha256$iterations$salt$hash
|
Returns format: $pbkdf2-sha256$iterations$salt$hash
|
||||||
@@ -119,7 +125,7 @@ def verify_password(password: str, password_hash: str) -> bool:
|
|||||||
_memory_db_holder = None
|
_memory_db_holder = None
|
||||||
|
|
||||||
|
|
||||||
def _get_connection_string(db_path) -> tuple[str, dict]:
|
def _get_connection_string(db_path: str | Path) -> tuple[str, dict[str, Any]]:
|
||||||
"""Get connection string and kwargs for sqlite3.connect."""
|
"""Get connection string and kwargs for sqlite3.connect."""
|
||||||
if isinstance(db_path, Path):
|
if isinstance(db_path, Path):
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -129,7 +135,7 @@ def _get_connection_string(db_path) -> tuple[str, dict]:
|
|||||||
return db_path, {}
|
return db_path, {}
|
||||||
|
|
||||||
|
|
||||||
def _is_file_database(db_path) -> bool:
|
def _is_file_database(db_path: str | Path) -> bool:
|
||||||
"""Check if database path refers to a file (not in-memory)."""
|
"""Check if database path refers to a file (not in-memory)."""
|
||||||
if isinstance(db_path, Path):
|
if isinstance(db_path, Path):
|
||||||
return True
|
return True
|
||||||
@@ -150,10 +156,11 @@ def get_db() -> sqlite3.Connection:
|
|||||||
# WAL mode set in init_db; these optimize per-connection behavior
|
# WAL mode set in init_db; these optimize per-connection behavior
|
||||||
g.db.execute("PRAGMA busy_timeout = 5000")
|
g.db.execute("PRAGMA busy_timeout = 5000")
|
||||||
g.db.execute("PRAGMA synchronous = NORMAL")
|
g.db.execute("PRAGMA synchronous = NORMAL")
|
||||||
return g.db
|
conn: sqlite3.Connection = g.db
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
def close_db(exception=None) -> None:
|
def close_db(exception: BaseException | None = None) -> None:
|
||||||
"""Close database connection at end of request."""
|
"""Close database connection at end of request."""
|
||||||
db = g.pop("db", None)
|
db = g.pop("db", None)
|
||||||
if db is not None:
|
if db is not None:
|
||||||
@@ -280,7 +287,7 @@ def check_content_hash(content_hash: str) -> tuple[bool, int]:
|
|||||||
return True, current_count
|
return True, current_count
|
||||||
|
|
||||||
|
|
||||||
def init_app(app) -> None:
|
def init_app(app: Flask) -> None:
|
||||||
"""Register database functions with Flask app."""
|
"""Register database functions with Flask app."""
|
||||||
app.teardown_appcontext(close_db)
|
app.teardown_appcontext(close_db)
|
||||||
|
|
||||||
|
|||||||
14
app/pki.py
14
app/pki.py
@@ -433,6 +433,7 @@ class PKI:
|
|||||||
if not self.has_ca():
|
if not self.has_ca():
|
||||||
raise CANotFoundError("No CA configured")
|
raise CANotFoundError("No CA configured")
|
||||||
|
|
||||||
|
assert self._ca_store is not None # narrowing for mypy
|
||||||
return {
|
return {
|
||||||
"id": self._ca_store["id"],
|
"id": self._ca_store["id"],
|
||||||
"common_name": self._ca_store["common_name"],
|
"common_name": self._ca_store["common_name"],
|
||||||
@@ -449,7 +450,9 @@ class PKI:
|
|||||||
"""
|
"""
|
||||||
if not self.has_ca():
|
if not self.has_ca():
|
||||||
raise CANotFoundError("No CA configured")
|
raise CANotFoundError("No CA configured")
|
||||||
return self._ca_store["certificate_pem"]
|
assert self._ca_store is not None # narrowing for mypy
|
||||||
|
cert_pem: str = self._ca_store["certificate_pem"]
|
||||||
|
return cert_pem
|
||||||
|
|
||||||
def _get_signing_key(self) -> tuple[Any, Any]:
|
def _get_signing_key(self) -> tuple[Any, Any]:
|
||||||
"""Get CA private key and certificate for signing.
|
"""Get CA private key and certificate for signing.
|
||||||
@@ -460,6 +463,8 @@ class PKI:
|
|||||||
if not self.has_ca():
|
if not self.has_ca():
|
||||||
raise CANotFoundError("No CA configured")
|
raise CANotFoundError("No CA configured")
|
||||||
|
|
||||||
|
assert self._ca_store is not None # narrowing for mypy
|
||||||
|
|
||||||
# Use cached key if available
|
# Use cached key if available
|
||||||
if "_private_key" in self._ca_store:
|
if "_private_key" in self._ca_store:
|
||||||
return self._ca_store["_private_key"], self._ca_store["_certificate"]
|
return self._ca_store["_private_key"], self._ca_store["_certificate"]
|
||||||
@@ -504,6 +509,7 @@ class PKI:
|
|||||||
days = self.cert_days
|
days = self.cert_days
|
||||||
|
|
||||||
ca_key, ca_cert = self._get_signing_key()
|
ca_key, ca_cert = self._get_signing_key()
|
||||||
|
assert self._ca_store is not None # narrowing for mypy (validated in _get_signing_key)
|
||||||
|
|
||||||
# Generate client key
|
# Generate client key
|
||||||
curves = {
|
curves = {
|
||||||
@@ -636,7 +642,8 @@ class PKI:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Check expiry
|
# Check expiry
|
||||||
return cert["expires_at"] >= int(time.time())
|
expires_at: int = cert["expires_at"]
|
||||||
|
return expires_at >= int(time.time())
|
||||||
|
|
||||||
def get_certificate(self, fingerprint: str) -> dict | None:
|
def get_certificate(self, fingerprint: str) -> dict | None:
|
||||||
"""Get certificate info by fingerprint.
|
"""Get certificate info by fingerprint.
|
||||||
@@ -728,7 +735,8 @@ def is_certificate_valid(fingerprint: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Check expiry
|
# Check expiry
|
||||||
return row["expires_at"] >= int(time.time())
|
expires_at: int = row["expires_at"]
|
||||||
|
return expires_at >= int(time.time())
|
||||||
|
|
||||||
|
|
||||||
# ─────────────────────────────────────────────────────────────────────────────
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|||||||
Reference in New Issue
Block a user