Files
flaskpaste/tests/test_pki.py
Username 5849c7406f add /register endpoint for public certificate registration
Public endpoint allows anyone to obtain a client certificate for
authentication. Features:

- Higher PoW difficulty than paste creation (24 vs 20 bits)
- Auto-generates CA on first registration if not present
- Returns PKCS#12 bundle with cert, key, and CA
- Configurable via FLASKPASTE_REGISTER_POW

Endpoints:
- GET /register/challenge - Get registration PoW challenge
- POST /register - Register and receive PKCS#12 bundle
2025-12-21 10:34:02 +01:00

597 lines
21 KiB
Python

"""Tests for PKI (Certificate Authority) functionality."""
from datetime import UTC, datetime, timedelta
import pytest
from app.pki import reset_pki
@pytest.fixture(autouse=True)
def reset_pki_state(app):
"""Reset PKI state and clear PKI database tables before each test."""
reset_pki()
# Clear PKI tables in database
with app.app_context():
from app.database import get_db
db = get_db()
db.execute("DELETE FROM issued_certificates")
db.execute("DELETE FROM certificate_authority")
db.commit()
yield
reset_pki()
class TestPKIStatus:
"""Test GET /pki endpoint."""
def test_pki_status_when_enabled(self, client):
"""PKI status shows enabled with no CA initially."""
response = client.get("/pki")
assert response.status_code == 200
data = response.get_json()
assert data["enabled"] is True
assert data["ca_exists"] is False
assert "hint" in data
def test_pki_status_after_ca_generation(self, client):
"""PKI status shows CA info after generation."""
# Generate CA first
client.post("/pki/ca", json={"common_name": "Test CA"})
response = client.get("/pki")
assert response.status_code == 200
data = response.get_json()
assert data["enabled"] is True
assert data["ca_exists"] is True
assert data["common_name"] == "Test CA"
assert "fingerprint_sha1" in data
assert len(data["fingerprint_sha1"]) == 40
class TestCAGeneration:
"""Test POST /pki/ca endpoint."""
def test_generate_ca_success(self, client):
"""CA can be generated with default name."""
response = client.post("/pki/ca")
assert response.status_code == 201
data = response.get_json()
assert data["message"] == "CA generated"
assert data["common_name"] == "FlaskPaste CA"
assert "fingerprint_sha1" in data
assert "created_at" in data
assert "expires_at" in data
assert data["download"] == "/pki/ca.crt"
def test_generate_ca_custom_name(self, client):
"""CA can be generated with custom name."""
response = client.post("/pki/ca", json={"common_name": "My Custom CA"})
assert response.status_code == 201
data = response.get_json()
assert data["common_name"] == "My Custom CA"
def test_generate_ca_twice_fails(self, client):
"""CA cannot be generated twice."""
# First generation succeeds
response = client.post("/pki/ca")
assert response.status_code == 201
# Second generation fails
response = client.post("/pki/ca")
assert response.status_code == 409
data = response.get_json()
assert "already exists" in data["error"]
class TestCADownload:
"""Test GET /pki/ca.crt endpoint."""
def test_download_ca_not_initialized(self, client):
"""Download fails when no CA exists."""
response = client.get("/pki/ca.crt")
assert response.status_code == 404
def test_download_ca_success(self, client):
"""CA certificate can be downloaded."""
# Generate CA first
client.post("/pki/ca", json={"common_name": "Test CA"})
response = client.get("/pki/ca.crt")
assert response.status_code == 200
assert response.content_type == "application/x-pem-file"
assert b"-----BEGIN CERTIFICATE-----" in response.data
assert b"-----END CERTIFICATE-----" in response.data
class TestCertificateIssuance:
"""Test POST /pki/issue endpoint."""
def test_issue_without_ca_fails(self, client):
"""Issuance fails when no CA exists."""
response = client.post("/pki/issue", json={"common_name": "alice"})
assert response.status_code == 404
def test_issue_without_name_fails(self, client):
"""Issuance fails without common_name."""
client.post("/pki/ca")
response = client.post("/pki/issue", json={})
assert response.status_code == 400
assert "common_name required" in response.get_json()["error"]
def test_issue_certificate_success(self, client):
"""Certificate issuance succeeds."""
client.post("/pki/ca")
response = client.post("/pki/issue", json={"common_name": "alice"})
assert response.status_code == 201
data = response.get_json()
assert data["message"] == "Certificate issued"
assert data["common_name"] == "alice"
assert "serial" in data
assert "fingerprint_sha1" in data
assert len(data["fingerprint_sha1"]) == 40
assert "certificate_pem" in data
assert "private_key_pem" in data
assert "-----BEGIN CERTIFICATE-----" in data["certificate_pem"]
assert "-----BEGIN PRIVATE KEY-----" in data["private_key_pem"]
def test_issue_multiple_certificates(self, client):
"""Multiple certificates can be issued."""
client.post("/pki/ca")
response1 = client.post("/pki/issue", json={"common_name": "alice"})
response2 = client.post("/pki/issue", json={"common_name": "bob"})
assert response1.status_code == 201
assert response2.status_code == 201
data1 = response1.get_json()
data2 = response2.get_json()
# Different serials and fingerprints
assert data1["serial"] != data2["serial"]
assert data1["fingerprint_sha1"] != data2["fingerprint_sha1"]
class TestCertificateListing:
"""Test GET /pki/certs endpoint."""
def test_list_anonymous_empty(self, client):
"""Anonymous users see empty list."""
client.post("/pki/ca")
response = client.get("/pki/certs")
assert response.status_code == 200
data = response.get_json()
assert data["certificates"] == []
assert data["count"] == 0
def test_list_authenticated_sees_own(self, client):
"""Authenticated users see certificates they issued."""
client.post("/pki/ca")
# Issue certificate as authenticated user
issuer_fingerprint = "a" * 40
client.post(
"/pki/issue",
json={"common_name": "alice"},
headers={"X-SSL-Client-SHA1": issuer_fingerprint},
)
# List as same user
response = client.get("/pki/certs", headers={"X-SSL-Client-SHA1": issuer_fingerprint})
assert response.status_code == 200
data = response.get_json()
assert data["count"] == 1
assert data["certificates"][0]["common_name"] == "alice"
class TestCertificateRevocation:
"""Test POST /pki/revoke/<serial> endpoint."""
def test_revoke_unauthenticated_fails(self, client):
"""Revocation requires authentication."""
client.post("/pki/ca")
issue_resp = client.post("/pki/issue", json={"common_name": "alice"})
serial = issue_resp.get_json()["serial"]
response = client.post(f"/pki/revoke/{serial}")
assert response.status_code == 401
def test_revoke_unauthorized_fails(self, client):
"""Revocation requires ownership."""
client.post("/pki/ca")
# Issue as one user
issue_resp = client.post(
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": "a" * 40}
)
serial = issue_resp.get_json()["serial"]
# Try to revoke as different user
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": "b" * 40})
assert response.status_code == 403
def test_revoke_as_issuer_succeeds(self, client):
"""Issuer can revoke certificate."""
client.post("/pki/ca")
issuer = "a" * 40
issue_resp = client.post(
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer}
)
serial = issue_resp.get_json()["serial"]
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
assert response.status_code == 200
assert response.get_json()["message"] == "Certificate revoked"
def test_revoke_nonexistent_fails(self, client):
"""Revoking nonexistent certificate fails."""
client.post("/pki/ca")
response = client.post("/pki/revoke/0" * 32, headers={"X-SSL-Client-SHA1": "a" * 40})
assert response.status_code == 404
def test_revoke_twice_fails(self, client):
"""Certificate cannot be revoked twice."""
client.post("/pki/ca")
issuer = "a" * 40
issue_resp = client.post(
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer}
)
serial = issue_resp.get_json()["serial"]
# First revocation succeeds
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
assert response.status_code == 200
# Second revocation fails
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
assert response.status_code == 409
class TestRevocationIntegration:
"""Test revocation affects authentication."""
def test_revoked_cert_treated_as_anonymous(self, client):
"""Revoked certificate is treated as anonymous."""
client.post("/pki/ca")
# Issue certificate
issuer = "a" * 40
issue_resp = client.post(
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer}
)
cert_fingerprint = issue_resp.get_json()["fingerprint_sha1"]
serial = issue_resp.get_json()["serial"]
# Create paste as authenticated user
create_resp = client.post(
"/", data=b"test content", headers={"X-SSL-Client-SHA1": cert_fingerprint}
)
assert create_resp.status_code == 201
paste_id = create_resp.get_json()["id"]
assert "owner" in create_resp.get_json()
# Revoke the certificate
client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
# Try to delete paste with revoked cert - should fail
delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": cert_fingerprint})
assert delete_resp.status_code == 401
class TestPKICryptoFunctions:
"""Test standalone PKI cryptographic functions."""
def test_derive_key_consistency(self):
"""Key derivation produces consistent results."""
from app.pki import derive_key
password = "test-password"
salt = b"x" * 32
key1 = derive_key(password, salt)
key2 = derive_key(password, salt)
assert key1 == key2
assert len(key1) == 32
def test_encrypt_decrypt_roundtrip(self):
"""Private key encryption/decryption roundtrip."""
from cryptography.hazmat.primitives.asymmetric import ec
from app.pki import decrypt_private_key, encrypt_private_key
# Generate a test key
private_key = ec.generate_private_key(ec.SECP384R1())
password = "test-password"
# Encrypt
encrypted, salt = encrypt_private_key(private_key, password)
# Decrypt
decrypted = decrypt_private_key(encrypted, salt, password)
# Verify same key
assert private_key.private_numbers() == decrypted.private_numbers()
def test_wrong_password_fails(self):
"""Decryption with wrong password fails."""
from cryptography.hazmat.primitives.asymmetric import ec
from app.pki import (
InvalidPasswordError,
decrypt_private_key,
encrypt_private_key,
)
private_key = ec.generate_private_key(ec.SECP384R1())
encrypted, salt = encrypt_private_key(private_key, "correct")
with pytest.raises(InvalidPasswordError):
decrypt_private_key(encrypted, salt, "wrong")
def test_fingerprint_calculation(self):
"""Certificate fingerprint is calculated correctly."""
from datetime import datetime, timedelta
from cryptography import x509
# Minimal self-signed cert for testing
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import NameOID
from app.pki import calculate_fingerprint
key = ec.generate_private_key(ec.SECP256R1())
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(1)
.not_valid_before(datetime.now(UTC))
.not_valid_after(datetime.now(UTC) + timedelta(days=1))
.sign(key, hashes.SHA256())
)
fingerprint = calculate_fingerprint(cert)
assert len(fingerprint) == 40
assert all(c in "0123456789abcdef" for c in fingerprint)
class TestRegistration:
"""Test public certificate registration via /register endpoint."""
def test_register_challenge_returns_token(self, client, app):
"""Registration challenge endpoint returns PoW token."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 10
response = client.get("/register/challenge")
assert response.status_code == 200
data = response.get_json()
assert data["enabled"] is True
assert "nonce" in data
assert "token" in data
assert data["difficulty"] == 10
assert data["purpose"] == "registration"
def test_register_challenge_disabled(self, client, app):
"""Registration challenge shows disabled when difficulty is 0."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 0
response = client.get("/register/challenge")
assert response.status_code == 200
data = response.get_json()
assert data["enabled"] is False
def test_register_requires_pow(self, client, app):
"""Registration fails without PoW when difficulty > 0."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 10
response = client.post("/register", json={"common_name": "test"})
assert response.status_code == 400
assert "Proof-of-work required" in response.get_json()["error"]
def test_register_with_pow_disabled_succeeds(self, client, app):
"""Registration succeeds without PoW when difficulty is 0."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 0
response = client.post("/register", json={"common_name": "test-client"})
assert response.status_code == 200
assert response.content_type == "application/x-pkcs12"
assert "X-Fingerprint-SHA1" in response.headers
assert len(response.headers["X-Fingerprint-SHA1"]) == 40
def test_register_auto_generates_ca(self, client, app):
"""Registration auto-generates CA if not present."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 0
# Verify no CA exists
from app.pki import get_ca_info
assert get_ca_info() is None
response = client.post("/register", json={"common_name": "first-client"})
assert response.status_code == 200
# Verify CA now exists
with app.app_context():
from app.pki import get_ca_info
ca_info = get_ca_info()
assert ca_info is not None
assert ca_info["common_name"] == "FlaskPaste CA"
def test_register_returns_pkcs12(self, client, app):
"""Registration returns valid PKCS#12 bundle."""
from cryptography.hazmat.primitives.serialization import pkcs12
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 0
response = client.post("/register", json={"common_name": "my-client"})
assert response.status_code == 200
# Verify PKCS#12 can be loaded
p12_data = response.data
private_key, certificate, additional_certs = pkcs12.load_key_and_certificates(
p12_data, password=None
)
assert private_key is not None
assert certificate is not None
# Should include CA certificate
assert additional_certs is not None
assert len(additional_certs) == 1
def test_register_generates_common_name(self, client, app):
"""Registration generates random CN if not provided."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 0
response = client.post("/register")
assert response.status_code == 200
# CN is in the Content-Disposition header
disposition = response.headers["Content-Disposition"]
assert "client-" in disposition
assert ".p12" in disposition
def test_register_respects_custom_common_name(self, client, app):
"""Registration uses provided common name."""
with app.app_context():
app.config["REGISTER_POW_DIFFICULTY"] = 0
response = client.post("/register", json={"common_name": "custom-name"})
assert response.status_code == 200
disposition = response.headers["Content-Disposition"]
assert "custom-name.p12" in disposition
def test_register_without_pki_password_fails(self, client, app):
"""Registration fails when PKI_CA_PASSWORD not configured."""
with app.app_context():
app.config["PKI_CA_PASSWORD"] = ""
app.config["REGISTER_POW_DIFFICULTY"] = 0
response = client.post("/register", json={"common_name": "test"})
assert response.status_code == 503
assert "not available" in response.get_json()["error"]
class TestPKCS12Creation:
"""Test PKCS#12 bundle creation function."""
def test_create_pkcs12_without_password(self):
"""PKCS#12 created without password can be loaded."""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import pkcs12
from cryptography.x509.oid import NameOID
from app.pki import create_pkcs12
# Generate test keys and certs
ca_key = ec.generate_private_key(ec.SECP384R1())
client_key = ec.generate_private_key(ec.SECP384R1())
now = datetime.now(UTC)
ca_cert = (
x509.CertificateBuilder()
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")]))
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")]))
.public_key(ca_key.public_key())
.serial_number(1)
.not_valid_before(now)
.not_valid_after(now + timedelta(days=365))
.sign(ca_key, hashes.SHA256())
)
client_cert = (
x509.CertificateBuilder()
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Client")]))
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")]))
.public_key(client_key.public_key())
.serial_number(2)
.not_valid_before(now)
.not_valid_after(now + timedelta(days=30))
.sign(ca_key, hashes.SHA256())
)
# Create PKCS#12
p12_data = create_pkcs12(
private_key=client_key,
certificate=client_cert,
ca_certificate=ca_cert,
friendly_name="test-client",
password=None,
)
# Load and verify
loaded_key, loaded_cert, loaded_cas = pkcs12.load_key_and_certificates(
p12_data, password=None
)
assert loaded_key is not None
assert loaded_cert is not None
assert loaded_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == "Client"
assert loaded_cas is not None
assert len(loaded_cas) == 1
def test_create_pkcs12_with_password(self):
"""PKCS#12 created with password requires password to load."""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import pkcs12
from cryptography.x509.oid import NameOID
from app.pki import create_pkcs12
key = ec.generate_private_key(ec.SECP384R1())
now = datetime.now(UTC)
cert = (
x509.CertificateBuilder()
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test")]))
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test")]))
.public_key(key.public_key())
.serial_number(1)
.not_valid_before(now)
.not_valid_after(now + timedelta(days=1))
.sign(key, hashes.SHA256())
)
p12_data = create_pkcs12(
private_key=key,
certificate=cert,
ca_certificate=cert,
friendly_name="test",
password=b"secret123",
)
# Should fail without password
with pytest.raises(ValueError):
pkcs12.load_key_and_certificates(p12_data, password=None)
# Should succeed with correct password
loaded_key, _, _ = pkcs12.load_key_and_certificates(p12_data, password=b"secret123")
assert loaded_key is not None