"""Tests for PKI (Certificate Authority) functionality.""" from datetime import UTC, datetime, timedelta import pytest from app.pki import reset_pki @pytest.fixture(autouse=True) def reset_pki_state(app): """Reset PKI state and clear PKI database tables before each test.""" reset_pki() # Clear PKI tables in database with app.app_context(): from app.database import get_db db = get_db() db.execute("DELETE FROM issued_certificates") db.execute("DELETE FROM certificate_authority") db.commit() yield reset_pki() class TestPKIStatus: """Test GET /pki endpoint.""" def test_pki_status_when_enabled(self, client): """PKI status shows enabled with no CA initially.""" response = client.get("/pki") assert response.status_code == 200 data = response.get_json() assert data["enabled"] is True assert data["ca_exists"] is False assert "hint" in data def test_pki_status_after_ca_generation(self, client): """PKI status shows CA info after generation.""" # Generate CA first client.post("/pki/ca", json={"common_name": "Test CA"}) response = client.get("/pki") assert response.status_code == 200 data = response.get_json() assert data["enabled"] is True assert data["ca_exists"] is True assert data["common_name"] == "Test CA" assert "fingerprint_sha1" in data assert len(data["fingerprint_sha1"]) == 40 class TestCAGeneration: """Test POST /pki/ca endpoint.""" def test_generate_ca_success(self, client): """CA can be generated with default name.""" response = client.post("/pki/ca") assert response.status_code == 201 data = response.get_json() assert data["message"] == "CA generated" assert data["common_name"] == "FlaskPaste CA" assert "fingerprint_sha1" in data assert "created_at" in data assert "expires_at" in data assert data["download"] == "/pki/ca.crt" def test_generate_ca_custom_name(self, client): """CA can be generated with custom name.""" response = client.post("/pki/ca", json={"common_name": "My Custom CA"}) assert response.status_code == 201 data = response.get_json() assert data["common_name"] == "My Custom CA" def test_generate_ca_twice_fails(self, client): """CA cannot be generated twice.""" # First generation succeeds response = client.post("/pki/ca") assert response.status_code == 201 # Second generation fails response = client.post("/pki/ca") assert response.status_code == 409 data = response.get_json() assert "already exists" in data["error"] class TestCADownload: """Test GET /pki/ca.crt endpoint.""" def test_download_ca_not_initialized(self, client): """Download fails when no CA exists.""" response = client.get("/pki/ca.crt") assert response.status_code == 404 def test_download_ca_success(self, client): """CA certificate can be downloaded.""" # Generate CA first client.post("/pki/ca", json={"common_name": "Test CA"}) response = client.get("/pki/ca.crt") assert response.status_code == 200 assert response.content_type == "application/x-pem-file" assert b"-----BEGIN CERTIFICATE-----" in response.data assert b"-----END CERTIFICATE-----" in response.data class TestCertificateIssuance: """Test POST /pki/issue endpoint.""" def test_issue_without_ca_fails(self, client): """Issuance fails when no CA exists.""" response = client.post("/pki/issue", json={"common_name": "alice"}) assert response.status_code == 404 def test_issue_without_name_fails(self, client): """Issuance fails without common_name.""" client.post("/pki/ca") response = client.post("/pki/issue", json={}) assert response.status_code == 400 assert "common_name required" in response.get_json()["error"] def test_issue_certificate_success(self, client): """Certificate issuance succeeds.""" client.post("/pki/ca") response = client.post("/pki/issue", json={"common_name": "alice"}) assert response.status_code == 201 data = response.get_json() assert data["message"] == "Certificate issued" assert data["common_name"] == "alice" assert "serial" in data assert "fingerprint_sha1" in data assert len(data["fingerprint_sha1"]) == 40 assert "certificate_pem" in data assert "private_key_pem" in data assert "-----BEGIN CERTIFICATE-----" in data["certificate_pem"] assert "-----BEGIN PRIVATE KEY-----" in data["private_key_pem"] def test_issue_multiple_certificates(self, client): """Multiple certificates can be issued.""" client.post("/pki/ca") response1 = client.post("/pki/issue", json={"common_name": "alice"}) response2 = client.post("/pki/issue", json={"common_name": "bob"}) assert response1.status_code == 201 assert response2.status_code == 201 data1 = response1.get_json() data2 = response2.get_json() # Different serials and fingerprints assert data1["serial"] != data2["serial"] assert data1["fingerprint_sha1"] != data2["fingerprint_sha1"] def test_first_user_is_admin(self, app, client): """First issued certificate gets admin rights.""" from app.pki import is_admin_certificate client.post("/pki/ca") # First user becomes admin response1 = client.post("/pki/issue", json={"common_name": "admin"}) assert response1.status_code == 201 data1 = response1.get_json() assert data1.get("is_admin") is True with app.app_context(): assert is_admin_certificate(data1["fingerprint_sha1"]) is True # Second user is not admin response2 = client.post("/pki/issue", json={"common_name": "user"}) assert response2.status_code == 201 data2 = response2.get_json() assert data2.get("is_admin") is False with app.app_context(): assert is_admin_certificate(data2["fingerprint_sha1"]) is False class TestCertificateListing: """Test GET /pki/certs endpoint.""" def test_list_anonymous_empty(self, client): """Anonymous users see empty list.""" client.post("/pki/ca") response = client.get("/pki/certs") assert response.status_code == 200 data = response.get_json() assert data["certificates"] == [] assert data["count"] == 0 def test_list_authenticated_sees_own(self, client): """Authenticated users see certificates they issued.""" client.post("/pki/ca") # Issue certificate as authenticated user issuer_fingerprint = "a" * 40 client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer_fingerprint}, ) # List as same user response = client.get("/pki/certs", headers={"X-SSL-Client-SHA1": issuer_fingerprint}) assert response.status_code == 200 data = response.get_json() assert data["count"] == 1 assert data["certificates"][0]["common_name"] == "alice" class TestCertificateRevocation: """Test POST /pki/revoke/ endpoint.""" def test_revoke_unauthenticated_fails(self, client): """Revocation requires authentication.""" client.post("/pki/ca") issue_resp = client.post("/pki/issue", json={"common_name": "alice"}) serial = issue_resp.get_json()["serial"] response = client.post(f"/pki/revoke/{serial}") assert response.status_code == 401 def test_revoke_unauthorized_fails(self, client): """Revocation requires ownership.""" client.post("/pki/ca") # Issue as one user issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": "a" * 40} ) serial = issue_resp.get_json()["serial"] # Try to revoke as different user response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": "b" * 40}) assert response.status_code == 403 def test_revoke_as_issuer_succeeds(self, client): """Issuer can revoke certificate.""" client.post("/pki/ca") issuer = "a" * 40 issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} ) serial = issue_resp.get_json()["serial"] response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) assert response.status_code == 200 assert response.get_json()["message"] == "Certificate revoked" def test_revoke_nonexistent_fails(self, client): """Revoking nonexistent certificate fails.""" client.post("/pki/ca") response = client.post("/pki/revoke/0" * 32, headers={"X-SSL-Client-SHA1": "a" * 40}) assert response.status_code == 404 def test_revoke_twice_fails(self, client): """Certificate cannot be revoked twice.""" client.post("/pki/ca") issuer = "a" * 40 issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} ) serial = issue_resp.get_json()["serial"] # First revocation succeeds response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) assert response.status_code == 200 # Second revocation fails response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) assert response.status_code == 409 class TestRevocationIntegration: """Test revocation affects authentication.""" def test_revoked_cert_treated_as_anonymous(self, client): """Revoked certificate is treated as anonymous.""" client.post("/pki/ca") # Issue certificate issuer = "a" * 40 issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} ) cert_fingerprint = issue_resp.get_json()["fingerprint_sha1"] serial = issue_resp.get_json()["serial"] # Create paste as authenticated user create_resp = client.post( "/", data=b"test content", headers={"X-SSL-Client-SHA1": cert_fingerprint} ) assert create_resp.status_code == 201 paste_id = create_resp.get_json()["id"] assert "owner" in create_resp.get_json() # Revoke the certificate client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) # Revoked cert can still delete their own paste (ownership by fingerprint) # They just lose elevated rate/size limits delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": cert_fingerprint}) assert delete_resp.status_code == 200 class TestPKICryptoFunctions: """Test standalone PKI cryptographic functions.""" def test_derive_key_consistency(self): """Key derivation produces consistent results.""" from app.pki import derive_key password = "test-password" salt = b"x" * 32 key1 = derive_key(password, salt) key2 = derive_key(password, salt) assert key1 == key2 assert len(key1) == 32 def test_encrypt_decrypt_roundtrip(self): """Private key encryption/decryption roundtrip.""" from cryptography.hazmat.primitives.asymmetric import ec from app.pki import decrypt_private_key, encrypt_private_key # Generate a test key private_key = ec.generate_private_key(ec.SECP384R1()) password = "test-password" # Encrypt encrypted, salt = encrypt_private_key(private_key, password) # Decrypt decrypted = decrypt_private_key(encrypted, salt, password) # Verify same key assert private_key.private_numbers() == decrypted.private_numbers() def test_wrong_password_fails(self): """Decryption with wrong password fails.""" from cryptography.hazmat.primitives.asymmetric import ec from app.pki import ( InvalidPasswordError, decrypt_private_key, encrypt_private_key, ) private_key = ec.generate_private_key(ec.SECP384R1()) encrypted, salt = encrypt_private_key(private_key, "correct") with pytest.raises(InvalidPasswordError): decrypt_private_key(encrypted, salt, "wrong") def test_fingerprint_calculation(self): """Certificate fingerprint is calculated correctly.""" from datetime import datetime, timedelta from cryptography import x509 # Minimal self-signed cert for testing from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec from cryptography.x509.oid import NameOID from app.pki import calculate_fingerprint key = ec.generate_private_key(ec.SECP256R1()) subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(subject) .public_key(key.public_key()) .serial_number(1) .not_valid_before(datetime.now(UTC)) .not_valid_after(datetime.now(UTC) + timedelta(days=1)) .sign(key, hashes.SHA256()) ) fingerprint = calculate_fingerprint(cert) assert len(fingerprint) == 40 assert all(c in "0123456789abcdef" for c in fingerprint) class TestRegistration: """Test public certificate registration via /register endpoint.""" def test_register_challenge_returns_token(self, client, app): """Registration challenge endpoint returns PoW token.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 10 response = client.get("/register/challenge") assert response.status_code == 200 data = response.get_json() assert data["enabled"] is True assert "nonce" in data assert "token" in data assert data["difficulty"] == 10 assert data["purpose"] == "registration" def test_register_challenge_disabled(self, client, app): """Registration challenge shows disabled when difficulty is 0.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 0 response = client.get("/register/challenge") assert response.status_code == 200 data = response.get_json() assert data["enabled"] is False def test_register_requires_pow(self, client, app): """Registration fails without PoW when difficulty > 0.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 10 response = client.post("/register", json={"common_name": "test"}) assert response.status_code == 400 assert "Proof-of-work required" in response.get_json()["error"] def test_register_with_pow_disabled_succeeds(self, client, app): """Registration succeeds without PoW when difficulty is 0.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 0 response = client.post("/register", json={"common_name": "test-client"}) assert response.status_code == 200 assert response.content_type == "application/x-pkcs12" assert "X-Fingerprint-SHA1" in response.headers assert len(response.headers["X-Fingerprint-SHA1"]) == 40 def test_register_auto_generates_ca(self, client, app): """Registration auto-generates CA if not present.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 0 # Verify no CA exists from app.pki import get_ca_info assert get_ca_info() is None response = client.post("/register", json={"common_name": "first-client"}) assert response.status_code == 200 # Verify CA now exists with app.app_context(): from app.pki import get_ca_info ca_info = get_ca_info() assert ca_info is not None assert ca_info["common_name"] == "FlaskPaste CA" def test_register_returns_pkcs12(self, client, app): """Registration returns valid PKCS#12 bundle.""" from cryptography.hazmat.primitives.serialization import pkcs12 with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 0 response = client.post("/register", json={"common_name": "my-client"}) assert response.status_code == 200 # Verify PKCS#12 can be loaded p12_data = response.data private_key, certificate, additional_certs = pkcs12.load_key_and_certificates( p12_data, password=None ) assert private_key is not None assert certificate is not None # Should include CA certificate assert additional_certs is not None assert len(additional_certs) == 1 def test_register_generates_common_name(self, client, app): """Registration generates random CN if not provided.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 0 response = client.post("/register") assert response.status_code == 200 # CN is in the Content-Disposition header disposition = response.headers["Content-Disposition"] assert "client-" in disposition assert ".p12" in disposition def test_register_respects_custom_common_name(self, client, app): """Registration uses provided common name.""" with app.app_context(): app.config["REGISTER_POW_DIFFICULTY"] = 0 response = client.post("/register", json={"common_name": "custom-name"}) assert response.status_code == 200 disposition = response.headers["Content-Disposition"] assert "custom-name.p12" in disposition def test_register_without_pki_password_fails(self, client, app): """Registration fails when PKI_CA_PASSWORD not configured.""" with app.app_context(): app.config["PKI_CA_PASSWORD"] = "" app.config["REGISTER_POW_DIFFICULTY"] = 0 response = client.post("/register", json={"common_name": "test"}) assert response.status_code == 503 assert "not available" in response.get_json()["error"] class TestPKCS12Creation: """Test PKCS#12 bundle creation function.""" def test_create_pkcs12_without_password(self): """PKCS#12 created without password can be loaded.""" from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.serialization import pkcs12 from cryptography.x509.oid import NameOID from app.pki import create_pkcs12 # Generate test keys and certs ca_key = ec.generate_private_key(ec.SECP384R1()) client_key = ec.generate_private_key(ec.SECP384R1()) now = datetime.now(UTC) ca_cert = ( x509.CertificateBuilder() .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")])) .issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")])) .public_key(ca_key.public_key()) .serial_number(1) .not_valid_before(now) .not_valid_after(now + timedelta(days=365)) .sign(ca_key, hashes.SHA256()) ) client_cert = ( x509.CertificateBuilder() .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Client")])) .issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")])) .public_key(client_key.public_key()) .serial_number(2) .not_valid_before(now) .not_valid_after(now + timedelta(days=30)) .sign(ca_key, hashes.SHA256()) ) # Create PKCS#12 p12_data = create_pkcs12( private_key=client_key, certificate=client_cert, ca_certificate=ca_cert, friendly_name="test-client", password=None, ) # Load and verify loaded_key, loaded_cert, loaded_cas = pkcs12.load_key_and_certificates( p12_data, password=None ) assert loaded_key is not None assert loaded_cert is not None assert loaded_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == "Client" assert loaded_cas is not None assert len(loaded_cas) == 1 def test_create_pkcs12_with_password(self): """PKCS#12 created with password requires password to load.""" from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.serialization import pkcs12 from cryptography.x509.oid import NameOID from app.pki import create_pkcs12 key = ec.generate_private_key(ec.SECP384R1()) now = datetime.now(UTC) cert = ( x509.CertificateBuilder() .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test")])) .issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test")])) .public_key(key.public_key()) .serial_number(1) .not_valid_before(now) .not_valid_after(now + timedelta(days=1)) .sign(key, hashes.SHA256()) ) p12_data = create_pkcs12( private_key=key, certificate=cert, ca_certificate=cert, friendly_name="test", password=b"secret123", ) # Should fail without password with pytest.raises(ValueError): pkcs12.load_key_and_certificates(p12_data, password=None) # Should succeed with correct password loaded_key, _, _ = pkcs12.load_key_and_certificates(p12_data, password=b"secret123") assert loaded_key is not None class TestAdminPrivileges: """Test admin privileges for list/delete operations.""" def test_admin_can_list_all_pastes(self, client): """Admin can list all pastes with ?all=1.""" # Setup: Create CA and issue admin cert (first user) client.post("/pki/ca") admin_resp = client.post("/pki/issue", json={"common_name": "admin"}) admin_fp = admin_resp.get_json()["fingerprint_sha1"] # Issue non-admin cert (second user) user_resp = client.post("/pki/issue", json={"common_name": "user"}) user_fp = user_resp.get_json()["fingerprint_sha1"] # User creates a paste paste_resp = client.post("/", data=b"user content", headers={"X-SSL-Client-SHA1": user_fp}) assert paste_resp.status_code == 201 # User can only see their own pastes user_list = client.get("/pastes", headers={"X-SSL-Client-SHA1": user_fp}) assert user_list.status_code == 200 assert user_list.get_json()["count"] == 1 assert "is_admin" not in user_list.get_json() # Admin lists only their own by default (no pastes) admin_list = client.get("/pastes", headers={"X-SSL-Client-SHA1": admin_fp}) assert admin_list.status_code == 200 assert admin_list.get_json()["count"] == 0 assert admin_list.get_json()["is_admin"] is True # Admin lists all with ?all=1 admin_all = client.get("/pastes?all=1", headers={"X-SSL-Client-SHA1": admin_fp}) assert admin_all.status_code == 200 assert admin_all.get_json()["count"] == 1 # Includes owner info assert "owner" in admin_all.get_json()["pastes"][0] assert admin_all.get_json()["pastes"][0]["owner"] == user_fp def test_non_admin_cannot_use_all_param(self, client): """Non-admin ?all=1 is ignored.""" client.post("/pki/ca") # First user is admin client.post("/pki/issue", json={"common_name": "admin"}) # Second user is not user_resp = client.post("/pki/issue", json={"common_name": "user"}) user_fp = user_resp.get_json()["fingerprint_sha1"] # User tries ?all=1, should be ignored resp = client.get("/pastes?all=1", headers={"X-SSL-Client-SHA1": user_fp}) assert resp.status_code == 200 assert "is_admin" not in resp.get_json() def test_admin_can_delete_any_paste(self, client): """Admin can delete pastes owned by others.""" client.post("/pki/ca") admin_resp = client.post("/pki/issue", json={"common_name": "admin"}) admin_fp = admin_resp.get_json()["fingerprint_sha1"] user_resp = client.post("/pki/issue", json={"common_name": "user"}) user_fp = user_resp.get_json()["fingerprint_sha1"] # User creates a paste paste_resp = client.post("/", data=b"user content", headers={"X-SSL-Client-SHA1": user_fp}) paste_id = paste_resp.get_json()["id"] # Admin deletes it delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": admin_fp}) assert delete_resp.status_code == 200 assert delete_resp.get_json()["message"] == "Paste deleted" # Verify gone get_resp = client.get(f"/{paste_id}") assert get_resp.status_code == 404 def test_non_admin_cannot_delete_others_paste(self, client): """Non-admin cannot delete pastes owned by others.""" client.post("/pki/ca") # First user is admin (create so second is not) admin_resp = client.post("/pki/issue", json={"common_name": "admin"}) admin_fp = admin_resp.get_json()["fingerprint_sha1"] user_resp = client.post("/pki/issue", json={"common_name": "user"}) user_fp = user_resp.get_json()["fingerprint_sha1"] # Admin creates a paste paste_resp = client.post( "/", data=b"admin content", headers={"X-SSL-Client-SHA1": admin_fp} ) paste_id = paste_resp.get_json()["id"] # User tries to delete it delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": user_fp}) assert delete_resp.status_code == 403