feat: add managed proxy pool with health testing

ProxyPool replaces ProxySource with:
- Multiple sources: HTTP APIs and text files (one proxy URL per line)
- Deduplication by proto://host:port
- Health testing: full chain test with configurable concurrency
- Mass-failure guard: skip eviction when >90% fail
- Background loops for periodic refresh and health checks
- JSON state persistence with atomic writes (warm starts)
- Backward compat: ProxySource still works for legacy configs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
user
2026-02-15 06:11:19 +01:00
parent 1780c3a8cd
commit 72adf2f658
3 changed files with 585 additions and 11 deletions

385
src/s5p/pool.py Normal file
View File

@@ -0,0 +1,385 @@
"""Managed proxy pool with multi-source fetching and health testing."""
from __future__ import annotations
import asyncio
import json
import logging
import os
import random
import time
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urlencode
from .config import ChainHop, PoolSourceConfig, ProxyPoolConfig, parse_proxy_url
from .proto import ProtoError, build_chain
logger = logging.getLogger("s5p")
VALID_PROTOS = {"socks5", "socks4", "http"}
STATE_VERSION = 1
# -- proxy entry -------------------------------------------------------------
@dataclass
class ProxyEntry:
"""Internal state for a single proxy in the pool."""
hop: ChainHop
last_seen: float = 0.0
last_ok: float = 0.0
last_test: float = 0.0
fails: int = 0
tests: int = 0
alive: bool = False
# -- proxy pool --------------------------------------------------------------
class ProxyPool:
"""Managed proxy pool with source fetching, health testing, and persistence.
Drop-in replacement for ProxySource: exposes ``start()``, ``stop()``,
``get()``, ``count``, and ``alive_count``.
"""
def __init__(
self,
cfg: ProxyPoolConfig,
chain: list[ChainHop],
timeout: float,
) -> None:
self._cfg = cfg
self._chain = list(chain)
self._timeout = timeout
self._proxies: dict[str, ProxyEntry] = {}
self._alive_keys: list[str] = []
self._tasks: list[asyncio.Task] = []
self._stop = asyncio.Event()
self._state_path = self._resolve_state_path()
# -- public interface ----------------------------------------------------
async def start(self) -> None:
"""Load state, fetch sources, run initial health test, start loops."""
self._load_state()
await self._fetch_all_sources()
await self._run_health_tests()
self._save_state()
self._tasks.append(asyncio.create_task(self._refresh_loop()))
self._tasks.append(asyncio.create_task(self._health_loop()))
async def stop(self) -> None:
"""Cancel background tasks and save state."""
self._stop.set()
for task in self._tasks:
task.cancel()
for task in self._tasks:
try:
await task
except asyncio.CancelledError:
pass
self._tasks.clear()
self._save_state()
async def get(self) -> ChainHop | None:
"""Return a random alive proxy, or None if pool is empty."""
if not self._alive_keys:
return None
key = random.choice(self._alive_keys)
entry = self._proxies.get(key)
return entry.hop if entry else None
@property
def count(self) -> int:
"""Total proxies in pool."""
return len(self._proxies)
@property
def alive_count(self) -> int:
"""Number of alive proxies."""
return len(self._alive_keys)
# -- source fetching -----------------------------------------------------
async def _fetch_all_sources(self) -> None:
"""Fetch proxies from all configured sources and merge."""
loop = asyncio.get_running_loop()
proxies: list[ChainHop] = []
for src in self._cfg.sources:
try:
if src.url:
batch = await loop.run_in_executor(None, self._fetch_api_sync, src)
logger.info("pool: fetched %d proxies from %s", len(batch), src.url)
proxies.extend(batch)
elif src.file:
batch = await loop.run_in_executor(None, self._fetch_file_sync, src)
logger.info("pool: loaded %d proxies from %s", len(batch), src.file)
proxies.extend(batch)
except Exception as e:
label = src.url or src.file or "?"
logger.warning("pool: source %s failed: %s", label, e)
self._merge(proxies)
def _fetch_api_sync(self, src: PoolSourceConfig) -> list[ChainHop]:
"""Fetch proxies from an HTTP API (runs in executor)."""
params: dict[str, str] = {}
if src.limit:
params["limit"] = str(src.limit)
if src.proto:
params["proto"] = src.proto
if src.country:
params["country"] = src.country
url = src.url
if params:
sep = "&" if "?" in url else "?"
url = f"{url}{sep}{urlencode(params)}"
req = urllib.request.Request(url, headers={"Accept": "application/json"})
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
proxies: list[ChainHop] = []
for entry in data.get("proxies", []):
proto = entry.get("proto")
addr = entry.get("proxy", "")
if not proto or proto not in VALID_PROTOS or ":" not in addr:
continue
host, port_str = addr.rsplit(":", 1)
try:
port = int(port_str)
except ValueError:
continue
proxies.append(ChainHop(proto=proto, host=host, port=port))
return proxies
def _fetch_file_sync(self, src: PoolSourceConfig) -> list[ChainHop]:
"""Parse a text file with one proxy URL per line (runs in executor)."""
path = Path(src.file).expanduser()
if not path.is_file():
logger.warning("pool: file not found: %s", path)
return []
proxies: list[ChainHop] = []
for line in path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
try:
hop = parse_proxy_url(line)
except ValueError as e:
logger.debug("pool: skipping invalid line %r: %s", line, e)
continue
if src.proto and hop.proto != src.proto:
continue
proxies.append(hop)
return proxies
def _merge(self, proxies: list[ChainHop]) -> None:
"""Deduplicate and merge fetched proxies into the pool."""
now = time.time()
seen: set[str] = set()
for hop in proxies:
key = f"{hop.proto}://{hop.host}:{hop.port}"
seen.add(key)
if key in self._proxies:
self._proxies[key].last_seen = now
self._proxies[key].hop = hop
else:
self._proxies[key] = ProxyEntry(hop=hop, last_seen=now)
# -- health testing ------------------------------------------------------
async def _test_proxy(self, key: str, entry: ProxyEntry) -> bool:
"""Test a single proxy by building the full chain and sending HTTP GET."""
chain = self._chain + [entry.hop]
entry.last_test = time.time()
entry.tests += 1
try:
reader, writer = await build_chain(
chain, "httpbin.org", 80, timeout=self._cfg.test_timeout,
)
except (ProtoError, TimeoutError, ConnectionError, OSError):
return False
try:
host = self._cfg.test_url.split("//", 1)[-1].split("/", 1)[0]
path = "/" + self._cfg.test_url.split("//", 1)[-1].split("/", 1)[-1] \
if "/" in self._cfg.test_url.split("//", 1)[-1] else "/"
request = f"GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n"
writer.write(request.encode())
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=self._cfg.test_timeout)
parts = line.decode("utf-8", errors="replace").split(None, 2)
return len(parts) >= 2 and parts[1].startswith("2")
except (TimeoutError, ConnectionError, OSError):
return False
finally:
try:
writer.close()
await writer.wait_closed()
except OSError:
pass
async def _run_health_tests(self) -> None:
"""Test all proxies with bounded concurrency."""
if not self._proxies:
return
sem = asyncio.Semaphore(self._cfg.test_concurrency)
results: dict[str, bool] = {}
async def _test(key: str, entry: ProxyEntry) -> None:
async with sem:
results[key] = await self._test_proxy(key, entry)
tasks = [_test(k, e) for k, e in list(self._proxies.items())]
await asyncio.gather(*tasks)
total = len(results)
passed = sum(1 for v in results.values() if v)
fail_rate = (total - passed) / total if total else 0.0
# mass-failure guard: if >90% fail, skip eviction
skip_eviction = fail_rate > 0.90 and total > 10
if skip_eviction:
logger.warning(
"pool: %d/%d tests failed (%.0f%%), skipping eviction",
total - passed, total, fail_rate * 100,
)
evict_keys: list[str] = []
for key, ok in results.items():
entry = self._proxies.get(key)
if not entry:
continue
if ok:
entry.alive = True
entry.fails = 0
entry.last_ok = time.time()
else:
entry.alive = False
entry.fails += 1
if not skip_eviction and entry.fails >= self._cfg.max_fails:
evict_keys.append(key)
for key in evict_keys:
del self._proxies[key]
self._rebuild_alive()
logger.info(
"pool: %d proxies, %d alive%s",
len(self._proxies),
len(self._alive_keys),
f" (evicted {len(evict_keys)})" if evict_keys else "",
)
def _rebuild_alive(self) -> None:
"""Rebuild the alive keys list from current state."""
self._alive_keys = [k for k, e in self._proxies.items() if e.alive]
# -- background loops ----------------------------------------------------
async def _refresh_loop(self) -> None:
"""Periodically re-fetch sources and merge."""
while not self._stop.is_set():
try:
await asyncio.wait_for(self._stop.wait(), timeout=self._cfg.refresh)
except TimeoutError:
pass
if self._stop.is_set():
break
await self._fetch_all_sources()
self._save_state()
async def _health_loop(self) -> None:
"""Periodically test proxies and evict dead ones."""
while not self._stop.is_set():
try:
await asyncio.wait_for(self._stop.wait(), timeout=self._cfg.test_interval)
except TimeoutError:
pass
if self._stop.is_set():
break
await self._run_health_tests()
self._save_state()
# -- persistence ---------------------------------------------------------
def _resolve_state_path(self) -> Path:
"""Resolve state file path, defaulting to ~/.cache/s5p/pool.json."""
if self._cfg.state_file:
return Path(self._cfg.state_file).expanduser()
cache_dir = Path.home() / ".cache" / "s5p"
return cache_dir / "pool.json"
def _load_state(self) -> None:
"""Load proxy state from JSON file (warm start)."""
if not self._state_path.is_file():
return
try:
data = json.loads(self._state_path.read_text())
if data.get("version") != STATE_VERSION:
logger.warning("pool: state file version mismatch, starting fresh")
return
for key, entry in data.get("proxies", {}).items():
hop = ChainHop(
proto=entry["proto"],
host=entry["host"],
port=int(entry["port"]),
username=entry.get("username"),
password=entry.get("password"),
)
self._proxies[key] = ProxyEntry(
hop=hop,
last_seen=entry.get("last_seen", 0.0),
last_ok=entry.get("last_ok", 0.0),
fails=entry.get("fails", 0),
tests=entry.get("tests", 0),
alive=entry.get("alive", False),
)
self._rebuild_alive()
logger.info(
"pool: loaded state (%d proxies, %d alive)",
len(self._proxies), len(self._alive_keys),
)
except (json.JSONDecodeError, KeyError, TypeError, ValueError) as e:
logger.warning("pool: corrupt state file: %s", e)
self._proxies.clear()
self._alive_keys.clear()
def _save_state(self) -> None:
"""Save proxy state to JSON file (atomic write)."""
self._state_path.parent.mkdir(parents=True, exist_ok=True)
tmp = self._state_path.with_suffix(".tmp")
proxies = {}
for key, entry in self._proxies.items():
proxies[key] = {
"proto": entry.hop.proto,
"host": entry.hop.host,
"port": entry.hop.port,
"username": entry.hop.username,
"password": entry.hop.password,
"last_seen": entry.last_seen,
"last_ok": entry.last_ok,
"fails": entry.fails,
"tests": entry.tests,
"alive": entry.alive,
}
data = {
"version": STATE_VERSION,
"updated": time.time(),
"proxies": proxies,
}
try:
tmp.write_text(json.dumps(data, indent=2))
os.replace(tmp, self._state_path)
except OSError as e:
logger.warning("pool: failed to save state: %s", e)

