Files
flaskpaste/tests/security/race_condition_test.py
Username bd75f81afd add security testing suite and update docs
- tests/security/pentest_session.py: comprehensive 10-phase pentest
- tests/security/profiled_server.py: cProfile-enabled server
- tests/security/cli_security_audit.py: CLI security checks
- tests/security/dos_memory_test.py: memory exhaustion tests
- tests/security/race_condition_test.py: concurrency tests
- docs: add pentest results, profiling analysis, new test commands
2025-12-26 00:39:33 +01:00

224 lines
7.1 KiB
Python

#!/usr/bin/env python3
"""Race condition tests for FlaskPaste."""
import hashlib
import sys
import threading
import time
sys.path.insert(0, ".")
from app import create_app
from app.database import check_content_hash, get_db
def test_dedup_counter_race():
"""Test content hash dedup counter for race conditions."""
print("\n[1] Content Hash Dedup Counter Race Condition")
print("=" * 50)
app = create_app("testing")
test_content = b"race condition test content"
content_hash = hashlib.sha256(test_content).hexdigest()
with app.app_context():
# Clear any existing entry
db = get_db()
db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,))
db.commit()
results = []
errors = []
def make_request():
with app.app_context():
try:
allowed, count = check_content_hash(content_hash)
results.append((allowed, count))
except Exception as e:
errors.append(str(e))
# Make 20 concurrent requests with same content
threads = [threading.Thread(target=make_request) for _ in range(20)]
for t in threads:
t.start()
for t in threads:
t.join()
# Check results
print(" Threads: 20")
print(f" Errors: {len(errors)}")
print(f" Results: {len(results)}")
if errors:
print(" FAIL: Errors during concurrent access")
for e in errors[:3]:
print(f" {e}")
return False
# Get final count from database
db = get_db()
row = db.execute(
"SELECT count FROM content_hashes WHERE hash = ?", (content_hash,)
).fetchone()
final_count = row["count"] if row else 0
print(f" Final DB count: {final_count}")
# Count how many got allowed
allowed_count = sum(1 for a, _ in results if a)
blocked_count = sum(1 for a, _ in results if not a)
print(f" Allowed: {allowed_count}, Blocked: {blocked_count}")
# With DEDUP_MAX=3 (testing config), we should see exactly 3 allowed
# and the rest blocked
max_allowed = app.config.get("CONTENT_DEDUP_MAX", 3)
# Check for race condition: counter should match allowed count
# (up to max_count)
expected_count = min(allowed_count, max_allowed)
if final_count != expected_count and allowed_count <= max_allowed:
print(" FAIL: Race condition detected!")
print(f" Expected count: {expected_count}, Got: {final_count}")
return False
if allowed_count > max_allowed:
print(f" FAIL: Too many requests allowed ({allowed_count} > {max_allowed})")
return False
print(f" PASS: Counter correctly tracked ({final_count})")
return True
def test_dedup_counter_sequential():
"""Test content hash dedup counter in sequential access."""
print("\n[2] Content Hash Dedup Counter Sequential")
print("=" * 50)
app = create_app("testing")
test_content = b"sequential test content"
content_hash = hashlib.sha256(test_content).hexdigest()
with app.app_context():
# Clear any existing entry
db = get_db()
db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,))
db.commit()
max_allowed = app.config.get("CONTENT_DEDUP_MAX", 3)
# Test with max_allowed + 2 to verify blocking
test_count = min(max_allowed + 2, 10)
results = []
for i in range(test_count):
allowed, count = check_content_hash(content_hash)
results.append((allowed, count))
print(f" Request {i + 1}: allowed={allowed}, count={count}")
# Count allowed/blocked
allowed_count = sum(1 for a, _ in results if a)
blocked_count = sum(1 for a, _ in results if not a)
print(f" Max allowed (config): {max_allowed}")
print(f" Allowed: {allowed_count}, Blocked: {blocked_count}")
# Counter should increment correctly up to max
expected_allowed = min(test_count, max_allowed)
expected_blocked = test_count - expected_allowed
if allowed_count != expected_allowed:
print(f" FAIL: Expected {expected_allowed} allowed, got {allowed_count}")
return False
if blocked_count != expected_blocked:
print(f" FAIL: Expected {expected_blocked} blocked, got {blocked_count}")
return False
# Verify counter values are sequential up to max
expected_counts = list(range(1, min(test_count, max_allowed) + 1))
if test_count > max_allowed:
expected_counts += [max_allowed] * (test_count - max_allowed)
actual_counts = [c for _, c in results]
if actual_counts != expected_counts:
print(" FAIL: Counter sequence mismatch")
print(f" Expected: {expected_counts}")
print(f" Got: {actual_counts}")
return False
print(" PASS: Counter correctly incremented")
return True
def test_dedup_window_expiry():
"""Test content hash dedup window expiry."""
print("\n[3] Content Hash Dedup Window Expiry")
print("=" * 50)
app = create_app("testing")
test_content = b"window expiry test"
content_hash = hashlib.sha256(test_content).hexdigest()
with app.app_context():
# Clear any existing entry
db = get_db()
db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,))
db.commit()
# Insert an old entry (past window)
window = app.config.get("CONTENT_DEDUP_WINDOW", 3600)
old_time = int(time.time()) - window - 10 # 10 seconds past window
db.execute(
"INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 5)",
(content_hash, old_time, old_time),
)
db.commit()
# Should reset counter since outside window
allowed, count = check_content_hash(content_hash)
print(f" Window: {window}s")
print(" Old entry count: 5")
print(f" After check: allowed={allowed}, count={count}")
if not allowed or count != 1:
print(" FAIL: Window expiry not resetting counter")
return False
print(" PASS: Counter reset after window expiry")
return True
def main():
print("=" * 60)
print("RACE CONDITION TESTS")
print("=" * 60)
results = []
results.append(("Dedup Counter Race", test_dedup_counter_race()))
results.append(("Dedup Counter Sequential", test_dedup_counter_sequential()))
results.append(("Dedup Window Expiry", test_dedup_window_expiry()))
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
passed = sum(1 for _, r in results if r)
total = len(results)
for name, result in results:
status = "PASS" if result else "FAIL"
print(f" {status}: {name}")
print(f"\n{passed}/{total} checks passed")
return 0 if passed == total else 1
if __name__ == "__main__":
sys.exit(main())