#!/usr/bin/env python3 """Race condition tests for FlaskPaste.""" import hashlib import sys import threading import time sys.path.insert(0, ".") from app import create_app from app.database import check_content_hash, get_db def test_dedup_counter_race(): """Test content hash dedup counter for race conditions.""" print("\n[1] Content Hash Dedup Counter Race Condition") print("=" * 50) app = create_app("testing") test_content = b"race condition test content" content_hash = hashlib.sha256(test_content).hexdigest() with app.app_context(): # Clear any existing entry db = get_db() db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,)) db.commit() results = [] errors = [] def make_request(): with app.app_context(): try: allowed, count = check_content_hash(content_hash) results.append((allowed, count)) except Exception as e: errors.append(str(e)) # Make 20 concurrent requests with same content threads = [threading.Thread(target=make_request) for _ in range(20)] for t in threads: t.start() for t in threads: t.join() # Check results print(" Threads: 20") print(f" Errors: {len(errors)}") print(f" Results: {len(results)}") if errors: print(" FAIL: Errors during concurrent access") for e in errors[:3]: print(f" {e}") return False # Get final count from database db = get_db() row = db.execute( "SELECT count FROM content_hashes WHERE hash = ?", (content_hash,) ).fetchone() final_count = row["count"] if row else 0 print(f" Final DB count: {final_count}") # Count how many got allowed allowed_count = sum(1 for a, _ in results if a) blocked_count = sum(1 for a, _ in results if not a) print(f" Allowed: {allowed_count}, Blocked: {blocked_count}") # With DEDUP_MAX=3 (testing config), we should see exactly 3 allowed # and the rest blocked max_allowed = app.config.get("CONTENT_DEDUP_MAX", 3) # Check for race condition: counter should match allowed count # (up to max_count) expected_count = min(allowed_count, max_allowed) if final_count != expected_count and allowed_count <= max_allowed: print(" FAIL: Race condition detected!") print(f" Expected count: {expected_count}, Got: {final_count}") return False if allowed_count > max_allowed: print(f" FAIL: Too many requests allowed ({allowed_count} > {max_allowed})") return False print(f" PASS: Counter correctly tracked ({final_count})") return True def test_dedup_counter_sequential(): """Test content hash dedup counter in sequential access.""" print("\n[2] Content Hash Dedup Counter Sequential") print("=" * 50) app = create_app("testing") test_content = b"sequential test content" content_hash = hashlib.sha256(test_content).hexdigest() with app.app_context(): # Clear any existing entry db = get_db() db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,)) db.commit() max_allowed = app.config.get("CONTENT_DEDUP_MAX", 3) # Test with max_allowed + 2 to verify blocking test_count = min(max_allowed + 2, 10) results = [] for i in range(test_count): allowed, count = check_content_hash(content_hash) results.append((allowed, count)) print(f" Request {i + 1}: allowed={allowed}, count={count}") # Count allowed/blocked allowed_count = sum(1 for a, _ in results if a) blocked_count = sum(1 for a, _ in results if not a) print(f" Max allowed (config): {max_allowed}") print(f" Allowed: {allowed_count}, Blocked: {blocked_count}") # Counter should increment correctly up to max expected_allowed = min(test_count, max_allowed) expected_blocked = test_count - expected_allowed if allowed_count != expected_allowed: print(f" FAIL: Expected {expected_allowed} allowed, got {allowed_count}") return False if blocked_count != expected_blocked: print(f" FAIL: Expected {expected_blocked} blocked, got {blocked_count}") return False # Verify counter values are sequential up to max expected_counts = list(range(1, min(test_count, max_allowed) + 1)) if test_count > max_allowed: expected_counts += [max_allowed] * (test_count - max_allowed) actual_counts = [c for _, c in results] if actual_counts != expected_counts: print(" FAIL: Counter sequence mismatch") print(f" Expected: {expected_counts}") print(f" Got: {actual_counts}") return False print(" PASS: Counter correctly incremented") return True def test_dedup_window_expiry(): """Test content hash dedup window expiry.""" print("\n[3] Content Hash Dedup Window Expiry") print("=" * 50) app = create_app("testing") test_content = b"window expiry test" content_hash = hashlib.sha256(test_content).hexdigest() with app.app_context(): # Clear any existing entry db = get_db() db.execute("DELETE FROM content_hashes WHERE hash = ?", (content_hash,)) db.commit() # Insert an old entry (past window) window = app.config.get("CONTENT_DEDUP_WINDOW", 3600) old_time = int(time.time()) - window - 10 # 10 seconds past window db.execute( "INSERT INTO content_hashes (hash, first_seen, last_seen, count) VALUES (?, ?, ?, 5)", (content_hash, old_time, old_time), ) db.commit() # Should reset counter since outside window allowed, count = check_content_hash(content_hash) print(f" Window: {window}s") print(" Old entry count: 5") print(f" After check: allowed={allowed}, count={count}") if not allowed or count != 1: print(" FAIL: Window expiry not resetting counter") return False print(" PASS: Counter reset after window expiry") return True def main(): print("=" * 60) print("RACE CONDITION TESTS") print("=" * 60) results = [] results.append(("Dedup Counter Race", test_dedup_counter_race())) results.append(("Dedup Counter Sequential", test_dedup_counter_sequential())) results.append(("Dedup Window Expiry", test_dedup_window_expiry())) print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) passed = sum(1 for _, r in results if r) total = len(results) for name, result in results: status = "PASS" if result else "FAIL" print(f" {status}: {name}") print(f"\n{passed}/{total} checks passed") return 0 if passed == total else 1 if __name__ == "__main__": sys.exit(main())