feat: cap concurrent connections with semaphore

Add max_connections config (default 256) with -m/--max-connections CLI
flag. Server wraps on_client in asyncio.Semaphore to prevent fd
exhaustion under load. Value reloads on SIGHUP; active connections
drain normally. Also adds pool_size/pool_max_idle config fields and
first_hop_pool wiring in server.py (used by next commits), and fixes
asyncio.TimeoutError -> TimeoutError lint warnings.
This commit is contained in:
user
2026-02-15 17:55:50 +01:00
parent 076213a830
commit 714e8efb3d
5 changed files with 61 additions and 7 deletions

View File

@@ -5,6 +5,9 @@ listen: 127.0.0.1:1080
timeout: 10
retries: 3 # max attempts per connection (proxy_source only)
log_level: info
# max_connections: 256 # max concurrent client connections (backpressure)
# pool_size: 0 # pre-warmed TCP connections to first hop (0 = disabled)
# pool_max_idle: 30 # max idle time (seconds) for pooled connections
# Proxy chain -- connections tunnel through each hop in order.
# Supported protocols: socks5://, socks4://, http://

View File

@@ -46,6 +46,10 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"-r", "--retries", type=int, metavar="N",
help="max connection attempts per request (default: 3, proxy_source only)",
)
p.add_argument(
"-m", "--max-connections", type=int, metavar="N",
help="max concurrent connections (default: 256)",
)
p.add_argument(
"-S", "--proxy-source", metavar="URL",
help="proxy source API URL",
@@ -82,6 +86,9 @@ def main(argv: list[str] | None = None) -> int:
if args.retries is not None:
config.retries = args.retries
if args.max_connections is not None:
config.max_connections = args.max_connections
if args.proxy_source:
config.proxy_pool = ProxyPoolConfig(
sources=[PoolSourceConfig(url=args.proxy_source)],

View File

@@ -73,6 +73,9 @@ class Config:
timeout: float = 10.0
retries: int = 3
log_level: str = "info"
max_connections: int = 256
pool_size: int = 0
pool_max_idle: float = 30.0
proxy_source: ProxySourceConfig | None = None
proxy_pool: ProxyPoolConfig | None = None
config_file: str = ""
@@ -126,6 +129,15 @@ def load_config(path: str | Path) -> Config:
if "log_level" in raw:
config.log_level = raw["log_level"]
if "max_connections" in raw:
config.max_connections = int(raw["max_connections"])
if "pool_size" in raw:
config.pool_size = int(raw["pool_size"])
if "pool_max_idle" in raw:
config.pool_max_idle = float(raw["pool_max_idle"])
if "chain" in raw:
for entry in raw["chain"]:
if isinstance(entry, str):

View File

@@ -9,6 +9,7 @@ import struct
import time
from .config import Config, load_config
from .connpool import FirstHopPool
from .metrics import Metrics
from .pool import ProxyPool
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
@@ -61,6 +62,7 @@ async def _handle_client(
config: Config,
proxy_pool: ProxyPool | ProxySource | None = None,
metrics: Metrics | None = None,
first_hop_pool: FirstHopPool | None = None,
) -> None:
"""Handle a single SOCKS5 client connection."""
peer = client_writer.get_extra_info("peername")
@@ -113,12 +115,13 @@ async def _handle_client(
try:
t0 = time.monotonic()
remote_reader, remote_writer = await build_chain(
effective_chain, target_host, target_port, timeout=config.timeout
effective_chain, target_host, target_port,
timeout=config.timeout, first_hop_pool=first_hop_pool,
)
dt = time.monotonic() - t0
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
break
except (ProtoError, asyncio.TimeoutError, ConnectionError, OSError) as e:
except (ProtoError, TimeoutError, ConnectionError, OSError) as e:
last_err = e
if pool_hop and isinstance(proxy_pool, ProxyPool):
proxy_pool.report_failure(pool_hop)
@@ -159,7 +162,7 @@ async def _handle_client(
await client_writer.drain()
except OSError:
pass
except asyncio.TimeoutError:
except TimeoutError:
logger.warning("[%s] timeout", tag)
if metrics:
metrics.failed += 1
@@ -201,7 +204,7 @@ async def _metrics_logger(
while not stop.is_set():
try:
await asyncio.wait_for(stop.wait(), timeout=60.0)
except asyncio.TimeoutError:
except TimeoutError:
pass
if not stop.is_set():
line = metrics.summary()
@@ -223,12 +226,22 @@ async def serve(config: Config) -> None:
proxy_pool = ProxySource(config.proxy_source)
await proxy_pool.start()
hop_pool: FirstHopPool | None = None
if config.pool_size > 0 and config.chain:
hop_pool = FirstHopPool(
config.chain[0], size=config.pool_size, max_idle=config.pool_max_idle,
)
await hop_pool.start()
sem = asyncio.Semaphore(config.max_connections)
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
await _handle_client(r, w, config, proxy_pool, metrics)
async with sem:
await _handle_client(r, w, config, proxy_pool, metrics, hop_pool)
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
logger.info("listening on %s", addrs)
logger.info("listening on %s max_connections=%d", addrs, config.max_connections)
if config.chain:
for i, hop in enumerate(config.chain):
@@ -265,6 +278,7 @@ async def serve(config: Config) -> None:
return
config.timeout = new.timeout
config.retries = new.retries
config.max_connections = new.max_connections
if new.log_level != config.log_level:
config.log_level = new.log_level
level = getattr(logging, new.log_level.upper(), logging.INFO)
@@ -289,6 +303,8 @@ async def serve(config: Config) -> None:
async with srv:
sig = await stop
logger.info("received %s, shutting down", signal.Signals(sig).name)
if hop_pool:
await hop_pool.stop()
if isinstance(proxy_pool, ProxyPool):
await proxy_pool.stop()
shutdown_line = metrics.summary()

View File

@@ -2,7 +2,7 @@
import pytest
from s5p.config import ChainHop, Config, parse_proxy_url
from s5p.config import ChainHop, Config, load_config, parse_proxy_url
class TestParseProxyUrl:
@@ -74,3 +74,19 @@ class TestConfig:
assert c.listen_port == 1080
assert c.chain == []
assert c.timeout == 10.0
assert c.max_connections == 256
assert c.pool_size == 0
assert c.pool_max_idle == 30.0
def test_max_connections_from_yaml(self, tmp_path):
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text("max_connections: 512\n")
c = load_config(cfg_file)
assert c.max_connections == 512
def test_pool_size_from_yaml(self, tmp_path):
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text("pool_size: 16\npool_max_idle: 45.0\n")
c = load_config(cfg_file)
assert c.pool_size == 16
assert c.pool_max_idle == 45.0