CRYPTO-001: Certificate serial collision detection - Add _generate_unique_serial() helper for database-backed PKI - Add _generate_unique_serial() method for in-memory PKI class - Check database for existing serial before certificate issuance - Retry with new random serial if collision detected (max 5 attempts) TIMING-001: Constant-time database lookups for sensitive queries - Add dummy PBKDF2 verification when paste not found - Prevents timing-based enumeration (attackers can't distinguish 'not found' from 'wrong password' by measuring response time)
1188 lines
34 KiB
Python
1188 lines
34 KiB
Python
"""Minimal PKI module for certificate authority and client certificate management.
|
|
|
|
This module provides both standalone PKI functionality and Flask-integrated
|
|
database-backed functions. Core cryptographic functions have no Flask dependencies.
|
|
|
|
Usage (standalone):
|
|
from app.pki import PKI
|
|
|
|
pki = PKI(password="your-ca-password")
|
|
ca_info = pki.generate_ca("My CA")
|
|
cert_info = pki.issue_certificate("client-name")
|
|
|
|
Usage (Flask routes):
|
|
from app.pki import generate_ca, issue_certificate, is_certificate_valid
|
|
|
|
ca_info = generate_ca("My CA", password) # Saves to database
|
|
if is_certificate_valid(fingerprint):
|
|
# Certificate is valid and not revoked
|
|
"""
|
|
|
|
import hashlib
|
|
import secrets
|
|
import time
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
# Cryptography imports (required dependency)
|
|
try:
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
from cryptography.hazmat.primitives.serialization import pkcs12
|
|
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
|
|
|
HAS_CRYPTO = True
|
|
except ImportError:
|
|
HAS_CRYPTO = False
|
|
|
|
|
|
# Constants
|
|
_KDF_ITERATIONS = 600000 # OWASP 2023 recommendation
|
|
_SALT_LENGTH = 32
|
|
_KEY_LENGTH = 32 # AES-256
|
|
_SERIAL_MAX_RETRIES = 5 # CRYPTO-001: Max attempts for unique serial generation
|
|
|
|
|
|
class PKIError(Exception):
|
|
"""Base exception for PKI operations."""
|
|
|
|
pass
|
|
|
|
|
|
class CANotFoundError(PKIError):
|
|
"""CA does not exist."""
|
|
|
|
pass
|
|
|
|
|
|
class CAExistsError(PKIError):
|
|
"""CA already exists."""
|
|
|
|
pass
|
|
|
|
|
|
class CertificateNotFoundError(PKIError):
|
|
"""Certificate not found."""
|
|
|
|
pass
|
|
|
|
|
|
class InvalidPasswordError(PKIError):
|
|
"""Invalid CA password."""
|
|
|
|
pass
|
|
|
|
|
|
def _require_crypto() -> None:
|
|
"""Raise if cryptography package is not available."""
|
|
if not HAS_CRYPTO:
|
|
raise PKIError("PKI requires 'cryptography' package: pip install cryptography")
|
|
|
|
|
|
def derive_key(password: str, salt: bytes) -> bytes:
|
|
"""Derive encryption key from password using PBKDF2.
|
|
|
|
Args:
|
|
password: Password string
|
|
salt: Random salt bytes
|
|
|
|
Returns:
|
|
32-byte derived key for AES-256
|
|
"""
|
|
_require_crypto()
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=_KEY_LENGTH,
|
|
salt=salt,
|
|
iterations=_KDF_ITERATIONS,
|
|
)
|
|
return kdf.derive(password.encode("utf-8"))
|
|
|
|
|
|
def encrypt_private_key(private_key: Any, password: str) -> tuple[bytes, bytes]:
|
|
"""Encrypt private key with password.
|
|
|
|
Args:
|
|
private_key: EC or RSA private key object
|
|
password: Encryption password
|
|
|
|
Returns:
|
|
Tuple of (encrypted_key_bytes, salt)
|
|
"""
|
|
_require_crypto()
|
|
|
|
# Serialize private key to PEM (unencrypted)
|
|
key_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
|
|
# Derive encryption key
|
|
salt = secrets.token_bytes(_SALT_LENGTH)
|
|
encryption_key = derive_key(password, salt)
|
|
|
|
# Encrypt with AES-256-GCM
|
|
nonce = secrets.token_bytes(12)
|
|
aesgcm = AESGCM(encryption_key)
|
|
ciphertext = aesgcm.encrypt(nonce, key_pem, None)
|
|
|
|
# Prepend nonce to ciphertext
|
|
encrypted = nonce + ciphertext
|
|
|
|
return encrypted, salt
|
|
|
|
|
|
def decrypt_private_key(encrypted: bytes, salt: bytes, password: str) -> Any:
|
|
"""Decrypt private key with password.
|
|
|
|
Args:
|
|
encrypted: Encrypted key bytes (nonce + ciphertext)
|
|
salt: Salt used for key derivation
|
|
password: Decryption password
|
|
|
|
Returns:
|
|
Private key object
|
|
|
|
Raises:
|
|
InvalidPasswordError: If password is incorrect
|
|
"""
|
|
_require_crypto()
|
|
|
|
if len(encrypted) < 12:
|
|
raise PKIError("Invalid encrypted key format")
|
|
|
|
# Derive encryption key
|
|
encryption_key = derive_key(password, salt)
|
|
|
|
# Decrypt
|
|
nonce = encrypted[:12]
|
|
ciphertext = encrypted[12:]
|
|
aesgcm = AESGCM(encryption_key)
|
|
|
|
try:
|
|
key_pem = aesgcm.decrypt(nonce, ciphertext, None)
|
|
except Exception:
|
|
raise InvalidPasswordError("Invalid CA password") from None
|
|
|
|
# Load private key
|
|
return serialization.load_pem_private_key(key_pem, password=None)
|
|
|
|
|
|
def calculate_fingerprint(certificate: Any) -> str:
|
|
"""Calculate SHA1 fingerprint of certificate.
|
|
|
|
SHA1 is used here for X.509 certificate fingerprints, which is standard
|
|
practice and not a security concern (fingerprints are identifiers, not
|
|
used for cryptographic security).
|
|
|
|
Args:
|
|
certificate: X.509 certificate object
|
|
|
|
Returns:
|
|
Lowercase hex fingerprint (40 characters)
|
|
"""
|
|
_require_crypto()
|
|
cert_der = certificate.public_bytes(serialization.Encoding.DER)
|
|
# SHA1 fingerprints are industry standard for X.509, not security-relevant
|
|
return hashlib.sha1(cert_der, usedforsecurity=False).hexdigest()
|
|
|
|
|
|
def create_pkcs12(
|
|
private_key: Any,
|
|
certificate: Any,
|
|
ca_certificate: Any,
|
|
friendly_name: str,
|
|
password: bytes | None = None,
|
|
) -> bytes:
|
|
"""Create PKCS#12 bundle containing certificate, private key, and CA cert.
|
|
|
|
Args:
|
|
private_key: Client private key object
|
|
certificate: Client certificate object
|
|
ca_certificate: CA certificate object
|
|
friendly_name: Friendly name for the certificate
|
|
password: Optional password for PKCS#12 encryption (None = no password)
|
|
|
|
Returns:
|
|
PKCS#12 bytes (DER encoded)
|
|
"""
|
|
_require_crypto()
|
|
|
|
# Build encryption algorithm - use strong encryption if password provided
|
|
if password:
|
|
encryption = (
|
|
serialization.BestAvailableEncryption(password)
|
|
if password
|
|
else serialization.NoEncryption()
|
|
)
|
|
else:
|
|
encryption = serialization.NoEncryption()
|
|
|
|
# Serialize to PKCS#12
|
|
p12_data = pkcs12.serialize_key_and_certificates(
|
|
name=friendly_name.encode("utf-8"),
|
|
key=private_key,
|
|
cert=certificate,
|
|
cas=[ca_certificate],
|
|
encryption_algorithm=encryption,
|
|
)
|
|
|
|
return p12_data
|
|
|
|
|
|
class PKI:
|
|
"""Standalone PKI manager for CA and certificate operations.
|
|
|
|
This class provides core PKI functionality without Flask dependencies.
|
|
Use get_pki() for Flask-integrated usage.
|
|
|
|
Args:
|
|
password: CA password for signing operations
|
|
ca_days: CA certificate validity in days (default: 3650)
|
|
cert_days: Client certificate validity in days (default: 365)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
password: str,
|
|
ca_days: int = 3650,
|
|
cert_days: int = 365,
|
|
):
|
|
_require_crypto()
|
|
self.password = password
|
|
self.ca_days = ca_days
|
|
self.cert_days = cert_days
|
|
|
|
# In-memory storage (override for database backing)
|
|
self._ca_store: dict | None = None
|
|
self._certificates: dict[str, dict] = {}
|
|
|
|
def has_ca(self) -> bool:
|
|
"""Check if CA exists."""
|
|
return self._ca_store is not None
|
|
|
|
def _generate_unique_serial(self) -> int:
|
|
"""Generate a unique certificate serial number.
|
|
|
|
CRYPTO-001: Checks existing certificates for collision.
|
|
|
|
Returns:
|
|
Unique serial number as integer
|
|
|
|
Raises:
|
|
PKIError: If unable to generate unique serial after max retries
|
|
"""
|
|
existing_serials = {
|
|
cert["serial"] for cert in self._certificates.values()
|
|
}
|
|
|
|
for _ in range(_SERIAL_MAX_RETRIES):
|
|
serial = x509.random_serial_number()
|
|
serial_hex = format(serial, "032x")
|
|
if serial_hex not in existing_serials:
|
|
return serial
|
|
|
|
raise PKIError("Failed to generate unique serial after max retries")
|
|
|
|
def generate_ca(
|
|
self,
|
|
common_name: str,
|
|
algorithm: str = "ec",
|
|
curve: str = "secp384r1",
|
|
) -> dict:
|
|
"""Generate a new Certificate Authority.
|
|
|
|
Args:
|
|
common_name: CA common name (e.g., "FlaskPaste CA")
|
|
algorithm: Key algorithm ("ec" only for now)
|
|
curve: EC curve name (secp256r1, secp384r1, secp521r1)
|
|
|
|
Returns:
|
|
Dict with CA info: id, common_name, fingerprint, expires_at, certificate_pem
|
|
|
|
Raises:
|
|
CAExistsError: If CA already exists
|
|
"""
|
|
if self.has_ca():
|
|
raise CAExistsError("CA already exists")
|
|
|
|
# Generate EC key
|
|
curves = {
|
|
"secp256r1": ec.SECP256R1(),
|
|
"secp384r1": ec.SECP384R1(),
|
|
"secp521r1": ec.SECP521R1(),
|
|
}
|
|
if curve not in curves:
|
|
raise PKIError(f"Unsupported curve: {curve}")
|
|
|
|
private_key = ec.generate_private_key(curves[curve])
|
|
|
|
# Build CA certificate
|
|
now = datetime.now(UTC)
|
|
subject = issuer = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(private_key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=self.ca_days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=True, path_length=0),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_cert_sign=True,
|
|
crl_sign=True,
|
|
key_encipherment=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(private_key, hashes.SHA256())
|
|
|
|
# Encrypt private key
|
|
encrypted_key, salt = encrypt_private_key(private_key, self.password)
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
|
|
# Store CA
|
|
expires_at = int((now + timedelta(days=self.ca_days)).timestamp())
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
|
|
self._ca_store = {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_encrypted": encrypted_key,
|
|
"key_salt": salt,
|
|
"created_at": int(now.timestamp()),
|
|
"expires_at": expires_at,
|
|
"key_algorithm": f"ec:{curve}",
|
|
"fingerprint": fingerprint,
|
|
"_private_key": private_key, # Cached for signing
|
|
"_certificate": certificate,
|
|
}
|
|
|
|
return {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint": fingerprint,
|
|
"expires_at": expires_at,
|
|
"certificate_pem": cert_pem,
|
|
}
|
|
|
|
def load_ca(self, cert_pem: str, key_pem: str, key_password: str | None = None) -> dict:
|
|
"""Load existing CA from PEM files.
|
|
|
|
Args:
|
|
cert_pem: CA certificate PEM string
|
|
key_pem: CA private key PEM string
|
|
key_password: Password for encrypted private key (if any)
|
|
|
|
Returns:
|
|
Dict with CA info
|
|
"""
|
|
if self.has_ca():
|
|
raise CAExistsError("CA already exists")
|
|
|
|
# Load certificate
|
|
certificate = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"))
|
|
|
|
# Load private key
|
|
pwd = key_password.encode("utf-8") if key_password else None
|
|
private_key = serialization.load_pem_private_key(key_pem.encode("utf-8"), password=pwd)
|
|
|
|
# Extract info
|
|
common_name = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
|
|
# Re-encrypt with our password
|
|
encrypted_key, salt = encrypt_private_key(private_key, self.password)
|
|
|
|
now = datetime.now(UTC)
|
|
expires_at = int(certificate.not_valid_after_utc.timestamp())
|
|
|
|
self._ca_store = {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_encrypted": encrypted_key,
|
|
"key_salt": salt,
|
|
"created_at": int(now.timestamp()),
|
|
"expires_at": expires_at,
|
|
"key_algorithm": "imported",
|
|
"fingerprint": fingerprint,
|
|
"_private_key": private_key,
|
|
"_certificate": certificate,
|
|
}
|
|
|
|
return {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint": fingerprint,
|
|
"expires_at": expires_at,
|
|
"certificate_pem": cert_pem,
|
|
}
|
|
|
|
def get_ca(self) -> dict:
|
|
"""Get CA information.
|
|
|
|
Returns:
|
|
Dict with CA info (without private key)
|
|
|
|
Raises:
|
|
CANotFoundError: If no CA exists
|
|
"""
|
|
if not self.has_ca():
|
|
raise CANotFoundError("No CA configured")
|
|
|
|
assert self._ca_store is not None # narrowing for mypy
|
|
return {
|
|
"id": self._ca_store["id"],
|
|
"common_name": self._ca_store["common_name"],
|
|
"fingerprint": self._ca_store["fingerprint"],
|
|
"expires_at": self._ca_store["expires_at"],
|
|
"certificate_pem": self._ca_store["certificate_pem"],
|
|
}
|
|
|
|
def get_ca_certificate_pem(self) -> str:
|
|
"""Get CA certificate in PEM format.
|
|
|
|
Returns:
|
|
PEM-encoded CA certificate
|
|
"""
|
|
if not self.has_ca():
|
|
raise CANotFoundError("No CA configured")
|
|
assert self._ca_store is not None # narrowing for mypy
|
|
cert_pem: str = self._ca_store["certificate_pem"]
|
|
return cert_pem
|
|
|
|
def _get_signing_key(self) -> tuple[Any, Any]:
|
|
"""Get CA private key and certificate for signing.
|
|
|
|
Returns:
|
|
Tuple of (private_key, certificate)
|
|
"""
|
|
if not self.has_ca():
|
|
raise CANotFoundError("No CA configured")
|
|
|
|
assert self._ca_store is not None # narrowing for mypy
|
|
|
|
# Use cached key if available
|
|
if "_private_key" in self._ca_store:
|
|
return self._ca_store["_private_key"], self._ca_store["_certificate"]
|
|
|
|
# Decrypt private key
|
|
private_key = decrypt_private_key(
|
|
self._ca_store["private_key_encrypted"],
|
|
self._ca_store["key_salt"],
|
|
self.password,
|
|
)
|
|
|
|
# Load certificate
|
|
certificate = x509.load_pem_x509_certificate(
|
|
self._ca_store["certificate_pem"].encode("utf-8")
|
|
)
|
|
|
|
# Cache for future use
|
|
self._ca_store["_private_key"] = private_key
|
|
self._ca_store["_certificate"] = certificate
|
|
|
|
return private_key, certificate
|
|
|
|
def issue_certificate(
|
|
self,
|
|
common_name: str,
|
|
days: int | None = None,
|
|
algorithm: str = "ec",
|
|
curve: str = "secp384r1",
|
|
) -> dict:
|
|
"""Issue a new client certificate.
|
|
|
|
Args:
|
|
common_name: Client common name
|
|
days: Validity period (default: self.cert_days)
|
|
algorithm: Key algorithm ("ec")
|
|
curve: EC curve name
|
|
|
|
Returns:
|
|
Dict with certificate info including private key PEM
|
|
"""
|
|
if days is None:
|
|
days = self.cert_days
|
|
|
|
ca_key, ca_cert = self._get_signing_key()
|
|
assert self._ca_store is not None # narrowing for mypy (validated in _get_signing_key)
|
|
|
|
# Generate client key
|
|
curves = {
|
|
"secp256r1": ec.SECP256R1(),
|
|
"secp384r1": ec.SECP384R1(),
|
|
"secp521r1": ec.SECP521R1(),
|
|
}
|
|
if curve not in curves:
|
|
raise PKIError(f"Unsupported curve: {curve}")
|
|
|
|
client_key = ec.generate_private_key(curves[curve])
|
|
|
|
# Build certificate
|
|
now = datetime.now(UTC)
|
|
# CRYPTO-001: Use collision-safe serial generation
|
|
serial = self._generate_unique_serial()
|
|
|
|
subject = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(ca_cert.subject)
|
|
.public_key(client_key.public_key())
|
|
.serial_number(serial)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=False, path_length=None),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_encipherment=True,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
key_cert_sign=False,
|
|
crl_sign=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]),
|
|
critical=False,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(ca_key, hashes.SHA256())
|
|
|
|
# Serialize
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
key_pem = client_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode("utf-8")
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
serial_hex = format(serial, "032x")
|
|
expires_at = int((now + timedelta(days=days)).timestamp())
|
|
|
|
# Store certificate record
|
|
cert_record = {
|
|
"serial": serial_hex,
|
|
"ca_id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint_sha1": fingerprint,
|
|
"certificate_pem": cert_pem,
|
|
"created_at": int(now.timestamp()),
|
|
"expires_at": expires_at,
|
|
"status": "valid",
|
|
"revoked_at": None,
|
|
}
|
|
self._certificates[fingerprint] = cert_record
|
|
|
|
return {
|
|
"serial": serial_hex,
|
|
"common_name": common_name,
|
|
"fingerprint": fingerprint,
|
|
"expires_at": expires_at,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_pem": key_pem,
|
|
"ca_certificate_pem": self._ca_store["certificate_pem"],
|
|
}
|
|
|
|
def revoke_certificate(self, fingerprint: str) -> bool:
|
|
"""Revoke a certificate by fingerprint.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint (40 hex chars)
|
|
|
|
Returns:
|
|
True if revoked, False if already revoked
|
|
"""
|
|
fingerprint = fingerprint.lower()
|
|
if fingerprint not in self._certificates:
|
|
raise CertificateNotFoundError(f"Certificate not found: {fingerprint}")
|
|
|
|
cert = self._certificates[fingerprint]
|
|
if cert["status"] == "revoked":
|
|
return False
|
|
|
|
cert["status"] = "revoked"
|
|
cert["revoked_at"] = int(time.time())
|
|
return True
|
|
|
|
def is_valid(self, fingerprint: str) -> bool:
|
|
"""Check if fingerprint is valid (exists and not revoked).
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
fingerprint = fingerprint.lower()
|
|
if fingerprint not in self._certificates:
|
|
return True # Unknown fingerprints are allowed (external certs)
|
|
|
|
cert = self._certificates[fingerprint]
|
|
if cert["status"] != "valid":
|
|
return False
|
|
|
|
# Check expiry
|
|
expires_at: int = cert["expires_at"]
|
|
return expires_at >= int(time.time())
|
|
|
|
def get_certificate(self, fingerprint: str) -> dict | None:
|
|
"""Get certificate info by fingerprint.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint
|
|
|
|
Returns:
|
|
Certificate info dict or None if not found
|
|
"""
|
|
fingerprint = fingerprint.lower()
|
|
cert = self._certificates.get(fingerprint)
|
|
if cert is None:
|
|
return None
|
|
|
|
return {
|
|
"serial": cert["serial"],
|
|
"common_name": cert["common_name"],
|
|
"fingerprint": cert["fingerprint_sha1"],
|
|
"expires_at": cert["expires_at"],
|
|
"status": cert["status"],
|
|
"created_at": cert["created_at"],
|
|
"revoked_at": cert["revoked_at"],
|
|
}
|
|
|
|
def list_certificates(self) -> list[dict]:
|
|
"""List all issued certificates.
|
|
|
|
Returns:
|
|
List of certificate info dicts
|
|
"""
|
|
return [
|
|
{
|
|
"serial": c["serial"],
|
|
"common_name": c["common_name"],
|
|
"fingerprint": c["fingerprint_sha1"],
|
|
"expires_at": c["expires_at"],
|
|
"status": c["status"],
|
|
"created_at": c["created_at"],
|
|
}
|
|
for c in self._certificates.values()
|
|
]
|
|
|
|
|
|
def reset_pki() -> None:
|
|
"""Reset PKI state (for testing).
|
|
|
|
This is a no-op since routes use direct database functions.
|
|
Kept for backward compatibility with tests.
|
|
"""
|
|
pass
|
|
|
|
|
|
def is_certificate_valid(fingerprint: str) -> bool:
|
|
"""Check if fingerprint is valid for authentication.
|
|
|
|
This is the main integration point for routes.py.
|
|
Unknown fingerprints are considered valid (external certs).
|
|
Queries database directly to ensure fresh revocation status.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint
|
|
|
|
Returns:
|
|
True if valid or unknown, False if revoked/expired
|
|
"""
|
|
from flask import current_app
|
|
|
|
if not current_app.config.get("PKI_ENABLED"):
|
|
return True
|
|
|
|
from app.database import get_db
|
|
|
|
fingerprint = fingerprint.lower()
|
|
db = get_db()
|
|
|
|
# Query database directly for fresh revocation status
|
|
row = db.execute(
|
|
"SELECT status, expires_at FROM issued_certificates WHERE fingerprint_sha1 = ?",
|
|
(fingerprint,),
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
# Unknown fingerprint (external cert or not issued by us) - allow
|
|
return True
|
|
|
|
# Check status
|
|
if row["status"] != "valid":
|
|
return False
|
|
|
|
# Check expiry
|
|
expires_at: int = row["expires_at"]
|
|
return expires_at >= int(time.time())
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# Flask Route Helper Functions
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def get_ca_info(skip_enabled_check: bool = False) -> dict | None:
|
|
"""Get CA information for status endpoint.
|
|
|
|
Args:
|
|
skip_enabled_check: If True, skip the PKI_ENABLED check (for registration)
|
|
|
|
Returns:
|
|
Dict with CA info or None if no CA exists
|
|
"""
|
|
from flask import current_app
|
|
|
|
from app.database import get_db
|
|
|
|
if not skip_enabled_check and not current_app.config.get("PKI_ENABLED"):
|
|
return None
|
|
|
|
db = get_db()
|
|
row = db.execute(
|
|
"""SELECT id, common_name, certificate_pem, created_at, expires_at, key_algorithm
|
|
FROM certificate_authority WHERE id = 'default'"""
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
return None
|
|
|
|
# Calculate fingerprint
|
|
cert = x509.load_pem_x509_certificate(row["certificate_pem"].encode("utf-8"))
|
|
fingerprint = calculate_fingerprint(cert)
|
|
|
|
return {
|
|
"id": row["id"],
|
|
"common_name": row["common_name"],
|
|
"certificate_pem": row["certificate_pem"],
|
|
"fingerprint_sha1": fingerprint,
|
|
"created_at": row["created_at"],
|
|
"expires_at": row["expires_at"],
|
|
"key_algorithm": row["key_algorithm"],
|
|
}
|
|
|
|
|
|
def _generate_unique_serial(db: Any) -> int:
|
|
"""Generate a unique certificate serial number.
|
|
|
|
CRYPTO-001: Checks database for collision before returning.
|
|
|
|
Args:
|
|
db: Database connection
|
|
|
|
Returns:
|
|
Unique serial number as integer
|
|
|
|
Raises:
|
|
PKIError: If unable to generate unique serial after max retries
|
|
"""
|
|
_require_crypto()
|
|
|
|
for _ in range(_SERIAL_MAX_RETRIES):
|
|
serial = x509.random_serial_number()
|
|
serial_hex = format(serial, "032x")
|
|
|
|
# Check for collision
|
|
existing = db.execute(
|
|
"SELECT 1 FROM issued_certificates WHERE serial = ?", (serial_hex,)
|
|
).fetchone()
|
|
|
|
if existing is None:
|
|
return serial
|
|
|
|
raise PKIError("Failed to generate unique serial after max retries")
|
|
|
|
|
|
def generate_ca(
|
|
common_name: str,
|
|
password: str,
|
|
days: int = 3650,
|
|
owner: str | None = None,
|
|
) -> dict:
|
|
"""Generate CA certificate and save to database.
|
|
|
|
Args:
|
|
common_name: CA common name
|
|
password: CA password for encrypting private key
|
|
days: Validity period in days
|
|
owner: Optional owner fingerprint
|
|
|
|
Returns:
|
|
Dict with CA info including fingerprint_sha1
|
|
|
|
Raises:
|
|
CAExistsError: If CA already exists
|
|
"""
|
|
_require_crypto()
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
|
|
# Check if CA already exists
|
|
existing = db.execute("SELECT id FROM certificate_authority WHERE id = 'default'").fetchone()
|
|
if existing:
|
|
raise CAExistsError("CA already exists")
|
|
|
|
# Generate EC key
|
|
curve = ec.SECP384R1()
|
|
private_key = ec.generate_private_key(curve)
|
|
|
|
# Build CA certificate
|
|
now = datetime.now(UTC)
|
|
subject = issuer = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(private_key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=True, path_length=0),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_cert_sign=True,
|
|
crl_sign=True,
|
|
key_encipherment=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(private_key, hashes.SHA256())
|
|
|
|
# Encrypt private key
|
|
encrypted_key, salt = encrypt_private_key(private_key, password)
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
|
|
# Serialize
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
created_at = int(now.timestamp())
|
|
expires_at = int((now + timedelta(days=days)).timestamp())
|
|
|
|
# Save to database
|
|
db.execute(
|
|
"""INSERT INTO certificate_authority
|
|
(id, common_name, certificate_pem, private_key_encrypted,
|
|
key_salt, created_at, expires_at, key_algorithm, owner)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
"default",
|
|
common_name,
|
|
cert_pem,
|
|
encrypted_key,
|
|
salt,
|
|
created_at,
|
|
expires_at,
|
|
"ec:secp384r1",
|
|
owner,
|
|
),
|
|
)
|
|
db.commit()
|
|
|
|
# Reset cached PKI instance
|
|
reset_pki()
|
|
|
|
return {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint_sha1": fingerprint,
|
|
"certificate_pem": cert_pem,
|
|
"created_at": created_at,
|
|
"expires_at": expires_at,
|
|
}
|
|
|
|
|
|
def issue_certificate(
|
|
common_name: str,
|
|
password: str,
|
|
days: int = 365,
|
|
issued_to: str | None = None,
|
|
is_admin: bool | None = None,
|
|
) -> dict:
|
|
"""Issue a client certificate signed by the CA.
|
|
|
|
Args:
|
|
common_name: Client common name
|
|
password: CA password for signing
|
|
days: Validity period in days
|
|
issued_to: Optional fingerprint of issuing user
|
|
is_admin: Admin flag (None = auto-detect first user)
|
|
|
|
Returns:
|
|
Dict with certificate and private key PEM
|
|
|
|
Raises:
|
|
CANotFoundError: If no CA exists
|
|
"""
|
|
_require_crypto()
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
|
|
# Auto-detect: first registered user becomes admin
|
|
if is_admin is None:
|
|
count = db.execute(
|
|
"SELECT COUNT(*) FROM issued_certificates WHERE status = 'valid'"
|
|
).fetchone()[0]
|
|
is_admin = count == 0
|
|
|
|
# Load CA
|
|
ca_row = db.execute(
|
|
"""SELECT certificate_pem, private_key_encrypted, key_salt
|
|
FROM certificate_authority WHERE id = 'default'"""
|
|
).fetchone()
|
|
|
|
if ca_row is None:
|
|
raise CANotFoundError("No CA configured")
|
|
|
|
# Decrypt CA private key
|
|
ca_key = decrypt_private_key(
|
|
ca_row["private_key_encrypted"],
|
|
ca_row["key_salt"],
|
|
password,
|
|
)
|
|
ca_cert = x509.load_pem_x509_certificate(ca_row["certificate_pem"].encode("utf-8"))
|
|
|
|
# Generate client key
|
|
curve = ec.SECP384R1()
|
|
client_key = ec.generate_private_key(curve)
|
|
|
|
# Build certificate
|
|
now = datetime.now(UTC)
|
|
# CRYPTO-001: Use collision-safe serial generation
|
|
serial = _generate_unique_serial(db)
|
|
|
|
subject = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(ca_cert.subject)
|
|
.public_key(client_key.public_key())
|
|
.serial_number(serial)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=False, path_length=None),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_encipherment=True,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
key_cert_sign=False,
|
|
crl_sign=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]),
|
|
critical=False,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(ca_key, hashes.SHA256())
|
|
|
|
# Serialize
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
key_pem = client_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode("utf-8")
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
serial_hex = format(serial, "032x")
|
|
created_at = int(now.timestamp())
|
|
expires_at = int((now + timedelta(days=days)).timestamp())
|
|
|
|
# Save to database
|
|
db.execute(
|
|
"""INSERT INTO issued_certificates
|
|
(serial, ca_id, common_name, fingerprint_sha1, certificate_pem,
|
|
created_at, expires_at, issued_to, status, revoked_at, is_admin)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
serial_hex,
|
|
"default",
|
|
common_name,
|
|
fingerprint,
|
|
cert_pem,
|
|
created_at,
|
|
expires_at,
|
|
issued_to,
|
|
"valid",
|
|
None,
|
|
1 if is_admin else 0,
|
|
),
|
|
)
|
|
db.commit()
|
|
|
|
return {
|
|
"serial": serial_hex,
|
|
"common_name": common_name,
|
|
"fingerprint_sha1": fingerprint,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_pem": key_pem,
|
|
"created_at": created_at,
|
|
"expires_at": expires_at,
|
|
"is_admin": is_admin,
|
|
}
|
|
|
|
|
|
def is_admin_certificate(fingerprint: str) -> bool:
|
|
"""Check if a certificate fingerprint belongs to an admin.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint of the certificate
|
|
|
|
Returns:
|
|
True if the certificate holder is an admin
|
|
"""
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
row = db.execute(
|
|
"""SELECT is_admin FROM issued_certificates
|
|
WHERE fingerprint_sha1 = ? AND status = 'valid'""",
|
|
(fingerprint,),
|
|
).fetchone()
|
|
|
|
return bool(row and row["is_admin"])
|
|
|
|
|
|
def is_trusted_certificate(fingerprint: str) -> bool:
|
|
"""Check if a certificate is trusted (registered in PKI system).
|
|
|
|
Trusted certificates are those issued by our PKI system and still valid.
|
|
External certificates (valid for auth but not issued by us) are not trusted.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint of the certificate
|
|
|
|
Returns:
|
|
True if the certificate is registered and valid in our PKI
|
|
"""
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
row = db.execute(
|
|
"""SELECT status FROM issued_certificates
|
|
WHERE fingerprint_sha1 = ? AND status = 'valid'""",
|
|
(fingerprint,),
|
|
).fetchone()
|
|
|
|
return row is not None
|
|
|
|
|
|
def revoke_certificate(serial: str) -> bool:
|
|
"""Revoke a certificate by serial number.
|
|
|
|
Args:
|
|
serial: Certificate serial number (hex)
|
|
|
|
Returns:
|
|
True if revoked
|
|
|
|
Raises:
|
|
CertificateNotFoundError: If certificate not found
|
|
"""
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
|
|
# Check exists
|
|
row = db.execute(
|
|
"SELECT status FROM issued_certificates WHERE serial = ?", (serial,)
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
raise CertificateNotFoundError(f"Certificate not found: {serial}")
|
|
|
|
if row["status"] == "revoked":
|
|
return False # Already revoked
|
|
|
|
# Revoke
|
|
db.execute(
|
|
"UPDATE issued_certificates SET status = 'revoked', revoked_at = ? WHERE serial = ?",
|
|
(int(time.time()), serial),
|
|
)
|
|
db.commit()
|
|
|
|
return True
|