"""Tests for PKI (Certificate Authority) functionality.""" from datetime import UTC import pytest from app.pki import reset_pki @pytest.fixture(autouse=True) def reset_pki_state(app): """Reset PKI state and clear PKI database tables before each test.""" reset_pki() # Clear PKI tables in database with app.app_context(): from app.database import get_db db = get_db() db.execute("DELETE FROM issued_certificates") db.execute("DELETE FROM certificate_authority") db.commit() yield reset_pki() class TestPKIStatus: """Test GET /pki endpoint.""" def test_pki_status_when_enabled(self, client): """PKI status shows enabled with no CA initially.""" response = client.get("/pki") assert response.status_code == 200 data = response.get_json() assert data["enabled"] is True assert data["ca_exists"] is False assert "hint" in data def test_pki_status_after_ca_generation(self, client): """PKI status shows CA info after generation.""" # Generate CA first client.post("/pki/ca", json={"common_name": "Test CA"}) response = client.get("/pki") assert response.status_code == 200 data = response.get_json() assert data["enabled"] is True assert data["ca_exists"] is True assert data["common_name"] == "Test CA" assert "fingerprint_sha1" in data assert len(data["fingerprint_sha1"]) == 40 class TestCAGeneration: """Test POST /pki/ca endpoint.""" def test_generate_ca_success(self, client): """CA can be generated with default name.""" response = client.post("/pki/ca") assert response.status_code == 201 data = response.get_json() assert data["message"] == "CA generated" assert data["common_name"] == "FlaskPaste CA" assert "fingerprint_sha1" in data assert "created_at" in data assert "expires_at" in data assert data["download"] == "/pki/ca.crt" def test_generate_ca_custom_name(self, client): """CA can be generated with custom name.""" response = client.post("/pki/ca", json={"common_name": "My Custom CA"}) assert response.status_code == 201 data = response.get_json() assert data["common_name"] == "My Custom CA" def test_generate_ca_twice_fails(self, client): """CA cannot be generated twice.""" # First generation succeeds response = client.post("/pki/ca") assert response.status_code == 201 # Second generation fails response = client.post("/pki/ca") assert response.status_code == 409 data = response.get_json() assert "already exists" in data["error"] class TestCADownload: """Test GET /pki/ca.crt endpoint.""" def test_download_ca_not_initialized(self, client): """Download fails when no CA exists.""" response = client.get("/pki/ca.crt") assert response.status_code == 404 def test_download_ca_success(self, client): """CA certificate can be downloaded.""" # Generate CA first client.post("/pki/ca", json={"common_name": "Test CA"}) response = client.get("/pki/ca.crt") assert response.status_code == 200 assert response.content_type == "application/x-pem-file" assert b"-----BEGIN CERTIFICATE-----" in response.data assert b"-----END CERTIFICATE-----" in response.data class TestCertificateIssuance: """Test POST /pki/issue endpoint.""" def test_issue_without_ca_fails(self, client): """Issuance fails when no CA exists.""" response = client.post("/pki/issue", json={"common_name": "alice"}) assert response.status_code == 404 def test_issue_without_name_fails(self, client): """Issuance fails without common_name.""" client.post("/pki/ca") response = client.post("/pki/issue", json={}) assert response.status_code == 400 assert "common_name required" in response.get_json()["error"] def test_issue_certificate_success(self, client): """Certificate issuance succeeds.""" client.post("/pki/ca") response = client.post("/pki/issue", json={"common_name": "alice"}) assert response.status_code == 201 data = response.get_json() assert data["message"] == "Certificate issued" assert data["common_name"] == "alice" assert "serial" in data assert "fingerprint_sha1" in data assert len(data["fingerprint_sha1"]) == 40 assert "certificate_pem" in data assert "private_key_pem" in data assert "-----BEGIN CERTIFICATE-----" in data["certificate_pem"] assert "-----BEGIN PRIVATE KEY-----" in data["private_key_pem"] def test_issue_multiple_certificates(self, client): """Multiple certificates can be issued.""" client.post("/pki/ca") response1 = client.post("/pki/issue", json={"common_name": "alice"}) response2 = client.post("/pki/issue", json={"common_name": "bob"}) assert response1.status_code == 201 assert response2.status_code == 201 data1 = response1.get_json() data2 = response2.get_json() # Different serials and fingerprints assert data1["serial"] != data2["serial"] assert data1["fingerprint_sha1"] != data2["fingerprint_sha1"] class TestCertificateListing: """Test GET /pki/certs endpoint.""" def test_list_anonymous_empty(self, client): """Anonymous users see empty list.""" client.post("/pki/ca") response = client.get("/pki/certs") assert response.status_code == 200 data = response.get_json() assert data["certificates"] == [] assert data["count"] == 0 def test_list_authenticated_sees_own(self, client): """Authenticated users see certificates they issued.""" client.post("/pki/ca") # Issue certificate as authenticated user issuer_fingerprint = "a" * 40 client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer_fingerprint}, ) # List as same user response = client.get("/pki/certs", headers={"X-SSL-Client-SHA1": issuer_fingerprint}) assert response.status_code == 200 data = response.get_json() assert data["count"] == 1 assert data["certificates"][0]["common_name"] == "alice" class TestCertificateRevocation: """Test POST /pki/revoke/ endpoint.""" def test_revoke_unauthenticated_fails(self, client): """Revocation requires authentication.""" client.post("/pki/ca") issue_resp = client.post("/pki/issue", json={"common_name": "alice"}) serial = issue_resp.get_json()["serial"] response = client.post(f"/pki/revoke/{serial}") assert response.status_code == 401 def test_revoke_unauthorized_fails(self, client): """Revocation requires ownership.""" client.post("/pki/ca") # Issue as one user issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": "a" * 40} ) serial = issue_resp.get_json()["serial"] # Try to revoke as different user response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": "b" * 40}) assert response.status_code == 403 def test_revoke_as_issuer_succeeds(self, client): """Issuer can revoke certificate.""" client.post("/pki/ca") issuer = "a" * 40 issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} ) serial = issue_resp.get_json()["serial"] response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) assert response.status_code == 200 assert response.get_json()["message"] == "Certificate revoked" def test_revoke_nonexistent_fails(self, client): """Revoking nonexistent certificate fails.""" client.post("/pki/ca") response = client.post("/pki/revoke/0" * 32, headers={"X-SSL-Client-SHA1": "a" * 40}) assert response.status_code == 404 def test_revoke_twice_fails(self, client): """Certificate cannot be revoked twice.""" client.post("/pki/ca") issuer = "a" * 40 issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} ) serial = issue_resp.get_json()["serial"] # First revocation succeeds response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) assert response.status_code == 200 # Second revocation fails response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) assert response.status_code == 409 class TestRevocationIntegration: """Test revocation affects authentication.""" def test_revoked_cert_treated_as_anonymous(self, client): """Revoked certificate is treated as anonymous.""" client.post("/pki/ca") # Issue certificate issuer = "a" * 40 issue_resp = client.post( "/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer} ) cert_fingerprint = issue_resp.get_json()["fingerprint_sha1"] serial = issue_resp.get_json()["serial"] # Create paste as authenticated user create_resp = client.post( "/", data=b"test content", headers={"X-SSL-Client-SHA1": cert_fingerprint} ) assert create_resp.status_code == 201 paste_id = create_resp.get_json()["id"] assert "owner" in create_resp.get_json() # Revoke the certificate client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer}) # Try to delete paste with revoked cert - should fail delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": cert_fingerprint}) assert delete_resp.status_code == 401 class TestPKICryptoFunctions: """Test standalone PKI cryptographic functions.""" def test_derive_key_consistency(self): """Key derivation produces consistent results.""" from app.pki import derive_key password = "test-password" salt = b"x" * 32 key1 = derive_key(password, salt) key2 = derive_key(password, salt) assert key1 == key2 assert len(key1) == 32 def test_encrypt_decrypt_roundtrip(self): """Private key encryption/decryption roundtrip.""" from cryptography.hazmat.primitives.asymmetric import ec from app.pki import decrypt_private_key, encrypt_private_key # Generate a test key private_key = ec.generate_private_key(ec.SECP384R1()) password = "test-password" # Encrypt encrypted, salt = encrypt_private_key(private_key, password) # Decrypt decrypted = decrypt_private_key(encrypted, salt, password) # Verify same key assert private_key.private_numbers() == decrypted.private_numbers() def test_wrong_password_fails(self): """Decryption with wrong password fails.""" from cryptography.hazmat.primitives.asymmetric import ec from app.pki import ( InvalidPasswordError, decrypt_private_key, encrypt_private_key, ) private_key = ec.generate_private_key(ec.SECP384R1()) encrypted, salt = encrypt_private_key(private_key, "correct") with pytest.raises(InvalidPasswordError): decrypt_private_key(encrypted, salt, "wrong") def test_fingerprint_calculation(self): """Certificate fingerprint is calculated correctly.""" from datetime import datetime, timedelta from cryptography import x509 # Minimal self-signed cert for testing from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec from cryptography.x509.oid import NameOID from app.pki import calculate_fingerprint key = ec.generate_private_key(ec.SECP256R1()) subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(subject) .public_key(key.public_key()) .serial_number(1) .not_valid_before(datetime.now(UTC)) .not_valid_after(datetime.now(UTC) + timedelta(days=1)) .sign(key, hashes.SHA256()) ) fingerprint = calculate_fingerprint(cert) assert len(fingerprint) == 40 assert all(c in "0123456789abcdef" for c in fingerprint)