Split authentication into two functions: - get_client_fingerprint(): Identity for ownership (any cert) - get_client_id(): Elevated privileges (trusted certs only) Behavior: - Anonymous: Create only, strict limits - Untrusted cert: Create + delete/update/list own pastes, strict limits - Trusted cert: All operations, relaxed limits (50MB, 5x rate) Updated tests to reflect new behavior where revoked certs can still manage their own pastes.
598 lines
21 KiB
Python
598 lines
21 KiB
Python
"""Tests for PKI (Certificate Authority) functionality."""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
from app.pki import reset_pki
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_pki_state(app):
|
|
"""Reset PKI state and clear PKI database tables before each test."""
|
|
reset_pki()
|
|
|
|
# Clear PKI tables in database
|
|
with app.app_context():
|
|
from app.database import get_db
|
|
|
|
db = get_db()
|
|
db.execute("DELETE FROM issued_certificates")
|
|
db.execute("DELETE FROM certificate_authority")
|
|
db.commit()
|
|
|
|
yield
|
|
reset_pki()
|
|
|
|
|
|
class TestPKIStatus:
|
|
"""Test GET /pki endpoint."""
|
|
|
|
def test_pki_status_when_enabled(self, client):
|
|
"""PKI status shows enabled with no CA initially."""
|
|
response = client.get("/pki")
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
assert data["enabled"] is True
|
|
assert data["ca_exists"] is False
|
|
assert "hint" in data
|
|
|
|
def test_pki_status_after_ca_generation(self, client):
|
|
"""PKI status shows CA info after generation."""
|
|
# Generate CA first
|
|
client.post("/pki/ca", json={"common_name": "Test CA"})
|
|
|
|
response = client.get("/pki")
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
assert data["enabled"] is True
|
|
assert data["ca_exists"] is True
|
|
assert data["common_name"] == "Test CA"
|
|
assert "fingerprint_sha1" in data
|
|
assert len(data["fingerprint_sha1"]) == 40
|
|
|
|
|
|
class TestCAGeneration:
|
|
"""Test POST /pki/ca endpoint."""
|
|
|
|
def test_generate_ca_success(self, client):
|
|
"""CA can be generated with default name."""
|
|
response = client.post("/pki/ca")
|
|
assert response.status_code == 201
|
|
data = response.get_json()
|
|
assert data["message"] == "CA generated"
|
|
assert data["common_name"] == "FlaskPaste CA"
|
|
assert "fingerprint_sha1" in data
|
|
assert "created_at" in data
|
|
assert "expires_at" in data
|
|
assert data["download"] == "/pki/ca.crt"
|
|
|
|
def test_generate_ca_custom_name(self, client):
|
|
"""CA can be generated with custom name."""
|
|
response = client.post("/pki/ca", json={"common_name": "My Custom CA"})
|
|
assert response.status_code == 201
|
|
data = response.get_json()
|
|
assert data["common_name"] == "My Custom CA"
|
|
|
|
def test_generate_ca_twice_fails(self, client):
|
|
"""CA cannot be generated twice."""
|
|
# First generation succeeds
|
|
response = client.post("/pki/ca")
|
|
assert response.status_code == 201
|
|
|
|
# Second generation fails
|
|
response = client.post("/pki/ca")
|
|
assert response.status_code == 409
|
|
data = response.get_json()
|
|
assert "already exists" in data["error"]
|
|
|
|
|
|
class TestCADownload:
|
|
"""Test GET /pki/ca.crt endpoint."""
|
|
|
|
def test_download_ca_not_initialized(self, client):
|
|
"""Download fails when no CA exists."""
|
|
response = client.get("/pki/ca.crt")
|
|
assert response.status_code == 404
|
|
|
|
def test_download_ca_success(self, client):
|
|
"""CA certificate can be downloaded."""
|
|
# Generate CA first
|
|
client.post("/pki/ca", json={"common_name": "Test CA"})
|
|
|
|
response = client.get("/pki/ca.crt")
|
|
assert response.status_code == 200
|
|
assert response.content_type == "application/x-pem-file"
|
|
assert b"-----BEGIN CERTIFICATE-----" in response.data
|
|
assert b"-----END CERTIFICATE-----" in response.data
|
|
|
|
|
|
class TestCertificateIssuance:
|
|
"""Test POST /pki/issue endpoint."""
|
|
|
|
def test_issue_without_ca_fails(self, client):
|
|
"""Issuance fails when no CA exists."""
|
|
response = client.post("/pki/issue", json={"common_name": "alice"})
|
|
assert response.status_code == 404
|
|
|
|
def test_issue_without_name_fails(self, client):
|
|
"""Issuance fails without common_name."""
|
|
client.post("/pki/ca")
|
|
|
|
response = client.post("/pki/issue", json={})
|
|
assert response.status_code == 400
|
|
assert "common_name required" in response.get_json()["error"]
|
|
|
|
def test_issue_certificate_success(self, client):
|
|
"""Certificate issuance succeeds."""
|
|
client.post("/pki/ca")
|
|
|
|
response = client.post("/pki/issue", json={"common_name": "alice"})
|
|
assert response.status_code == 201
|
|
data = response.get_json()
|
|
assert data["message"] == "Certificate issued"
|
|
assert data["common_name"] == "alice"
|
|
assert "serial" in data
|
|
assert "fingerprint_sha1" in data
|
|
assert len(data["fingerprint_sha1"]) == 40
|
|
assert "certificate_pem" in data
|
|
assert "private_key_pem" in data
|
|
assert "-----BEGIN CERTIFICATE-----" in data["certificate_pem"]
|
|
assert "-----BEGIN PRIVATE KEY-----" in data["private_key_pem"]
|
|
|
|
def test_issue_multiple_certificates(self, client):
|
|
"""Multiple certificates can be issued."""
|
|
client.post("/pki/ca")
|
|
|
|
response1 = client.post("/pki/issue", json={"common_name": "alice"})
|
|
response2 = client.post("/pki/issue", json={"common_name": "bob"})
|
|
|
|
assert response1.status_code == 201
|
|
assert response2.status_code == 201
|
|
|
|
data1 = response1.get_json()
|
|
data2 = response2.get_json()
|
|
|
|
# Different serials and fingerprints
|
|
assert data1["serial"] != data2["serial"]
|
|
assert data1["fingerprint_sha1"] != data2["fingerprint_sha1"]
|
|
|
|
|
|
class TestCertificateListing:
|
|
"""Test GET /pki/certs endpoint."""
|
|
|
|
def test_list_anonymous_empty(self, client):
|
|
"""Anonymous users see empty list."""
|
|
client.post("/pki/ca")
|
|
|
|
response = client.get("/pki/certs")
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
assert data["certificates"] == []
|
|
assert data["count"] == 0
|
|
|
|
def test_list_authenticated_sees_own(self, client):
|
|
"""Authenticated users see certificates they issued."""
|
|
client.post("/pki/ca")
|
|
|
|
# Issue certificate as authenticated user
|
|
issuer_fingerprint = "a" * 40
|
|
client.post(
|
|
"/pki/issue",
|
|
json={"common_name": "alice"},
|
|
headers={"X-SSL-Client-SHA1": issuer_fingerprint},
|
|
)
|
|
|
|
# List as same user
|
|
response = client.get("/pki/certs", headers={"X-SSL-Client-SHA1": issuer_fingerprint})
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
assert data["count"] == 1
|
|
assert data["certificates"][0]["common_name"] == "alice"
|
|
|
|
|
|
class TestCertificateRevocation:
|
|
"""Test POST /pki/revoke/<serial> endpoint."""
|
|
|
|
def test_revoke_unauthenticated_fails(self, client):
|
|
"""Revocation requires authentication."""
|
|
client.post("/pki/ca")
|
|
issue_resp = client.post("/pki/issue", json={"common_name": "alice"})
|
|
serial = issue_resp.get_json()["serial"]
|
|
|
|
response = client.post(f"/pki/revoke/{serial}")
|
|
assert response.status_code == 401
|
|
|
|
def test_revoke_unauthorized_fails(self, client):
|
|
"""Revocation requires ownership."""
|
|
client.post("/pki/ca")
|
|
|
|
# Issue as one user
|
|
issue_resp = client.post(
|
|
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": "a" * 40}
|
|
)
|
|
serial = issue_resp.get_json()["serial"]
|
|
|
|
# Try to revoke as different user
|
|
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": "b" * 40})
|
|
assert response.status_code == 403
|
|
|
|
def test_revoke_as_issuer_succeeds(self, client):
|
|
"""Issuer can revoke certificate."""
|
|
client.post("/pki/ca")
|
|
|
|
issuer = "a" * 40
|
|
issue_resp = client.post(
|
|
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer}
|
|
)
|
|
serial = issue_resp.get_json()["serial"]
|
|
|
|
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
|
|
assert response.status_code == 200
|
|
assert response.get_json()["message"] == "Certificate revoked"
|
|
|
|
def test_revoke_nonexistent_fails(self, client):
|
|
"""Revoking nonexistent certificate fails."""
|
|
client.post("/pki/ca")
|
|
|
|
response = client.post("/pki/revoke/0" * 32, headers={"X-SSL-Client-SHA1": "a" * 40})
|
|
assert response.status_code == 404
|
|
|
|
def test_revoke_twice_fails(self, client):
|
|
"""Certificate cannot be revoked twice."""
|
|
client.post("/pki/ca")
|
|
|
|
issuer = "a" * 40
|
|
issue_resp = client.post(
|
|
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer}
|
|
)
|
|
serial = issue_resp.get_json()["serial"]
|
|
|
|
# First revocation succeeds
|
|
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
|
|
assert response.status_code == 200
|
|
|
|
# Second revocation fails
|
|
response = client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
|
|
assert response.status_code == 409
|
|
|
|
|
|
class TestRevocationIntegration:
|
|
"""Test revocation affects authentication."""
|
|
|
|
def test_revoked_cert_treated_as_anonymous(self, client):
|
|
"""Revoked certificate is treated as anonymous."""
|
|
client.post("/pki/ca")
|
|
|
|
# Issue certificate
|
|
issuer = "a" * 40
|
|
issue_resp = client.post(
|
|
"/pki/issue", json={"common_name": "alice"}, headers={"X-SSL-Client-SHA1": issuer}
|
|
)
|
|
cert_fingerprint = issue_resp.get_json()["fingerprint_sha1"]
|
|
serial = issue_resp.get_json()["serial"]
|
|
|
|
# Create paste as authenticated user
|
|
create_resp = client.post(
|
|
"/", data=b"test content", headers={"X-SSL-Client-SHA1": cert_fingerprint}
|
|
)
|
|
assert create_resp.status_code == 201
|
|
paste_id = create_resp.get_json()["id"]
|
|
assert "owner" in create_resp.get_json()
|
|
|
|
# Revoke the certificate
|
|
client.post(f"/pki/revoke/{serial}", headers={"X-SSL-Client-SHA1": issuer})
|
|
|
|
# Revoked cert can still delete their own paste (ownership by fingerprint)
|
|
# They just lose elevated rate/size limits
|
|
delete_resp = client.delete(f"/{paste_id}", headers={"X-SSL-Client-SHA1": cert_fingerprint})
|
|
assert delete_resp.status_code == 200
|
|
|
|
|
|
class TestPKICryptoFunctions:
|
|
"""Test standalone PKI cryptographic functions."""
|
|
|
|
def test_derive_key_consistency(self):
|
|
"""Key derivation produces consistent results."""
|
|
from app.pki import derive_key
|
|
|
|
password = "test-password"
|
|
salt = b"x" * 32
|
|
|
|
key1 = derive_key(password, salt)
|
|
key2 = derive_key(password, salt)
|
|
|
|
assert key1 == key2
|
|
assert len(key1) == 32
|
|
|
|
def test_encrypt_decrypt_roundtrip(self):
|
|
"""Private key encryption/decryption roundtrip."""
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
|
from app.pki import decrypt_private_key, encrypt_private_key
|
|
|
|
# Generate a test key
|
|
private_key = ec.generate_private_key(ec.SECP384R1())
|
|
password = "test-password"
|
|
|
|
# Encrypt
|
|
encrypted, salt = encrypt_private_key(private_key, password)
|
|
|
|
# Decrypt
|
|
decrypted = decrypt_private_key(encrypted, salt, password)
|
|
|
|
# Verify same key
|
|
assert private_key.private_numbers() == decrypted.private_numbers()
|
|
|
|
def test_wrong_password_fails(self):
|
|
"""Decryption with wrong password fails."""
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
|
from app.pki import (
|
|
InvalidPasswordError,
|
|
decrypt_private_key,
|
|
encrypt_private_key,
|
|
)
|
|
|
|
private_key = ec.generate_private_key(ec.SECP384R1())
|
|
encrypted, salt = encrypt_private_key(private_key, "correct")
|
|
|
|
with pytest.raises(InvalidPasswordError):
|
|
decrypt_private_key(encrypted, salt, "wrong")
|
|
|
|
def test_fingerprint_calculation(self):
|
|
"""Certificate fingerprint is calculated correctly."""
|
|
from datetime import datetime, timedelta
|
|
|
|
from cryptography import x509
|
|
|
|
# Minimal self-signed cert for testing
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.x509.oid import NameOID
|
|
|
|
from app.pki import calculate_fingerprint
|
|
|
|
key = ec.generate_private_key(ec.SECP256R1())
|
|
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
|
|
cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(subject)
|
|
.public_key(key.public_key())
|
|
.serial_number(1)
|
|
.not_valid_before(datetime.now(UTC))
|
|
.not_valid_after(datetime.now(UTC) + timedelta(days=1))
|
|
.sign(key, hashes.SHA256())
|
|
)
|
|
|
|
fingerprint = calculate_fingerprint(cert)
|
|
|
|
assert len(fingerprint) == 40
|
|
assert all(c in "0123456789abcdef" for c in fingerprint)
|
|
|
|
|
|
class TestRegistration:
|
|
"""Test public certificate registration via /register endpoint."""
|
|
|
|
def test_register_challenge_returns_token(self, client, app):
|
|
"""Registration challenge endpoint returns PoW token."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 10
|
|
|
|
response = client.get("/register/challenge")
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
assert data["enabled"] is True
|
|
assert "nonce" in data
|
|
assert "token" in data
|
|
assert data["difficulty"] == 10
|
|
assert data["purpose"] == "registration"
|
|
|
|
def test_register_challenge_disabled(self, client, app):
|
|
"""Registration challenge shows disabled when difficulty is 0."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
response = client.get("/register/challenge")
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
assert data["enabled"] is False
|
|
|
|
def test_register_requires_pow(self, client, app):
|
|
"""Registration fails without PoW when difficulty > 0."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 10
|
|
|
|
response = client.post("/register", json={"common_name": "test"})
|
|
assert response.status_code == 400
|
|
assert "Proof-of-work required" in response.get_json()["error"]
|
|
|
|
def test_register_with_pow_disabled_succeeds(self, client, app):
|
|
"""Registration succeeds without PoW when difficulty is 0."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
response = client.post("/register", json={"common_name": "test-client"})
|
|
assert response.status_code == 200
|
|
assert response.content_type == "application/x-pkcs12"
|
|
assert "X-Fingerprint-SHA1" in response.headers
|
|
assert len(response.headers["X-Fingerprint-SHA1"]) == 40
|
|
|
|
def test_register_auto_generates_ca(self, client, app):
|
|
"""Registration auto-generates CA if not present."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
# Verify no CA exists
|
|
from app.pki import get_ca_info
|
|
|
|
assert get_ca_info() is None
|
|
|
|
response = client.post("/register", json={"common_name": "first-client"})
|
|
assert response.status_code == 200
|
|
|
|
# Verify CA now exists
|
|
with app.app_context():
|
|
from app.pki import get_ca_info
|
|
|
|
ca_info = get_ca_info()
|
|
assert ca_info is not None
|
|
assert ca_info["common_name"] == "FlaskPaste CA"
|
|
|
|
def test_register_returns_pkcs12(self, client, app):
|
|
"""Registration returns valid PKCS#12 bundle."""
|
|
from cryptography.hazmat.primitives.serialization import pkcs12
|
|
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
response = client.post("/register", json={"common_name": "my-client"})
|
|
assert response.status_code == 200
|
|
|
|
# Verify PKCS#12 can be loaded
|
|
p12_data = response.data
|
|
private_key, certificate, additional_certs = pkcs12.load_key_and_certificates(
|
|
p12_data, password=None
|
|
)
|
|
|
|
assert private_key is not None
|
|
assert certificate is not None
|
|
# Should include CA certificate
|
|
assert additional_certs is not None
|
|
assert len(additional_certs) == 1
|
|
|
|
def test_register_generates_common_name(self, client, app):
|
|
"""Registration generates random CN if not provided."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
response = client.post("/register")
|
|
assert response.status_code == 200
|
|
|
|
# CN is in the Content-Disposition header
|
|
disposition = response.headers["Content-Disposition"]
|
|
assert "client-" in disposition
|
|
assert ".p12" in disposition
|
|
|
|
def test_register_respects_custom_common_name(self, client, app):
|
|
"""Registration uses provided common name."""
|
|
with app.app_context():
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
response = client.post("/register", json={"common_name": "custom-name"})
|
|
assert response.status_code == 200
|
|
|
|
disposition = response.headers["Content-Disposition"]
|
|
assert "custom-name.p12" in disposition
|
|
|
|
def test_register_without_pki_password_fails(self, client, app):
|
|
"""Registration fails when PKI_CA_PASSWORD not configured."""
|
|
with app.app_context():
|
|
app.config["PKI_CA_PASSWORD"] = ""
|
|
app.config["REGISTER_POW_DIFFICULTY"] = 0
|
|
|
|
response = client.post("/register", json={"common_name": "test"})
|
|
assert response.status_code == 503
|
|
assert "not available" in response.get_json()["error"]
|
|
|
|
|
|
class TestPKCS12Creation:
|
|
"""Test PKCS#12 bundle creation function."""
|
|
|
|
def test_create_pkcs12_without_password(self):
|
|
"""PKCS#12 created without password can be loaded."""
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.serialization import pkcs12
|
|
from cryptography.x509.oid import NameOID
|
|
|
|
from app.pki import create_pkcs12
|
|
|
|
# Generate test keys and certs
|
|
ca_key = ec.generate_private_key(ec.SECP384R1())
|
|
client_key = ec.generate_private_key(ec.SECP384R1())
|
|
|
|
now = datetime.now(UTC)
|
|
ca_cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")]))
|
|
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")]))
|
|
.public_key(ca_key.public_key())
|
|
.serial_number(1)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=365))
|
|
.sign(ca_key, hashes.SHA256())
|
|
)
|
|
|
|
client_cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Client")]))
|
|
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test CA")]))
|
|
.public_key(client_key.public_key())
|
|
.serial_number(2)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=30))
|
|
.sign(ca_key, hashes.SHA256())
|
|
)
|
|
|
|
# Create PKCS#12
|
|
p12_data = create_pkcs12(
|
|
private_key=client_key,
|
|
certificate=client_cert,
|
|
ca_certificate=ca_cert,
|
|
friendly_name="test-client",
|
|
password=None,
|
|
)
|
|
|
|
# Load and verify
|
|
loaded_key, loaded_cert, loaded_cas = pkcs12.load_key_and_certificates(
|
|
p12_data, password=None
|
|
)
|
|
|
|
assert loaded_key is not None
|
|
assert loaded_cert is not None
|
|
assert loaded_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == "Client"
|
|
assert loaded_cas is not None
|
|
assert len(loaded_cas) == 1
|
|
|
|
def test_create_pkcs12_with_password(self):
|
|
"""PKCS#12 created with password requires password to load."""
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.serialization import pkcs12
|
|
from cryptography.x509.oid import NameOID
|
|
|
|
from app.pki import create_pkcs12
|
|
|
|
key = ec.generate_private_key(ec.SECP384R1())
|
|
now = datetime.now(UTC)
|
|
cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test")]))
|
|
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test")]))
|
|
.public_key(key.public_key())
|
|
.serial_number(1)
|
|
.not_valid_before(now)
|
|
.not_valid_after(now + timedelta(days=1))
|
|
.sign(key, hashes.SHA256())
|
|
)
|
|
|
|
p12_data = create_pkcs12(
|
|
private_key=key,
|
|
certificate=cert,
|
|
ca_certificate=cert,
|
|
friendly_name="test",
|
|
password=b"secret123",
|
|
)
|
|
|
|
# Should fail without password
|
|
with pytest.raises(ValueError):
|
|
pkcs12.load_key_and_certificates(p12_data, password=None)
|
|
|
|
# Should succeed with correct password
|
|
loaded_key, _, _ = pkcs12.load_key_and_certificates(p12_data, password=b"secret123")
|
|
assert loaded_key is not None
|