forked from username/flaskpaste
1194 lines
35 KiB
Python
1194 lines
35 KiB
Python
"""Minimal PKI module for certificate authority and client certificate management.
|
|
|
|
This module provides both standalone PKI functionality and Flask-integrated
|
|
database-backed functions. Core cryptographic functions have no Flask dependencies.
|
|
|
|
Usage (standalone):
|
|
from app.pki import PKI
|
|
|
|
pki = PKI(password="your-ca-password")
|
|
ca_info = pki.generate_ca("My CA")
|
|
cert_info = pki.issue_certificate("client-name")
|
|
|
|
Usage (Flask routes):
|
|
from app.pki import generate_ca, issue_certificate, is_certificate_valid
|
|
|
|
ca_info = generate_ca("My CA", password) # Saves to database
|
|
if is_certificate_valid(fingerprint):
|
|
# Certificate is valid and not revoked
|
|
"""
|
|
|
|
import hashlib
|
|
import secrets
|
|
import time
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
# Cryptography imports (required dependency)
|
|
try:
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
from cryptography.hazmat.primitives.serialization import pkcs12
|
|
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
|
|
|
HAS_CRYPTO = True
|
|
except ImportError:
|
|
HAS_CRYPTO = False
|
|
|
|
|
|
# Constants
|
|
_KDF_ITERATIONS = 600000 # OWASP 2023 recommendation
|
|
_SALT_LENGTH = 32
|
|
_KEY_LENGTH = 32 # AES-256
|
|
_SERIAL_MAX_RETRIES = 5 # CRYPTO-001: Max attempts for unique serial generation
|
|
|
|
|
|
class PKIError(Exception):
|
|
"""Base exception for PKI operations."""
|
|
|
|
pass
|
|
|
|
|
|
class CANotFoundError(PKIError):
|
|
"""CA does not exist."""
|
|
|
|
pass
|
|
|
|
|
|
class CAExistsError(PKIError):
|
|
"""CA already exists."""
|
|
|
|
pass
|
|
|
|
|
|
class CertificateNotFoundError(PKIError):
|
|
"""Certificate not found."""
|
|
|
|
pass
|
|
|
|
|
|
class InvalidPasswordError(PKIError):
|
|
"""Invalid CA password."""
|
|
|
|
pass
|
|
|
|
|
|
def _require_crypto() -> None:
|
|
"""Raise if cryptography package is not available."""
|
|
if not HAS_CRYPTO:
|
|
raise PKIError("PKI requires 'cryptography' package: pip install cryptography")
|
|
|
|
|
|
def derive_key(password: str, salt: bytes) -> bytes:
|
|
"""Derive encryption key from password using PBKDF2.
|
|
|
|
Args:
|
|
password: Password string
|
|
salt: Random salt bytes
|
|
|
|
Returns:
|
|
32-byte derived key for AES-256
|
|
"""
|
|
_require_crypto()
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=_KEY_LENGTH,
|
|
salt=salt,
|
|
iterations=_KDF_ITERATIONS,
|
|
)
|
|
return kdf.derive(password.encode("utf-8"))
|
|
|
|
|
|
def encrypt_private_key(private_key: Any, password: str) -> tuple[bytes, bytes]:
|
|
"""Encrypt private key with password.
|
|
|
|
Args:
|
|
private_key: EC or RSA private key object
|
|
password: Encryption password
|
|
|
|
Returns:
|
|
Tuple of (encrypted_key_bytes, salt)
|
|
"""
|
|
_require_crypto()
|
|
|
|
# Serialize private key to PEM (unencrypted)
|
|
key_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
|
|
# Derive encryption key
|
|
salt = secrets.token_bytes(_SALT_LENGTH)
|
|
encryption_key = derive_key(password, salt)
|
|
|
|
# Encrypt with AES-256-GCM
|
|
nonce = secrets.token_bytes(12)
|
|
aesgcm = AESGCM(encryption_key)
|
|
ciphertext = aesgcm.encrypt(nonce, key_pem, None)
|
|
|
|
# Prepend nonce to ciphertext
|
|
encrypted = nonce + ciphertext
|
|
|
|
return encrypted, salt
|
|
|
|
|
|
def decrypt_private_key(encrypted: bytes, salt: bytes, password: str) -> Any:
|
|
"""Decrypt private key with password.
|
|
|
|
Args:
|
|
encrypted: Encrypted key bytes (nonce + ciphertext)
|
|
salt: Salt used for key derivation
|
|
password: Decryption password
|
|
|
|
Returns:
|
|
Private key object
|
|
|
|
Raises:
|
|
InvalidPasswordError: If password is incorrect
|
|
"""
|
|
_require_crypto()
|
|
|
|
if len(encrypted) < 12:
|
|
raise PKIError("Invalid encrypted key format")
|
|
|
|
# Derive encryption key
|
|
encryption_key = derive_key(password, salt)
|
|
|
|
# Decrypt
|
|
nonce = encrypted[:12]
|
|
ciphertext = encrypted[12:]
|
|
aesgcm = AESGCM(encryption_key)
|
|
|
|
try:
|
|
key_pem = aesgcm.decrypt(nonce, ciphertext, None)
|
|
except Exception:
|
|
raise InvalidPasswordError("Invalid CA password") from None
|
|
|
|
# Load private key
|
|
return serialization.load_pem_private_key(key_pem, password=None)
|
|
|
|
|
|
def calculate_fingerprint(certificate: Any) -> str:
|
|
"""Calculate SHA1 fingerprint of certificate.
|
|
|
|
SHA1 is used here for X.509 certificate fingerprints, which is standard
|
|
practice and not a security concern (fingerprints are identifiers, not
|
|
used for cryptographic security).
|
|
|
|
Args:
|
|
certificate: X.509 certificate object
|
|
|
|
Returns:
|
|
Lowercase hex fingerprint (40 characters)
|
|
"""
|
|
_require_crypto()
|
|
cert_der = certificate.public_bytes(serialization.Encoding.DER)
|
|
# SHA1 fingerprints are industry standard for X.509, not security-relevant
|
|
return hashlib.sha1(cert_der, usedforsecurity=False).hexdigest()
|
|
|
|
|
|
def create_pkcs12(
|
|
private_key: Any,
|
|
certificate: Any,
|
|
ca_certificate: Any,
|
|
friendly_name: str,
|
|
password: bytes | None = None,
|
|
) -> bytes:
|
|
"""Create PKCS#12 bundle containing certificate, private key, and CA cert.
|
|
|
|
Args:
|
|
private_key: Client private key object
|
|
certificate: Client certificate object
|
|
ca_certificate: CA certificate object
|
|
friendly_name: Friendly name for the certificate
|
|
password: Optional password for PKCS#12 encryption (None = no password)
|
|
|
|
Returns:
|
|
PKCS#12 bytes (DER encoded)
|
|
"""
|
|
_require_crypto()
|
|
|
|
# Build encryption algorithm - use strong encryption if password provided
|
|
if password:
|
|
encryption = (
|
|
serialization.BestAvailableEncryption(password)
|
|
if password
|
|
else serialization.NoEncryption()
|
|
)
|
|
else:
|
|
encryption = serialization.NoEncryption()
|
|
|
|
# Serialize to PKCS#12
|
|
p12_data = pkcs12.serialize_key_and_certificates(
|
|
name=friendly_name.encode("utf-8"),
|
|
key=private_key,
|
|
cert=certificate,
|
|
cas=[ca_certificate],
|
|
encryption_algorithm=encryption,
|
|
)
|
|
|
|
return p12_data
|
|
|
|
|
|
class PKI:
|
|
"""Standalone PKI manager for CA and certificate operations.
|
|
|
|
This class provides core PKI functionality without Flask dependencies.
|
|
Use get_pki() for Flask-integrated usage.
|
|
|
|
Args:
|
|
password: CA password for signing operations
|
|
ca_days: CA certificate validity in days (default: 3650)
|
|
cert_days: Client certificate validity in days (default: 365)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
password: str,
|
|
ca_days: int = 3650,
|
|
cert_days: int = 365,
|
|
):
|
|
_require_crypto()
|
|
self.password = password
|
|
self.ca_days = ca_days
|
|
self.cert_days = cert_days
|
|
|
|
# In-memory storage (override for database backing)
|
|
self._ca_store: dict | None = None
|
|
self._certificates: dict[str, dict] = {}
|
|
|
|
def has_ca(self) -> bool:
|
|
"""Check if CA exists."""
|
|
return self._ca_store is not None
|
|
|
|
def _generate_unique_serial(self) -> int:
|
|
"""Generate a unique certificate serial number.
|
|
|
|
CRYPTO-001: Checks existing certificates for collision.
|
|
|
|
Returns:
|
|
Unique serial number as integer
|
|
|
|
Raises:
|
|
PKIError: If unable to generate unique serial after max retries
|
|
"""
|
|
existing_serials = {cert["serial"] for cert in self._certificates.values()}
|
|
|
|
for _ in range(_SERIAL_MAX_RETRIES):
|
|
serial = x509.random_serial_number()
|
|
serial_hex = format(serial, "032x")
|
|
if serial_hex not in existing_serials:
|
|
return serial
|
|
|
|
raise PKIError("Failed to generate unique serial after max retries")
|
|
|
|
def generate_ca(
|
|
self,
|
|
common_name: str,
|
|
algorithm: str = "ec",
|
|
curve: str = "secp384r1",
|
|
) -> dict:
|
|
"""Generate a new Certificate Authority.
|
|
|
|
Args:
|
|
common_name: CA common name (e.g., "FlaskPaste CA")
|
|
algorithm: Key algorithm ("ec" only for now)
|
|
curve: EC curve name (secp256r1, secp384r1, secp521r1)
|
|
|
|
Returns:
|
|
Dict with CA info: id, common_name, fingerprint, expires_at, certificate_pem
|
|
|
|
Raises:
|
|
CAExistsError: If CA already exists
|
|
"""
|
|
if self.has_ca():
|
|
raise CAExistsError("CA already exists")
|
|
|
|
# Validate algorithm (only EC supported for now)
|
|
if algorithm != "ec":
|
|
raise PKIError(f"Unsupported algorithm: {algorithm} (only 'ec' supported)")
|
|
|
|
# Generate EC key
|
|
curves = {
|
|
"secp256r1": ec.SECP256R1(),
|
|
"secp384r1": ec.SECP384R1(),
|
|
"secp521r1": ec.SECP521R1(),
|
|
}
|
|
if curve not in curves:
|
|
raise PKIError(f"Unsupported curve: {curve}")
|
|
|
|
private_key = ec.generate_private_key(curves[curve])
|
|
|
|
# Build CA certificate
|
|
now = datetime.now(UTC)
|
|
subject = issuer = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(private_key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=self.ca_days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=True, path_length=0),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_cert_sign=True,
|
|
crl_sign=True,
|
|
key_encipherment=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(private_key, hashes.SHA256())
|
|
|
|
# Encrypt private key
|
|
encrypted_key, salt = encrypt_private_key(private_key, self.password)
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
|
|
# Store CA
|
|
expires_at = int((now + timedelta(days=self.ca_days)).timestamp())
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
|
|
self._ca_store = {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_encrypted": encrypted_key,
|
|
"key_salt": salt,
|
|
"created_at": int(now.timestamp()),
|
|
"expires_at": expires_at,
|
|
"key_algorithm": f"ec:{curve}",
|
|
"fingerprint": fingerprint,
|
|
"_private_key": private_key, # Cached for signing
|
|
"_certificate": certificate,
|
|
}
|
|
|
|
return {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint": fingerprint,
|
|
"expires_at": expires_at,
|
|
"certificate_pem": cert_pem,
|
|
}
|
|
|
|
def load_ca(self, cert_pem: str, key_pem: str, key_password: str | None = None) -> dict:
|
|
"""Load existing CA from PEM files.
|
|
|
|
Args:
|
|
cert_pem: CA certificate PEM string
|
|
key_pem: CA private key PEM string
|
|
key_password: Password for encrypted private key (if any)
|
|
|
|
Returns:
|
|
Dict with CA info
|
|
"""
|
|
if self.has_ca():
|
|
raise CAExistsError("CA already exists")
|
|
|
|
# Load certificate
|
|
certificate = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"))
|
|
|
|
# Load private key
|
|
pwd = key_password.encode("utf-8") if key_password else None
|
|
private_key = serialization.load_pem_private_key(key_pem.encode("utf-8"), password=pwd)
|
|
|
|
# Extract info
|
|
common_name = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
|
|
# Re-encrypt with our password
|
|
encrypted_key, salt = encrypt_private_key(private_key, self.password)
|
|
|
|
now = datetime.now(UTC)
|
|
expires_at = int(certificate.not_valid_after_utc.timestamp())
|
|
|
|
self._ca_store = {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_encrypted": encrypted_key,
|
|
"key_salt": salt,
|
|
"created_at": int(now.timestamp()),
|
|
"expires_at": expires_at,
|
|
"key_algorithm": "imported",
|
|
"fingerprint": fingerprint,
|
|
"_private_key": private_key,
|
|
"_certificate": certificate,
|
|
}
|
|
|
|
return {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint": fingerprint,
|
|
"expires_at": expires_at,
|
|
"certificate_pem": cert_pem,
|
|
}
|
|
|
|
def get_ca(self) -> dict:
|
|
"""Get CA information.
|
|
|
|
Returns:
|
|
Dict with CA info (without private key)
|
|
|
|
Raises:
|
|
CANotFoundError: If no CA exists
|
|
"""
|
|
if not self.has_ca():
|
|
raise CANotFoundError("No CA configured")
|
|
|
|
assert self._ca_store is not None # narrowing for mypy
|
|
return {
|
|
"id": self._ca_store["id"],
|
|
"common_name": self._ca_store["common_name"],
|
|
"fingerprint": self._ca_store["fingerprint"],
|
|
"expires_at": self._ca_store["expires_at"],
|
|
"certificate_pem": self._ca_store["certificate_pem"],
|
|
}
|
|
|
|
def get_ca_certificate_pem(self) -> str:
|
|
"""Get CA certificate in PEM format.
|
|
|
|
Returns:
|
|
PEM-encoded CA certificate
|
|
"""
|
|
if not self.has_ca():
|
|
raise CANotFoundError("No CA configured")
|
|
assert self._ca_store is not None # narrowing for mypy
|
|
cert_pem: str = self._ca_store["certificate_pem"]
|
|
return cert_pem
|
|
|
|
def _get_signing_key(self) -> tuple[Any, Any]:
|
|
"""Get CA private key and certificate for signing.
|
|
|
|
Returns:
|
|
Tuple of (private_key, certificate)
|
|
"""
|
|
if not self.has_ca():
|
|
raise CANotFoundError("No CA configured")
|
|
|
|
assert self._ca_store is not None # narrowing for mypy
|
|
|
|
# Use cached key if available
|
|
if "_private_key" in self._ca_store:
|
|
return self._ca_store["_private_key"], self._ca_store["_certificate"]
|
|
|
|
# Decrypt private key
|
|
private_key = decrypt_private_key(
|
|
self._ca_store["private_key_encrypted"],
|
|
self._ca_store["key_salt"],
|
|
self.password,
|
|
)
|
|
|
|
# Load certificate
|
|
certificate = x509.load_pem_x509_certificate(
|
|
self._ca_store["certificate_pem"].encode("utf-8")
|
|
)
|
|
|
|
# Cache for future use
|
|
self._ca_store["_private_key"] = private_key
|
|
self._ca_store["_certificate"] = certificate
|
|
|
|
return private_key, certificate
|
|
|
|
def issue_certificate(
|
|
self,
|
|
common_name: str,
|
|
days: int | None = None,
|
|
algorithm: str = "ec",
|
|
curve: str = "secp384r1",
|
|
) -> dict:
|
|
"""Issue a new client certificate.
|
|
|
|
Args:
|
|
common_name: Client common name
|
|
days: Validity period (default: self.cert_days)
|
|
algorithm: Key algorithm ("ec")
|
|
curve: EC curve name
|
|
|
|
Returns:
|
|
Dict with certificate info including private key PEM
|
|
"""
|
|
if days is None:
|
|
days = self.cert_days
|
|
|
|
# Validate algorithm (only EC supported for now)
|
|
if algorithm != "ec":
|
|
raise PKIError(f"Unsupported algorithm: {algorithm} (only 'ec' supported)")
|
|
|
|
ca_key, ca_cert = self._get_signing_key()
|
|
assert self._ca_store is not None # narrowing for mypy (validated in _get_signing_key)
|
|
|
|
# Generate client key
|
|
curves = {
|
|
"secp256r1": ec.SECP256R1(),
|
|
"secp384r1": ec.SECP384R1(),
|
|
"secp521r1": ec.SECP521R1(),
|
|
}
|
|
if curve not in curves:
|
|
raise PKIError(f"Unsupported curve: {curve}")
|
|
|
|
client_key = ec.generate_private_key(curves[curve])
|
|
|
|
# Build certificate
|
|
now = datetime.now(UTC)
|
|
# CRYPTO-001: Use collision-safe serial generation
|
|
serial = self._generate_unique_serial()
|
|
|
|
subject = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(ca_cert.subject)
|
|
.public_key(client_key.public_key())
|
|
.serial_number(serial)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=False, path_length=None),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_encipherment=True,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
key_cert_sign=False,
|
|
crl_sign=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]),
|
|
critical=False,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(ca_key, hashes.SHA256())
|
|
|
|
# Serialize
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
key_pem = client_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode("utf-8")
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
serial_hex = format(serial, "032x")
|
|
expires_at = int((now + timedelta(days=days)).timestamp())
|
|
|
|
# Store certificate record
|
|
cert_record = {
|
|
"serial": serial_hex,
|
|
"ca_id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint_sha1": fingerprint,
|
|
"certificate_pem": cert_pem,
|
|
"created_at": int(now.timestamp()),
|
|
"expires_at": expires_at,
|
|
"status": "valid",
|
|
"revoked_at": None,
|
|
}
|
|
self._certificates[fingerprint] = cert_record
|
|
|
|
return {
|
|
"serial": serial_hex,
|
|
"common_name": common_name,
|
|
"fingerprint": fingerprint,
|
|
"expires_at": expires_at,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_pem": key_pem,
|
|
"ca_certificate_pem": self._ca_store["certificate_pem"],
|
|
}
|
|
|
|
def revoke_certificate(self, fingerprint: str) -> bool:
|
|
"""Revoke a certificate by fingerprint.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint (40 hex chars)
|
|
|
|
Returns:
|
|
True if revoked, False if already revoked
|
|
"""
|
|
fingerprint = fingerprint.lower()
|
|
if fingerprint not in self._certificates:
|
|
raise CertificateNotFoundError(f"Certificate not found: {fingerprint}")
|
|
|
|
cert = self._certificates[fingerprint]
|
|
if cert["status"] == "revoked":
|
|
return False
|
|
|
|
cert["status"] = "revoked"
|
|
cert["revoked_at"] = int(time.time())
|
|
return True
|
|
|
|
def is_valid(self, fingerprint: str) -> bool:
|
|
"""Check if fingerprint is valid (exists and not revoked).
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
fingerprint = fingerprint.lower()
|
|
if fingerprint not in self._certificates:
|
|
return True # Unknown fingerprints are allowed (external certs)
|
|
|
|
cert = self._certificates[fingerprint]
|
|
if cert["status"] != "valid":
|
|
return False
|
|
|
|
# Check expiry
|
|
expires_at: int = cert["expires_at"]
|
|
return expires_at >= int(time.time())
|
|
|
|
def get_certificate(self, fingerprint: str) -> dict | None:
|
|
"""Get certificate info by fingerprint.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint
|
|
|
|
Returns:
|
|
Certificate info dict or None if not found
|
|
"""
|
|
fingerprint = fingerprint.lower()
|
|
cert = self._certificates.get(fingerprint)
|
|
if cert is None:
|
|
return None
|
|
|
|
return {
|
|
"serial": cert["serial"],
|
|
"common_name": cert["common_name"],
|
|
"fingerprint": cert["fingerprint_sha1"],
|
|
"expires_at": cert["expires_at"],
|
|
"status": cert["status"],
|
|
"created_at": cert["created_at"],
|
|
"revoked_at": cert["revoked_at"],
|
|
}
|
|
|
|
def list_certificates(self) -> list[dict]:
|
|
"""List all issued certificates.
|
|
|
|
Returns:
|
|
List of certificate info dicts
|
|
"""
|
|
return [
|
|
{
|
|
"serial": c["serial"],
|
|
"common_name": c["common_name"],
|
|
"fingerprint": c["fingerprint_sha1"],
|
|
"expires_at": c["expires_at"],
|
|
"status": c["status"],
|
|
"created_at": c["created_at"],
|
|
}
|
|
for c in self._certificates.values()
|
|
]
|
|
|
|
|
|
def reset_pki() -> None:
|
|
"""Reset PKI state (for testing).
|
|
|
|
This is a no-op since routes use direct database functions.
|
|
Kept for backward compatibility with tests.
|
|
"""
|
|
pass
|
|
|
|
|
|
def is_certificate_valid(fingerprint: str) -> bool:
|
|
"""Check if fingerprint is valid for authentication.
|
|
|
|
This is the main integration point for routes.py.
|
|
Unknown fingerprints are considered valid (external certs).
|
|
Queries database directly to ensure fresh revocation status.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint
|
|
|
|
Returns:
|
|
True if valid or unknown, False if revoked/expired
|
|
"""
|
|
from flask import current_app
|
|
|
|
if not current_app.config.get("PKI_ENABLED"):
|
|
return True
|
|
|
|
from app.database import get_db
|
|
|
|
fingerprint = fingerprint.lower()
|
|
db = get_db()
|
|
|
|
# Query database directly for fresh revocation status
|
|
row = db.execute(
|
|
"SELECT status, expires_at FROM issued_certificates WHERE fingerprint_sha1 = ?",
|
|
(fingerprint,),
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
# Unknown fingerprint (external cert or not issued by us) - allow
|
|
return True
|
|
|
|
# Check status
|
|
if row["status"] != "valid":
|
|
return False
|
|
|
|
# Check expiry
|
|
expires_at: int = row["expires_at"]
|
|
return expires_at >= int(time.time())
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# Flask Route Helper Functions
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def get_ca_info(skip_enabled_check: bool = False) -> dict | None:
|
|
"""Get CA information for status endpoint.
|
|
|
|
Args:
|
|
skip_enabled_check: If True, skip the PKI_ENABLED check (for registration)
|
|
|
|
Returns:
|
|
Dict with CA info or None if no CA exists
|
|
"""
|
|
from flask import current_app
|
|
|
|
from app.database import get_db
|
|
|
|
if not skip_enabled_check and not current_app.config.get("PKI_ENABLED"):
|
|
return None
|
|
|
|
db = get_db()
|
|
row = db.execute(
|
|
"""SELECT id, common_name, certificate_pem, created_at, expires_at, key_algorithm
|
|
FROM certificate_authority WHERE id = 'default'"""
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
return None
|
|
|
|
# Calculate fingerprint
|
|
cert = x509.load_pem_x509_certificate(row["certificate_pem"].encode("utf-8"))
|
|
fingerprint = calculate_fingerprint(cert)
|
|
|
|
return {
|
|
"id": row["id"],
|
|
"common_name": row["common_name"],
|
|
"certificate_pem": row["certificate_pem"],
|
|
"fingerprint_sha1": fingerprint,
|
|
"created_at": row["created_at"],
|
|
"expires_at": row["expires_at"],
|
|
"key_algorithm": row["key_algorithm"],
|
|
}
|
|
|
|
|
|
def _generate_unique_serial(db: Any) -> int:
|
|
"""Generate a unique certificate serial number.
|
|
|
|
CRYPTO-001: Checks database for collision before returning.
|
|
|
|
Args:
|
|
db: Database connection
|
|
|
|
Returns:
|
|
Unique serial number as integer
|
|
|
|
Raises:
|
|
PKIError: If unable to generate unique serial after max retries
|
|
"""
|
|
_require_crypto()
|
|
|
|
for _ in range(_SERIAL_MAX_RETRIES):
|
|
serial = x509.random_serial_number()
|
|
serial_hex = format(serial, "032x")
|
|
|
|
# Check for collision
|
|
existing = db.execute(
|
|
"SELECT 1 FROM issued_certificates WHERE serial = ?", (serial_hex,)
|
|
).fetchone()
|
|
|
|
if existing is None:
|
|
return serial
|
|
|
|
raise PKIError("Failed to generate unique serial after max retries")
|
|
|
|
|
|
def generate_ca(
|
|
common_name: str,
|
|
password: str,
|
|
days: int = 3650,
|
|
owner: str | None = None,
|
|
) -> dict:
|
|
"""Generate CA certificate and save to database.
|
|
|
|
Args:
|
|
common_name: CA common name
|
|
password: CA password for encrypting private key
|
|
days: Validity period in days
|
|
owner: Optional owner fingerprint
|
|
|
|
Returns:
|
|
Dict with CA info including fingerprint_sha1
|
|
|
|
Raises:
|
|
CAExistsError: If CA already exists
|
|
"""
|
|
_require_crypto()
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
|
|
# Check if CA already exists
|
|
existing = db.execute("SELECT id FROM certificate_authority WHERE id = 'default'").fetchone()
|
|
if existing:
|
|
raise CAExistsError("CA already exists")
|
|
|
|
# Generate EC key
|
|
curve = ec.SECP384R1()
|
|
private_key = ec.generate_private_key(curve)
|
|
|
|
# Build CA certificate
|
|
now = datetime.now(UTC)
|
|
subject = issuer = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(private_key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=True, path_length=0),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_cert_sign=True,
|
|
crl_sign=True,
|
|
key_encipherment=False,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(private_key, hashes.SHA256())
|
|
|
|
# Encrypt private key
|
|
encrypted_key, salt = encrypt_private_key(private_key, password)
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
|
|
# Serialize
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
created_at = int(now.timestamp())
|
|
expires_at = int((now + timedelta(days=days)).timestamp())
|
|
|
|
# Save to database
|
|
db.execute(
|
|
"""INSERT INTO certificate_authority
|
|
(id, common_name, certificate_pem, private_key_encrypted,
|
|
key_salt, created_at, expires_at, key_algorithm, owner)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
"default",
|
|
common_name,
|
|
cert_pem,
|
|
encrypted_key,
|
|
salt,
|
|
created_at,
|
|
expires_at,
|
|
"ec:secp384r1",
|
|
owner,
|
|
),
|
|
)
|
|
db.commit()
|
|
|
|
# Reset cached PKI instance
|
|
reset_pki()
|
|
|
|
return {
|
|
"id": "default",
|
|
"common_name": common_name,
|
|
"fingerprint_sha1": fingerprint,
|
|
"certificate_pem": cert_pem,
|
|
"created_at": created_at,
|
|
"expires_at": expires_at,
|
|
}
|
|
|
|
|
|
def issue_certificate(
|
|
common_name: str,
|
|
password: str,
|
|
days: int = 365,
|
|
issued_to: str | None = None,
|
|
is_admin: bool | None = None,
|
|
) -> dict:
|
|
"""Issue a client certificate signed by the CA.
|
|
|
|
Args:
|
|
common_name: Client common name
|
|
password: CA password for signing
|
|
days: Validity period in days
|
|
issued_to: Optional fingerprint of issuing user
|
|
is_admin: Admin flag (None = auto-detect first user)
|
|
|
|
Returns:
|
|
Dict with certificate and private key PEM
|
|
|
|
Raises:
|
|
CANotFoundError: If no CA exists
|
|
"""
|
|
_require_crypto()
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
|
|
# Auto-detect: first registered user becomes admin
|
|
if is_admin is None:
|
|
count = db.execute(
|
|
"SELECT COUNT(*) FROM issued_certificates WHERE status = 'valid'"
|
|
).fetchone()[0]
|
|
is_admin = count == 0
|
|
|
|
# Load CA
|
|
ca_row = db.execute(
|
|
"""SELECT certificate_pem, private_key_encrypted, key_salt
|
|
FROM certificate_authority WHERE id = 'default'"""
|
|
).fetchone()
|
|
|
|
if ca_row is None:
|
|
raise CANotFoundError("No CA configured")
|
|
|
|
# Decrypt CA private key
|
|
ca_key = decrypt_private_key(
|
|
ca_row["private_key_encrypted"],
|
|
ca_row["key_salt"],
|
|
password,
|
|
)
|
|
ca_cert = x509.load_pem_x509_certificate(ca_row["certificate_pem"].encode("utf-8"))
|
|
|
|
# Generate client key
|
|
curve = ec.SECP384R1()
|
|
client_key = ec.generate_private_key(curve)
|
|
|
|
# Build certificate
|
|
now = datetime.now(UTC)
|
|
# CRYPTO-001: Use collision-safe serial generation
|
|
serial = _generate_unique_serial(db)
|
|
|
|
subject = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
|
]
|
|
)
|
|
|
|
cert_builder = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(ca_cert.subject)
|
|
.public_key(client_key.public_key())
|
|
.serial_number(serial)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=days))
|
|
.add_extension(
|
|
x509.BasicConstraints(ca=False, path_length=None),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.KeyUsage(
|
|
digital_signature=True,
|
|
key_encipherment=True,
|
|
content_commitment=False,
|
|
data_encipherment=False,
|
|
key_agreement=False,
|
|
key_cert_sign=False,
|
|
crl_sign=False,
|
|
encipher_only=False,
|
|
decipher_only=False,
|
|
),
|
|
critical=True,
|
|
)
|
|
.add_extension(
|
|
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]),
|
|
critical=False,
|
|
)
|
|
)
|
|
|
|
certificate = cert_builder.sign(ca_key, hashes.SHA256())
|
|
|
|
# Serialize
|
|
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
|
key_pem = client_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode("utf-8")
|
|
|
|
# Calculate fingerprint
|
|
fingerprint = calculate_fingerprint(certificate)
|
|
serial_hex = format(serial, "032x")
|
|
created_at = int(now.timestamp())
|
|
expires_at = int((now + timedelta(days=days)).timestamp())
|
|
|
|
# Save to database
|
|
db.execute(
|
|
"""INSERT INTO issued_certificates
|
|
(serial, ca_id, common_name, fingerprint_sha1, certificate_pem,
|
|
created_at, expires_at, issued_to, status, revoked_at, is_admin)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
serial_hex,
|
|
"default",
|
|
common_name,
|
|
fingerprint,
|
|
cert_pem,
|
|
created_at,
|
|
expires_at,
|
|
issued_to,
|
|
"valid",
|
|
None,
|
|
1 if is_admin else 0,
|
|
),
|
|
)
|
|
db.commit()
|
|
|
|
return {
|
|
"serial": serial_hex,
|
|
"common_name": common_name,
|
|
"fingerprint_sha1": fingerprint,
|
|
"certificate_pem": cert_pem,
|
|
"private_key_pem": key_pem,
|
|
"created_at": created_at,
|
|
"expires_at": expires_at,
|
|
"is_admin": is_admin,
|
|
}
|
|
|
|
|
|
def is_admin_certificate(fingerprint: str) -> bool:
|
|
"""Check if a certificate fingerprint belongs to an admin.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint of the certificate
|
|
|
|
Returns:
|
|
True if the certificate holder is an admin
|
|
"""
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
row = db.execute(
|
|
"""SELECT is_admin FROM issued_certificates
|
|
WHERE fingerprint_sha1 = ? AND status = 'valid'""",
|
|
(fingerprint,),
|
|
).fetchone()
|
|
|
|
return bool(row and row["is_admin"])
|
|
|
|
|
|
def is_trusted_certificate(fingerprint: str) -> bool:
|
|
"""Check if a certificate is trusted (registered in PKI system).
|
|
|
|
Trusted certificates are those issued by our PKI system and still valid.
|
|
External certificates (valid for auth but not issued by us) are not trusted.
|
|
|
|
Args:
|
|
fingerprint: SHA1 fingerprint of the certificate
|
|
|
|
Returns:
|
|
True if the certificate is registered and valid in our PKI
|
|
"""
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
row = db.execute(
|
|
"""SELECT status FROM issued_certificates
|
|
WHERE fingerprint_sha1 = ? AND status = 'valid'""",
|
|
(fingerprint,),
|
|
).fetchone()
|
|
|
|
return row is not None
|
|
|
|
|
|
def revoke_certificate(serial: str) -> bool:
|
|
"""Revoke a certificate by serial number.
|
|
|
|
Args:
|
|
serial: Certificate serial number (hex)
|
|
|
|
Returns:
|
|
True if revoked
|
|
|
|
Raises:
|
|
CertificateNotFoundError: If certificate not found
|
|
"""
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
|
|
# Check exists
|
|
row = db.execute(
|
|
"SELECT status FROM issued_certificates WHERE serial = ?", (serial,)
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
raise CertificateNotFoundError(f"Certificate not found: {serial}")
|
|
|
|
if row["status"] == "revoked":
|
|
return False # Already revoked
|
|
|
|
# Revoke
|
|
db.execute(
|
|
"UPDATE issued_certificates SET status = 'revoked', revoked_at = ? WHERE serial = ?",
|
|
(int(time.time()), serial),
|
|
)
|
|
db.commit()
|
|
|
|
return True
|