Some checks failed
CI / Lint & Format (push) Failing after 15s
CI / Unit Tests (push) Has been skipped
CI / Memory Leak Check (push) Has been skipped
CI / SBOM Generation (push) Has been skipped
CI / Security Scan (push) Successful in 19s
CI / Security Tests (push) Has been skipped
- tests/security/pentest_session.py: comprehensive 10-phase pentest - tests/security/profiled_server.py: cProfile-enabled server - tests/security/cli_security_audit.py: CLI security checks - tests/security/dos_memory_test.py: memory exhaustion tests - tests/security/race_condition_test.py: concurrency tests - docs: add pentest results, profiling analysis, new test commands
224 lines
7.1 KiB
Python
224 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Race condition tests for FlaskPaste."""
|
|
|
|
import hashlib
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
sys.path.insert(0, ".")
|
|
|
|
from app import create_app
|
|
from app.database import check_content_hash, get_db
|
|
|
|
|
|
def test_dedup_counter_race():
|
|
"""Test content hash dedup counter for race conditions."""
|
|
print("\n[1] Content Hash Dedup Counter Race Condition")
|
|
print("=" * 50)
|
|
|
|
app = create_app("testing")
|
|
test_content = b"race condition test content"
|
|
content_hash = hashlib.sha256(test_content).hexdigest()
|
|
|
|
with app.app_context():
|
|
# Clear any existing entry
|
|
db = get_db()
|
|
db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,))
|
|
db.commit()
|
|
|
|
results = []
|
|
errors = []
|
|
|
|
def make_request():
|
|
with app.app_context():
|
|
try:
|
|
allowed, count = check_content_hash(content_hash)
|
|
results.append((allowed, count))
|
|
except Exception as e:
|
|
errors.append(str(e))
|
|
|
|
# Make 20 concurrent requests with same content
|
|
threads = [threading.Thread(target=make_request) for _ in range(20)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Check results
|
|
print(" Threads: 20")
|
|
print(f" Errors: {len(errors)}")
|
|
print(f" Results: {len(results)}")
|
|
|
|
if errors:
|
|
print(" FAIL: Errors during concurrent access")
|
|
for e in errors[:3]:
|
|
print(f" {e}")
|
|
return False
|
|
|
|
# Get final count from database
|
|
db = get_db()
|
|
row = db.execute(
|
|
"SELECT count FROM content_hashes WHERE hash = ?", (content_hash,)
|
|
).fetchone()
|
|
|
|
final_count = row["count"] if row else 0
|
|
print(f" Final DB count: {final_count}")
|
|
|
|
# Count how many got allowed
|
|
allowed_count = sum(1 for a, _ in results if a)
|
|
blocked_count = sum(1 for a, _ in results if not a)
|
|
print(f" Allowed: {allowed_count}, Blocked: {blocked_count}")
|
|
|
|
# With DEDUP_MAX=3 (testing config), we should see exactly 3 allowed
|
|
# and the rest blocked
|
|
max_allowed = app.config.get("CONTENT_DEDUP_MAX", 3)
|
|
|
|
# Check for race condition: counter should match allowed count
|
|
# (up to max_count)
|
|
expected_count = min(allowed_count, max_allowed)
|
|
|
|
if final_count != expected_count and allowed_count <= max_allowed:
|
|
print(" FAIL: Race condition detected!")
|
|
print(f" Expected count: {expected_count}, Got: {final_count}")
|
|
return False
|
|
|
|
if allowed_count > max_allowed:
|
|
print(f" FAIL: Too many requests allowed ({allowed_count} > {max_allowed})")
|
|
return False
|
|
|
|
print(f" PASS: Counter correctly tracked ({final_count})")
|
|
return True
|
|
|
|
|
|
def test_dedup_counter_sequential():
|
|
"""Test content hash dedup counter in sequential access."""
|
|
print("\n[2] Content Hash Dedup Counter Sequential")
|
|
print("=" * 50)
|
|
|
|
app = create_app("testing")
|
|
test_content = b"sequential test content"
|
|
content_hash = hashlib.sha256(test_content).hexdigest()
|
|
|
|
with app.app_context():
|
|
# Clear any existing entry
|
|
db = get_db()
|
|
db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,))
|
|
db.commit()
|
|
|
|
max_allowed = app.config.get("CONTENT_DEDUP_MAX", 3)
|
|
# Test with max_allowed + 2 to verify blocking
|
|
test_count = min(max_allowed + 2, 10)
|
|
results = []
|
|
|
|
for i in range(test_count):
|
|
allowed, count = check_content_hash(content_hash)
|
|
results.append((allowed, count))
|
|
print(f" Request {i + 1}: allowed={allowed}, count={count}")
|
|
|
|
# Count allowed/blocked
|
|
allowed_count = sum(1 for a, _ in results if a)
|
|
blocked_count = sum(1 for a, _ in results if not a)
|
|
|
|
print(f" Max allowed (config): {max_allowed}")
|
|
print(f" Allowed: {allowed_count}, Blocked: {blocked_count}")
|
|
|
|
# Counter should increment correctly up to max
|
|
expected_allowed = min(test_count, max_allowed)
|
|
expected_blocked = test_count - expected_allowed
|
|
|
|
if allowed_count != expected_allowed:
|
|
print(f" FAIL: Expected {expected_allowed} allowed, got {allowed_count}")
|
|
return False
|
|
|
|
if blocked_count != expected_blocked:
|
|
print(f" FAIL: Expected {expected_blocked} blocked, got {blocked_count}")
|
|
return False
|
|
|
|
# Verify counter values are sequential up to max
|
|
expected_counts = list(range(1, min(test_count, max_allowed) + 1))
|
|
if test_count > max_allowed:
|
|
expected_counts += [max_allowed] * (test_count - max_allowed)
|
|
|
|
actual_counts = [c for _, c in results]
|
|
|
|
if actual_counts != expected_counts:
|
|
print(" FAIL: Counter sequence mismatch")
|
|
print(f" Expected: {expected_counts}")
|
|
print(f" Got: {actual_counts}")
|
|
return False
|
|
|
|
print(" PASS: Counter correctly incremented")
|
|
return True
|
|
|
|
|
|
def test_dedup_window_expiry():
|
|
"""Test content hash dedup window expiry."""
|
|
print("\n[3] Content Hash Dedup Window Expiry")
|
|
print("=" * 50)
|
|
|
|
app = create_app("testing")
|
|
test_content = b"window expiry test"
|
|
content_hash = hashlib.sha256(test_content).hexdigest()
|
|
|
|
with app.app_context():
|
|
# Clear any existing entry
|
|
db = get_db()
|
|
db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,))
|
|
db.commit()
|
|
|
|
# Insert an old entry (past window)
|
|
window = app.config.get("CONTENT_DEDUP_WINDOW", 3600)
|
|
old_time = int(time.time()) - window - 10 # 10 seconds past window
|
|
|
|
db.execute(
|
|
"INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 5)",
|
|
(content_hash, old_time, old_time),
|
|
)
|
|
db.commit()
|
|
|
|
# Should reset counter since outside window
|
|
allowed, count = check_content_hash(content_hash)
|
|
|
|
print(f" Window: {window}s")
|
|
print(" Old entry count: 5")
|
|
print(f" After check: allowed={allowed}, count={count}")
|
|
|
|
if not allowed or count != 1:
|
|
print(" FAIL: Window expiry not resetting counter")
|
|
return False
|
|
|
|
print(" PASS: Counter reset after window expiry")
|
|
return True
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("RACE CONDITION TESTS")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
results.append(("Dedup Counter Race", test_dedup_counter_race()))
|
|
results.append(("Dedup Counter Sequential", test_dedup_counter_sequential()))
|
|
results.append(("Dedup Window Expiry", test_dedup_window_expiry()))
|
|
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
|
|
passed = sum(1 for _, r in results if r)
|
|
total = len(results)
|
|
|
|
for name, result in results:
|
|
status = "PASS" if result else "FAIL"
|
|
print(f" {status}: {name}")
|
|
|
|
print(f"\n{passed}/{total} checks passed")
|
|
|
|
return 0 if passed == total else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|