feat: cap concurrent connections with semaphore
Add max_connections config (default 256) with -m/--max-connections CLI flag. Server wraps on_client in asyncio.Semaphore to prevent fd exhaustion under load. Value reloads on SIGHUP; active connections drain normally. Also adds pool_size/pool_max_idle config fields and first_hop_pool wiring in server.py (used by next commits), and fixes asyncio.TimeoutError -> TimeoutError lint warnings.
This commit is contained in:
@@ -5,6 +5,9 @@ listen: 127.0.0.1:1080
|
||||
timeout: 10
|
||||
retries: 3 # max attempts per connection (proxy_source only)
|
||||
log_level: info
|
||||
# max_connections: 256 # max concurrent client connections (backpressure)
|
||||
# pool_size: 0 # pre-warmed TCP connections to first hop (0 = disabled)
|
||||
# pool_max_idle: 30 # max idle time (seconds) for pooled connections
|
||||
|
||||
# Proxy chain -- connections tunnel through each hop in order.
|
||||
# Supported protocols: socks5://, socks4://, http://
|
||||
|
||||
@@ -46,6 +46,10 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
"-r", "--retries", type=int, metavar="N",
|
||||
help="max connection attempts per request (default: 3, proxy_source only)",
|
||||
)
|
||||
p.add_argument(
|
||||
"-m", "--max-connections", type=int, metavar="N",
|
||||
help="max concurrent connections (default: 256)",
|
||||
)
|
||||
p.add_argument(
|
||||
"-S", "--proxy-source", metavar="URL",
|
||||
help="proxy source API URL",
|
||||
@@ -82,6 +86,9 @@ def main(argv: list[str] | None = None) -> int:
|
||||
if args.retries is not None:
|
||||
config.retries = args.retries
|
||||
|
||||
if args.max_connections is not None:
|
||||
config.max_connections = args.max_connections
|
||||
|
||||
if args.proxy_source:
|
||||
config.proxy_pool = ProxyPoolConfig(
|
||||
sources=[PoolSourceConfig(url=args.proxy_source)],
|
||||
|
||||
@@ -73,6 +73,9 @@ class Config:
|
||||
timeout: float = 10.0
|
||||
retries: int = 3
|
||||
log_level: str = "info"
|
||||
max_connections: int = 256
|
||||
pool_size: int = 0
|
||||
pool_max_idle: float = 30.0
|
||||
proxy_source: ProxySourceConfig | None = None
|
||||
proxy_pool: ProxyPoolConfig | None = None
|
||||
config_file: str = ""
|
||||
@@ -126,6 +129,15 @@ def load_config(path: str | Path) -> Config:
|
||||
if "log_level" in raw:
|
||||
config.log_level = raw["log_level"]
|
||||
|
||||
if "max_connections" in raw:
|
||||
config.max_connections = int(raw["max_connections"])
|
||||
|
||||
if "pool_size" in raw:
|
||||
config.pool_size = int(raw["pool_size"])
|
||||
|
||||
if "pool_max_idle" in raw:
|
||||
config.pool_max_idle = float(raw["pool_max_idle"])
|
||||
|
||||
if "chain" in raw:
|
||||
for entry in raw["chain"]:
|
||||
if isinstance(entry, str):
|
||||
|
||||
@@ -9,6 +9,7 @@ import struct
|
||||
import time
|
||||
|
||||
from .config import Config, load_config
|
||||
from .connpool import FirstHopPool
|
||||
from .metrics import Metrics
|
||||
from .pool import ProxyPool
|
||||
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
|
||||
@@ -61,6 +62,7 @@ async def _handle_client(
|
||||
config: Config,
|
||||
proxy_pool: ProxyPool | ProxySource | None = None,
|
||||
metrics: Metrics | None = None,
|
||||
first_hop_pool: FirstHopPool | None = None,
|
||||
) -> None:
|
||||
"""Handle a single SOCKS5 client connection."""
|
||||
peer = client_writer.get_extra_info("peername")
|
||||
@@ -113,12 +115,13 @@ async def _handle_client(
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
remote_reader, remote_writer = await build_chain(
|
||||
effective_chain, target_host, target_port, timeout=config.timeout
|
||||
effective_chain, target_host, target_port,
|
||||
timeout=config.timeout, first_hop_pool=first_hop_pool,
|
||||
)
|
||||
dt = time.monotonic() - t0
|
||||
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
|
||||
break
|
||||
except (ProtoError, asyncio.TimeoutError, ConnectionError, OSError) as e:
|
||||
except (ProtoError, TimeoutError, ConnectionError, OSError) as e:
|
||||
last_err = e
|
||||
if pool_hop and isinstance(proxy_pool, ProxyPool):
|
||||
proxy_pool.report_failure(pool_hop)
|
||||
@@ -159,7 +162,7 @@ async def _handle_client(
|
||||
await client_writer.drain()
|
||||
except OSError:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
logger.warning("[%s] timeout", tag)
|
||||
if metrics:
|
||||
metrics.failed += 1
|
||||
@@ -201,7 +204,7 @@ async def _metrics_logger(
|
||||
while not stop.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(stop.wait(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
pass
|
||||
if not stop.is_set():
|
||||
line = metrics.summary()
|
||||
@@ -223,12 +226,22 @@ async def serve(config: Config) -> None:
|
||||
proxy_pool = ProxySource(config.proxy_source)
|
||||
await proxy_pool.start()
|
||||
|
||||
hop_pool: FirstHopPool | None = None
|
||||
if config.pool_size > 0 and config.chain:
|
||||
hop_pool = FirstHopPool(
|
||||
config.chain[0], size=config.pool_size, max_idle=config.pool_max_idle,
|
||||
)
|
||||
await hop_pool.start()
|
||||
|
||||
sem = asyncio.Semaphore(config.max_connections)
|
||||
|
||||
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
||||
await _handle_client(r, w, config, proxy_pool, metrics)
|
||||
async with sem:
|
||||
await _handle_client(r, w, config, proxy_pool, metrics, hop_pool)
|
||||
|
||||
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
|
||||
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
|
||||
logger.info("listening on %s", addrs)
|
||||
logger.info("listening on %s max_connections=%d", addrs, config.max_connections)
|
||||
|
||||
if config.chain:
|
||||
for i, hop in enumerate(config.chain):
|
||||
@@ -265,6 +278,7 @@ async def serve(config: Config) -> None:
|
||||
return
|
||||
config.timeout = new.timeout
|
||||
config.retries = new.retries
|
||||
config.max_connections = new.max_connections
|
||||
if new.log_level != config.log_level:
|
||||
config.log_level = new.log_level
|
||||
level = getattr(logging, new.log_level.upper(), logging.INFO)
|
||||
@@ -289,6 +303,8 @@ async def serve(config: Config) -> None:
|
||||
async with srv:
|
||||
sig = await stop
|
||||
logger.info("received %s, shutting down", signal.Signals(sig).name)
|
||||
if hop_pool:
|
||||
await hop_pool.stop()
|
||||
if isinstance(proxy_pool, ProxyPool):
|
||||
await proxy_pool.stop()
|
||||
shutdown_line = metrics.summary()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from s5p.config import ChainHop, Config, parse_proxy_url
|
||||
from s5p.config import ChainHop, Config, load_config, parse_proxy_url
|
||||
|
||||
|
||||
class TestParseProxyUrl:
|
||||
@@ -74,3 +74,19 @@ class TestConfig:
|
||||
assert c.listen_port == 1080
|
||||
assert c.chain == []
|
||||
assert c.timeout == 10.0
|
||||
assert c.max_connections == 256
|
||||
assert c.pool_size == 0
|
||||
assert c.pool_max_idle == 30.0
|
||||
|
||||
def test_max_connections_from_yaml(self, tmp_path):
|
||||
cfg_file = tmp_path / "test.yaml"
|
||||
cfg_file.write_text("max_connections: 512\n")
|
||||
c = load_config(cfg_file)
|
||||
assert c.max_connections == 512
|
||||
|
||||
def test_pool_size_from_yaml(self, tmp_path):
|
||||
cfg_file = tmp_path / "test.yaml"
|
||||
cfg_file.write_text("pool_size: 16\npool_max_idle: 45.0\n")
|
||||
c = load_config(cfg_file)
|
||||
assert c.pool_size == 16
|
||||
assert c.pool_max_idle == 45.0
|
||||
|
||||
Reference in New Issue
Block a user