Files
flaskpaste/app/pki.py
Username ca9342e92d fix: add comprehensive type annotations for mypy
- database.py: add type hints for Path, Flask, Any, BaseException
- pki.py: add assertions to narrow Optional types after has_ca() checks
- routes.py: annotate config values to avoid Any return types
- api/__init__.py: use float for cleanup timestamps (time.time())
- __init__.py: remove unused return from setup_rate_limiting
2025-12-22 19:11:11 +01:00

1131 lines
33 KiB
Python

"""Minimal PKI module for certificate authority and client certificate management.
This module provides both standalone PKI functionality and Flask-integrated
database-backed functions. Core cryptographic functions have no Flask dependencies.
Usage (standalone):
from app.pki import PKI
pki = PKI(password="your-ca-password")
ca_info = pki.generate_ca("My CA")
cert_info = pki.issue_certificate("client-name")
Usage (Flask routes):
from app.pki import generate_ca, issue_certificate, is_certificate_valid
ca_info = generate_ca("My CA", password) # Saves to database
if is_certificate_valid(fingerprint):
# Certificate is valid and not revoked
"""
import hashlib
import secrets
import time
from datetime import UTC, datetime, timedelta
from typing import Any
# Cryptography imports (required dependency)
try:
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.serialization import pkcs12
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
HAS_CRYPTO = True
except ImportError:
HAS_CRYPTO = False
# Constants
_KDF_ITERATIONS = 600000 # OWASP 2023 recommendation
_SALT_LENGTH = 32
_KEY_LENGTH = 32 # AES-256
class PKIError(Exception):
"""Base exception for PKI operations."""
pass
class CANotFoundError(PKIError):
"""CA does not exist."""
pass
class CAExistsError(PKIError):
"""CA already exists."""
pass
class CertificateNotFoundError(PKIError):
"""Certificate not found."""
pass
class InvalidPasswordError(PKIError):
"""Invalid CA password."""
pass
def _require_crypto() -> None:
"""Raise if cryptography package is not available."""
if not HAS_CRYPTO:
raise PKIError("PKI requires 'cryptography' package: pip install cryptography")
def derive_key(password: str, salt: bytes) -> bytes:
"""Derive encryption key from password using PBKDF2.
Args:
password: Password string
salt: Random salt bytes
Returns:
32-byte derived key for AES-256
"""
_require_crypto()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=_KEY_LENGTH,
salt=salt,
iterations=_KDF_ITERATIONS,
)
return kdf.derive(password.encode("utf-8"))
def encrypt_private_key(private_key: Any, password: str) -> tuple[bytes, bytes]:
"""Encrypt private key with password.
Args:
private_key: EC or RSA private key object
password: Encryption password
Returns:
Tuple of (encrypted_key_bytes, salt)
"""
_require_crypto()
# Serialize private key to PEM (unencrypted)
key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Derive encryption key
salt = secrets.token_bytes(_SALT_LENGTH)
encryption_key = derive_key(password, salt)
# Encrypt with AES-256-GCM
nonce = secrets.token_bytes(12)
aesgcm = AESGCM(encryption_key)
ciphertext = aesgcm.encrypt(nonce, key_pem, None)
# Prepend nonce to ciphertext
encrypted = nonce + ciphertext
return encrypted, salt
def decrypt_private_key(encrypted: bytes, salt: bytes, password: str) -> Any:
"""Decrypt private key with password.
Args:
encrypted: Encrypted key bytes (nonce + ciphertext)
salt: Salt used for key derivation
password: Decryption password
Returns:
Private key object
Raises:
InvalidPasswordError: If password is incorrect
"""
_require_crypto()
if len(encrypted) < 12:
raise PKIError("Invalid encrypted key format")
# Derive encryption key
encryption_key = derive_key(password, salt)
# Decrypt
nonce = encrypted[:12]
ciphertext = encrypted[12:]
aesgcm = AESGCM(encryption_key)
try:
key_pem = aesgcm.decrypt(nonce, ciphertext, None)
except Exception:
raise InvalidPasswordError("Invalid CA password") from None
# Load private key
return serialization.load_pem_private_key(key_pem, password=None)
def calculate_fingerprint(certificate: Any) -> str:
"""Calculate SHA1 fingerprint of certificate.
SHA1 is used here for X.509 certificate fingerprints, which is standard
practice and not a security concern (fingerprints are identifiers, not
used for cryptographic security).
Args:
certificate: X.509 certificate object
Returns:
Lowercase hex fingerprint (40 characters)
"""
_require_crypto()
cert_der = certificate.public_bytes(serialization.Encoding.DER)
# SHA1 fingerprints are industry standard for X.509, not security-relevant
return hashlib.sha1(cert_der, usedforsecurity=False).hexdigest()
def create_pkcs12(
private_key: Any,
certificate: Any,
ca_certificate: Any,
friendly_name: str,
password: bytes | None = None,
) -> bytes:
"""Create PKCS#12 bundle containing certificate, private key, and CA cert.
Args:
private_key: Client private key object
certificate: Client certificate object
ca_certificate: CA certificate object
friendly_name: Friendly name for the certificate
password: Optional password for PKCS#12 encryption (None = no password)
Returns:
PKCS#12 bytes (DER encoded)
"""
_require_crypto()
# Build encryption algorithm - use strong encryption if password provided
if password:
encryption = (
serialization.BestAvailableEncryption(password)
if password
else serialization.NoEncryption()
)
else:
encryption = serialization.NoEncryption()
# Serialize to PKCS#12
p12_data = pkcs12.serialize_key_and_certificates(
name=friendly_name.encode("utf-8"),
key=private_key,
cert=certificate,
cas=[ca_certificate],
encryption_algorithm=encryption,
)
return p12_data
class PKI:
"""Standalone PKI manager for CA and certificate operations.
This class provides core PKI functionality without Flask dependencies.
Use get_pki() for Flask-integrated usage.
Args:
password: CA password for signing operations
ca_days: CA certificate validity in days (default: 3650)
cert_days: Client certificate validity in days (default: 365)
"""
def __init__(
self,
password: str,
ca_days: int = 3650,
cert_days: int = 365,
):
_require_crypto()
self.password = password
self.ca_days = ca_days
self.cert_days = cert_days
# In-memory storage (override for database backing)
self._ca_store: dict | None = None
self._certificates: dict[str, dict] = {}
def has_ca(self) -> bool:
"""Check if CA exists."""
return self._ca_store is not None
def generate_ca(
self,
common_name: str,
algorithm: str = "ec",
curve: str = "secp384r1",
) -> dict:
"""Generate a new Certificate Authority.
Args:
common_name: CA common name (e.g., "FlaskPaste CA")
algorithm: Key algorithm ("ec" only for now)
curve: EC curve name (secp256r1, secp384r1, secp521r1)
Returns:
Dict with CA info: id, common_name, fingerprint, expires_at, certificate_pem
Raises:
CAExistsError: If CA already exists
"""
if self.has_ca():
raise CAExistsError("CA already exists")
# Generate EC key
curves = {
"secp256r1": ec.SECP256R1(),
"secp384r1": ec.SECP384R1(),
"secp521r1": ec.SECP521R1(),
}
if curve not in curves:
raise PKIError(f"Unsupported curve: {curve}")
private_key = ec.generate_private_key(curves[curve])
# Build CA certificate
now = datetime.now(UTC)
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"),
]
)
cert_builder = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + timedelta(days=self.ca_days))
.add_extension(
x509.BasicConstraints(ca=True, path_length=0),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
)
certificate = cert_builder.sign(private_key, hashes.SHA256())
# Encrypt private key
encrypted_key, salt = encrypt_private_key(private_key, self.password)
# Calculate fingerprint
fingerprint = calculate_fingerprint(certificate)
# Store CA
expires_at = int((now + timedelta(days=self.ca_days)).timestamp())
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
self._ca_store = {
"id": "default",
"common_name": common_name,
"certificate_pem": cert_pem,
"private_key_encrypted": encrypted_key,
"key_salt": salt,
"created_at": int(now.timestamp()),
"expires_at": expires_at,
"key_algorithm": f"ec:{curve}",
"fingerprint": fingerprint,
"_private_key": private_key, # Cached for signing
"_certificate": certificate,
}
return {
"id": "default",
"common_name": common_name,
"fingerprint": fingerprint,
"expires_at": expires_at,
"certificate_pem": cert_pem,
}
def load_ca(self, cert_pem: str, key_pem: str, key_password: str | None = None) -> dict:
"""Load existing CA from PEM files.
Args:
cert_pem: CA certificate PEM string
key_pem: CA private key PEM string
key_password: Password for encrypted private key (if any)
Returns:
Dict with CA info
"""
if self.has_ca():
raise CAExistsError("CA already exists")
# Load certificate
certificate = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"))
# Load private key
pwd = key_password.encode("utf-8") if key_password else None
private_key = serialization.load_pem_private_key(key_pem.encode("utf-8"), password=pwd)
# Extract info
common_name = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
fingerprint = calculate_fingerprint(certificate)
# Re-encrypt with our password
encrypted_key, salt = encrypt_private_key(private_key, self.password)
now = datetime.now(UTC)
expires_at = int(certificate.not_valid_after_utc.timestamp())
self._ca_store = {
"id": "default",
"common_name": common_name,
"certificate_pem": cert_pem,
"private_key_encrypted": encrypted_key,
"key_salt": salt,
"created_at": int(now.timestamp()),
"expires_at": expires_at,
"key_algorithm": "imported",
"fingerprint": fingerprint,
"_private_key": private_key,
"_certificate": certificate,
}
return {
"id": "default",
"common_name": common_name,
"fingerprint": fingerprint,
"expires_at": expires_at,
"certificate_pem": cert_pem,
}
def get_ca(self) -> dict:
"""Get CA information.
Returns:
Dict with CA info (without private key)
Raises:
CANotFoundError: If no CA exists
"""
if not self.has_ca():
raise CANotFoundError("No CA configured")
assert self._ca_store is not None # narrowing for mypy
return {
"id": self._ca_store["id"],
"common_name": self._ca_store["common_name"],
"fingerprint": self._ca_store["fingerprint"],
"expires_at": self._ca_store["expires_at"],
"certificate_pem": self._ca_store["certificate_pem"],
}
def get_ca_certificate_pem(self) -> str:
"""Get CA certificate in PEM format.
Returns:
PEM-encoded CA certificate
"""
if not self.has_ca():
raise CANotFoundError("No CA configured")
assert self._ca_store is not None # narrowing for mypy
cert_pem: str = self._ca_store["certificate_pem"]
return cert_pem
def _get_signing_key(self) -> tuple[Any, Any]:
"""Get CA private key and certificate for signing.
Returns:
Tuple of (private_key, certificate)
"""
if not self.has_ca():
raise CANotFoundError("No CA configured")
assert self._ca_store is not None # narrowing for mypy
# Use cached key if available
if "_private_key" in self._ca_store:
return self._ca_store["_private_key"], self._ca_store["_certificate"]
# Decrypt private key
private_key = decrypt_private_key(
self._ca_store["private_key_encrypted"],
self._ca_store["key_salt"],
self.password,
)
# Load certificate
certificate = x509.load_pem_x509_certificate(
self._ca_store["certificate_pem"].encode("utf-8")
)
# Cache for future use
self._ca_store["_private_key"] = private_key
self._ca_store["_certificate"] = certificate
return private_key, certificate
def issue_certificate(
self,
common_name: str,
days: int | None = None,
algorithm: str = "ec",
curve: str = "secp384r1",
) -> dict:
"""Issue a new client certificate.
Args:
common_name: Client common name
days: Validity period (default: self.cert_days)
algorithm: Key algorithm ("ec")
curve: EC curve name
Returns:
Dict with certificate info including private key PEM
"""
if days is None:
days = self.cert_days
ca_key, ca_cert = self._get_signing_key()
assert self._ca_store is not None # narrowing for mypy (validated in _get_signing_key)
# Generate client key
curves = {
"secp256r1": ec.SECP256R1(),
"secp384r1": ec.SECP384R1(),
"secp521r1": ec.SECP521R1(),
}
if curve not in curves:
raise PKIError(f"Unsupported curve: {curve}")
client_key = ec.generate_private_key(curves[curve])
# Build certificate
now = datetime.now(UTC)
serial = x509.random_serial_number()
subject = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
cert_builder = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(client_key.public_key())
.serial_number(serial)
.not_valid_before(now)
.not_valid_after(now + timedelta(days=days))
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]),
critical=False,
)
)
certificate = cert_builder.sign(ca_key, hashes.SHA256())
# Serialize
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = client_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
# Calculate fingerprint
fingerprint = calculate_fingerprint(certificate)
serial_hex = format(serial, "032x")
expires_at = int((now + timedelta(days=days)).timestamp())
# Store certificate record
cert_record = {
"serial": serial_hex,
"ca_id": "default",
"common_name": common_name,
"fingerprint_sha1": fingerprint,
"certificate_pem": cert_pem,
"created_at": int(now.timestamp()),
"expires_at": expires_at,
"status": "valid",
"revoked_at": None,
}
self._certificates[fingerprint] = cert_record
return {
"serial": serial_hex,
"common_name": common_name,
"fingerprint": fingerprint,
"expires_at": expires_at,
"certificate_pem": cert_pem,
"private_key_pem": key_pem,
"ca_certificate_pem": self._ca_store["certificate_pem"],
}
def revoke_certificate(self, fingerprint: str) -> bool:
"""Revoke a certificate by fingerprint.
Args:
fingerprint: SHA1 fingerprint (40 hex chars)
Returns:
True if revoked, False if already revoked
"""
fingerprint = fingerprint.lower()
if fingerprint not in self._certificates:
raise CertificateNotFoundError(f"Certificate not found: {fingerprint}")
cert = self._certificates[fingerprint]
if cert["status"] == "revoked":
return False
cert["status"] = "revoked"
cert["revoked_at"] = int(time.time())
return True
def is_valid(self, fingerprint: str) -> bool:
"""Check if fingerprint is valid (exists and not revoked).
Args:
fingerprint: SHA1 fingerprint
Returns:
True if valid, False otherwise
"""
fingerprint = fingerprint.lower()
if fingerprint not in self._certificates:
return True # Unknown fingerprints are allowed (external certs)
cert = self._certificates[fingerprint]
if cert["status"] != "valid":
return False
# Check expiry
expires_at: int = cert["expires_at"]
return expires_at >= int(time.time())
def get_certificate(self, fingerprint: str) -> dict | None:
"""Get certificate info by fingerprint.
Args:
fingerprint: SHA1 fingerprint
Returns:
Certificate info dict or None if not found
"""
fingerprint = fingerprint.lower()
cert = self._certificates.get(fingerprint)
if cert is None:
return None
return {
"serial": cert["serial"],
"common_name": cert["common_name"],
"fingerprint": cert["fingerprint_sha1"],
"expires_at": cert["expires_at"],
"status": cert["status"],
"created_at": cert["created_at"],
"revoked_at": cert["revoked_at"],
}
def list_certificates(self) -> list[dict]:
"""List all issued certificates.
Returns:
List of certificate info dicts
"""
return [
{
"serial": c["serial"],
"common_name": c["common_name"],
"fingerprint": c["fingerprint_sha1"],
"expires_at": c["expires_at"],
"status": c["status"],
"created_at": c["created_at"],
}
for c in self._certificates.values()
]
def reset_pki() -> None:
"""Reset PKI state (for testing).
This is a no-op since routes use direct database functions.
Kept for backward compatibility with tests.
"""
pass
def is_certificate_valid(fingerprint: str) -> bool:
"""Check if fingerprint is valid for authentication.
This is the main integration point for routes.py.
Unknown fingerprints are considered valid (external certs).
Queries database directly to ensure fresh revocation status.
Args:
fingerprint: SHA1 fingerprint
Returns:
True if valid or unknown, False if revoked/expired
"""
from flask import current_app
if not current_app.config.get("PKI_ENABLED"):
return True
from app.database import get_db
fingerprint = fingerprint.lower()
db = get_db()
# Query database directly for fresh revocation status
row = db.execute(
"SELECT status, expires_at FROM issued_certificates WHERE fingerprint_sha1 = ?",
(fingerprint,),
).fetchone()
if row is None:
# Unknown fingerprint (external cert or not issued by us) - allow
return True
# Check status
if row["status"] != "valid":
return False
# Check expiry
expires_at: int = row["expires_at"]
return expires_at >= int(time.time())
# ─────────────────────────────────────────────────────────────────────────────
# Flask Route Helper Functions
# ─────────────────────────────────────────────────────────────────────────────
def get_ca_info(skip_enabled_check: bool = False) -> dict | None:
"""Get CA information for status endpoint.
Args:
skip_enabled_check: If True, skip the PKI_ENABLED check (for registration)
Returns:
Dict with CA info or None if no CA exists
"""
from flask import current_app
from app.database import get_db
if not skip_enabled_check and not current_app.config.get("PKI_ENABLED"):
return None
db = get_db()
row = db.execute(
"""SELECT id, common_name, certificate_pem, created_at, expires_at, key_algorithm
FROM certificate_authority WHERE id = 'default'"""
).fetchone()
if row is None:
return None
# Calculate fingerprint
cert = x509.load_pem_x509_certificate(row["certificate_pem"].encode("utf-8"))
fingerprint = calculate_fingerprint(cert)
return {
"id": row["id"],
"common_name": row["common_name"],
"certificate_pem": row["certificate_pem"],
"fingerprint_sha1": fingerprint,
"created_at": row["created_at"],
"expires_at": row["expires_at"],
"key_algorithm": row["key_algorithm"],
}
def generate_ca(
common_name: str,
password: str,
days: int = 3650,
owner: str | None = None,
) -> dict:
"""Generate CA certificate and save to database.
Args:
common_name: CA common name
password: CA password for encrypting private key
days: Validity period in days
owner: Optional owner fingerprint
Returns:
Dict with CA info including fingerprint_sha1
Raises:
CAExistsError: If CA already exists
"""
_require_crypto()
from app.database import get_db
db = get_db()
# Check if CA already exists
existing = db.execute("SELECT id FROM certificate_authority WHERE id = 'default'").fetchone()
if existing:
raise CAExistsError("CA already exists")
# Generate EC key
curve = ec.SECP384R1()
private_key = ec.generate_private_key(curve)
# Build CA certificate
now = datetime.now(UTC)
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "FlaskPaste PKI"),
]
)
cert_builder = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + timedelta(days=days))
.add_extension(
x509.BasicConstraints(ca=True, path_length=0),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
)
certificate = cert_builder.sign(private_key, hashes.SHA256())
# Encrypt private key
encrypted_key, salt = encrypt_private_key(private_key, password)
# Calculate fingerprint
fingerprint = calculate_fingerprint(certificate)
# Serialize
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
created_at = int(now.timestamp())
expires_at = int((now + timedelta(days=days)).timestamp())
# Save to database
db.execute(
"""INSERT INTO certificate_authority
(id, common_name, certificate_pem, private_key_encrypted,
key_salt, created_at, expires_at, key_algorithm, owner)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
"default",
common_name,
cert_pem,
encrypted_key,
salt,
created_at,
expires_at,
"ec:secp384r1",
owner,
),
)
db.commit()
# Reset cached PKI instance
reset_pki()
return {
"id": "default",
"common_name": common_name,
"fingerprint_sha1": fingerprint,
"certificate_pem": cert_pem,
"created_at": created_at,
"expires_at": expires_at,
}
def issue_certificate(
common_name: str,
password: str,
days: int = 365,
issued_to: str | None = None,
is_admin: bool | None = None,
) -> dict:
"""Issue a client certificate signed by the CA.
Args:
common_name: Client common name
password: CA password for signing
days: Validity period in days
issued_to: Optional fingerprint of issuing user
is_admin: Admin flag (None = auto-detect first user)
Returns:
Dict with certificate and private key PEM
Raises:
CANotFoundError: If no CA exists
"""
_require_crypto()
from app.database import get_db
db = get_db()
# Auto-detect: first registered user becomes admin
if is_admin is None:
count = db.execute(
"SELECT COUNT(*) FROM issued_certificates WHERE status = 'valid'"
).fetchone()[0]
is_admin = count == 0
# Load CA
ca_row = db.execute(
"""SELECT certificate_pem, private_key_encrypted, key_salt
FROM certificate_authority WHERE id = 'default'"""
).fetchone()
if ca_row is None:
raise CANotFoundError("No CA configured")
# Decrypt CA private key
ca_key = decrypt_private_key(
ca_row["private_key_encrypted"],
ca_row["key_salt"],
password,
)
ca_cert = x509.load_pem_x509_certificate(ca_row["certificate_pem"].encode("utf-8"))
# Generate client key
curve = ec.SECP384R1()
client_key = ec.generate_private_key(curve)
# Build certificate
now = datetime.now(UTC)
serial = x509.random_serial_number()
subject = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
cert_builder = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(client_key.public_key())
.serial_number(serial)
.not_valid_before(now)
.not_valid_after(now + timedelta(days=days))
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]),
critical=False,
)
)
certificate = cert_builder.sign(ca_key, hashes.SHA256())
# Serialize
cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = client_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
# Calculate fingerprint
fingerprint = calculate_fingerprint(certificate)
serial_hex = format(serial, "032x")
created_at = int(now.timestamp())
expires_at = int((now + timedelta(days=days)).timestamp())
# Save to database
db.execute(
"""INSERT INTO issued_certificates
(serial, ca_id, common_name, fingerprint_sha1, certificate_pem,
created_at, expires_at, issued_to, status, revoked_at, is_admin)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
serial_hex,
"default",
common_name,
fingerprint,
cert_pem,
created_at,
expires_at,
issued_to,
"valid",
None,
1 if is_admin else 0,
),
)
db.commit()
return {
"serial": serial_hex,
"common_name": common_name,
"fingerprint_sha1": fingerprint,
"certificate_pem": cert_pem,
"private_key_pem": key_pem,
"created_at": created_at,
"expires_at": expires_at,
"is_admin": is_admin,
}
def is_admin_certificate(fingerprint: str) -> bool:
"""Check if a certificate fingerprint belongs to an admin.
Args:
fingerprint: SHA1 fingerprint of the certificate
Returns:
True if the certificate holder is an admin
"""
from app.database import get_db
db = get_db()
row = db.execute(
"""SELECT is_admin FROM issued_certificates
WHERE fingerprint_sha1 = ? AND status = 'valid'""",
(fingerprint,),
).fetchone()
return bool(row and row["is_admin"])
def is_trusted_certificate(fingerprint: str) -> bool:
"""Check if a certificate is trusted (registered in PKI system).
Trusted certificates are those issued by our PKI system and still valid.
External certificates (valid for auth but not issued by us) are not trusted.
Args:
fingerprint: SHA1 fingerprint of the certificate
Returns:
True if the certificate is registered and valid in our PKI
"""
from app.database import get_db
db = get_db()
row = db.execute(
"""SELECT status FROM issued_certificates
WHERE fingerprint_sha1 = ? AND status = 'valid'""",
(fingerprint,),
).fetchone()
return row is not None
def revoke_certificate(serial: str) -> bool:
"""Revoke a certificate by serial number.
Args:
serial: Certificate serial number (hex)
Returns:
True if revoked
Raises:
CertificateNotFoundError: If certificate not found
"""
from app.database import get_db
db = get_db()
# Check exists
row = db.execute(
"SELECT status FROM issued_certificates WHERE serial = ?", (serial,)
).fetchone()
if row is None:
raise CertificateNotFoundError(f"Certificate not found: {serial}")
if row["status"] == "revoked":
return False # Already revoked
# Revoke
db.execute(
"UPDATE issued_certificates SET status = 'revoked', revoked_at = ? WHERE serial = ?",
(int(time.time()), serial),
)
db.commit()
return True