View File

@@ -10,6 +10,7 @@ import time
from .config import Config
from .metrics import Metrics
from .pool import ProxyPool
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
from .source import ProxySource
@@ -58,7 +59,7 @@ async def _handle_client(
client_reader: asyncio.StreamReader,
client_writer: asyncio.StreamWriter,
config: Config,
proxy_source: ProxySource | None = None,
proxy_pool: ProxyPool | ProxySource | None = None,
metrics: Metrics | None = None,
) -> None:
"""Handle a single SOCKS5 client connection."""
@@ -97,13 +98,13 @@ async def _handle_client(
logger.info("[%s] connect %s:%d", tag, target_host, target_port)
# -- build chain (with retry) --
attempts = config.retries if proxy_source else 1
attempts = config.retries if proxy_pool else 1
last_err: Exception | None = None
for attempt in range(attempts):
effective_chain = list(config.chain)
if proxy_source:
hop = await proxy_source.get()
if proxy_pool:
hop = await proxy_pool.get()
if hop:
effective_chain.append(hop)
logger.debug("[%s] +proxy %s", tag, hop)
@@ -203,13 +204,17 @@ async def serve(config: Config) -> None:
"""Start the SOCKS5 proxy server."""
metrics = Metrics()
proxy_source: ProxySource | None = None
if config.proxy_source and config.proxy_source.url:
proxy_source = ProxySource(config.proxy_source)
await proxy_source.start()
proxy_pool: ProxyPool | ProxySource | None = None
if config.proxy_pool and config.proxy_pool.sources:
pool = ProxyPool(config.proxy_pool, config.chain, config.timeout)
await pool.start()
proxy_pool = pool
elif config.proxy_source and config.proxy_source.url:
proxy_pool = ProxySource(config.proxy_source)
await proxy_pool.start()
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
await _handle_client(r, w, config, proxy_source, metrics)
await _handle_client(r, w, config, proxy_pool, metrics)
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
@@ -221,8 +226,15 @@ async def serve(config: Config) -> None:
else:
logger.info(" mode: direct (no chain)")
if proxy_source:
logger.info(" proxy source: %s (%d proxies)", config.proxy_source.url, proxy_source.count)
if isinstance(proxy_pool, ProxyPool):
nsrc = len(config.proxy_pool.sources)
logger.info(
" pool: %d proxies, %d alive (from %d source%s)",
proxy_pool.count, proxy_pool.alive_count, nsrc, "s" if nsrc != 1 else "",
)
logger.info(" retries: %d", config.retries)
elif proxy_pool:
logger.info(" proxy source: %s (%d proxies)", config.proxy_source.url, proxy_pool.count)
logger.info(" retries: %d", config.retries)
loop = asyncio.get_running_loop()
@@ -237,6 +249,8 @@ async def serve(config: Config) -> None:
async with srv:
sig = await stop
logger.info("received %s, shutting down", signal.Signals(sig).name)
if isinstance(proxy_pool, ProxyPool):
await proxy_pool.stop()
logger.info("metrics: %s", metrics.summary())
metrics_stop.set()
await metrics_task

175
tests/test_pool.py Normal file
View File

@@ -0,0 +1,175 @@
"""Tests for the managed proxy pool."""
import json
import time
import pytest
from s5p.config import ChainHop, PoolSourceConfig, ProxyPoolConfig
from s5p.pool import ProxyEntry, ProxyPool
class TestProxyEntry:
"""Test ProxyEntry defaults."""
def test_defaults(self):
hop = ChainHop(proto="socks5", host="1.2.3.4", port=1080)
entry = ProxyEntry(hop=hop)
assert entry.alive is False
assert entry.fails == 0
assert entry.tests == 0
class TestProxyPoolMerge:
"""Test proxy deduplication and merge."""
def test_merge_dedup(self):
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
proxies = [
ChainHop(proto="socks5", host="1.2.3.4", port=1080),
ChainHop(proto="socks5", host="1.2.3.4", port=1080),
ChainHop(proto="socks5", host="5.6.7.8", port=1080),
]
pool._merge(proxies)
assert pool.count == 2
def test_merge_updates_existing(self):
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
hop = ChainHop(proto="socks5", host="1.2.3.4", port=1080)
pool._merge([hop])
first_seen = pool._proxies["socks5://1.2.3.4:1080"].last_seen
# merge again -- last_seen should update
time.sleep(0.01)
pool._merge([hop])
assert pool._proxies["socks5://1.2.3.4:1080"].last_seen >= first_seen
assert pool.count == 1
class TestProxyPoolGet:
"""Test proxy selection."""
def test_get_empty(self):
import asyncio
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
result = asyncio.run(pool.get())
assert result is None
def test_get_returns_alive(self):
import asyncio
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
hop = ChainHop(proto="socks5", host="1.2.3.4", port=1080)
pool._proxies["socks5://1.2.3.4:1080"] = ProxyEntry(hop=hop, alive=True)
pool._rebuild_alive()
result = asyncio.run(pool.get())
assert result is not None
assert result.host == "1.2.3.4"
class TestProxyPoolFetchFile:
"""Test file source parsing."""
def test_parse_file(self, tmp_path):
proxy_file = tmp_path / "proxies.txt"
proxy_file.write_text(
"# comment\n"
"socks5://1.2.3.4:1080\n"
"socks5://user:pass@5.6.7.8:1080\n"
"http://proxy.example.com:8080\n"
"\n"
" # another comment\n"
)
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
src = PoolSourceConfig(file=str(proxy_file))
result = pool._fetch_file_sync(src)
assert len(result) == 3
assert result[0].proto == "socks5"
assert result[0].host == "1.2.3.4"
assert result[1].username == "user"
assert result[1].password == "pass"
assert result[2].proto == "http"
def test_parse_file_with_proto_filter(self, tmp_path):
proxy_file = tmp_path / "proxies.txt"
proxy_file.write_text(
"socks5://1.2.3.4:1080\n"
"http://proxy.example.com:8080\n"
)
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
src = PoolSourceConfig(file=str(proxy_file), proto="socks5")
result = pool._fetch_file_sync(src)
assert len(result) == 1
assert result[0].proto == "socks5"
def test_missing_file(self, tmp_path):
cfg = ProxyPoolConfig(sources=[])
pool = ProxyPool(cfg, [], timeout=10.0)
src = PoolSourceConfig(file=str(tmp_path / "nonexistent.txt"))
result = pool._fetch_file_sync(src)
assert result == []
class TestProxyPoolPersistence:
"""Test state save/load."""
def test_save_and_load(self, tmp_path):
state_file = str(tmp_path / "pool.json")
cfg = ProxyPoolConfig(sources=[], state_file=state_file)
pool = ProxyPool(cfg, [], timeout=10.0)
hop = ChainHop(proto="socks5", host="1.2.3.4", port=1080)
pool._proxies["socks5://1.2.3.4:1080"] = ProxyEntry(
hop=hop, alive=True, fails=0, tests=5, last_ok=1000.0,
)
pool._rebuild_alive()
pool._save_state()
# load into a fresh pool
pool2 = ProxyPool(cfg, [], timeout=10.0)
pool2._load_state()
assert pool2.count == 1
assert pool2.alive_count == 1
entry = pool2._proxies["socks5://1.2.3.4:1080"]
assert entry.hop.host == "1.2.3.4"
assert entry.tests == 5
assert entry.alive is True
def test_corrupt_state(self, tmp_path):
state_file = tmp_path / "pool.json"
state_file.write_text("{invalid json")
cfg = ProxyPoolConfig(sources=[], state_file=str(state_file))
pool = ProxyPool(cfg, [], timeout=10.0)
pool._load_state()
assert pool.count == 0
def test_missing_state(self, tmp_path):
cfg = ProxyPoolConfig(sources=[], state_file=str(tmp_path / "missing.json"))
pool = ProxyPool(cfg, [], timeout=10.0)
pool._load_state() # should not raise
assert pool.count == 0
def test_state_with_auth(self, tmp_path):
state_file = str(tmp_path / "pool.json")
cfg = ProxyPoolConfig(sources=[], state_file=state_file)
pool = ProxyPool(cfg, [], timeout=10.0)
hop = ChainHop(
proto="socks5", host="1.2.3.4", port=1080,
username="user", password="pass",
)
pool._proxies["socks5://1.2.3.4:1080"] = ProxyEntry(hop=hop, alive=True)
pool._save_state()
pool2 = ProxyPool(cfg, [], timeout=10.0)
pool2._load_state()
entry = pool2._proxies["socks5://1.2.3.4:1080"]
assert entry.hop.username == "user"
assert entry.hop.password == "pass"