feat: cap concurrent connections with semaphore
Add max_connections config (default 256) with -m/--max-connections CLI flag. Server wraps on_client in asyncio.Semaphore to prevent fd exhaustion under load. Value reloads on SIGHUP; active connections drain normally. Also adds pool_size/pool_max_idle config fields and first_hop_pool wiring in server.py (used by next commits), and fixes asyncio.TimeoutError -> TimeoutError lint warnings.
This commit is contained in:
@@ -5,6 +5,9 @@ listen: 127.0.0.1:1080
|
|||||||
timeout: 10
|
timeout: 10
|
||||||
retries: 3 # max attempts per connection (proxy_source only)
|
retries: 3 # max attempts per connection (proxy_source only)
|
||||||
log_level: info
|
log_level: info
|
||||||
|
# max_connections: 256 # max concurrent client connections (backpressure)
|
||||||
|
# pool_size: 0 # pre-warmed TCP connections to first hop (0 = disabled)
|
||||||
|
# pool_max_idle: 30 # max idle time (seconds) for pooled connections
|
||||||
|
|
||||||
# Proxy chain -- connections tunnel through each hop in order.
|
# Proxy chain -- connections tunnel through each hop in order.
|
||||||
# Supported protocols: socks5://, socks4://, http://
|
# Supported protocols: socks5://, socks4://, http://
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
"-r", "--retries", type=int, metavar="N",
|
"-r", "--retries", type=int, metavar="N",
|
||||||
help="max connection attempts per request (default: 3, proxy_source only)",
|
help="max connection attempts per request (default: 3, proxy_source only)",
|
||||||
)
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"-m", "--max-connections", type=int, metavar="N",
|
||||||
|
help="max concurrent connections (default: 256)",
|
||||||
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"-S", "--proxy-source", metavar="URL",
|
"-S", "--proxy-source", metavar="URL",
|
||||||
help="proxy source API URL",
|
help="proxy source API URL",
|
||||||
@@ -82,6 +86,9 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
if args.retries is not None:
|
if args.retries is not None:
|
||||||
config.retries = args.retries
|
config.retries = args.retries
|
||||||
|
|
||||||
|
if args.max_connections is not None:
|
||||||
|
config.max_connections = args.max_connections
|
||||||
|
|
||||||
if args.proxy_source:
|
if args.proxy_source:
|
||||||
config.proxy_pool = ProxyPoolConfig(
|
config.proxy_pool = ProxyPoolConfig(
|
||||||
sources=[PoolSourceConfig(url=args.proxy_source)],
|
sources=[PoolSourceConfig(url=args.proxy_source)],
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ class Config:
|
|||||||
timeout: float = 10.0
|
timeout: float = 10.0
|
||||||
retries: int = 3
|
retries: int = 3
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
|
max_connections: int = 256
|
||||||
|
pool_size: int = 0
|
||||||
|
pool_max_idle: float = 30.0
|
||||||
proxy_source: ProxySourceConfig | None = None
|
proxy_source: ProxySourceConfig | None = None
|
||||||
proxy_pool: ProxyPoolConfig | None = None
|
proxy_pool: ProxyPoolConfig | None = None
|
||||||
config_file: str = ""
|
config_file: str = ""
|
||||||
@@ -126,6 +129,15 @@ def load_config(path: str | Path) -> Config:
|
|||||||
if "log_level" in raw:
|
if "log_level" in raw:
|
||||||
config.log_level = raw["log_level"]
|
config.log_level = raw["log_level"]
|
||||||
|
|
||||||
|
if "max_connections" in raw:
|
||||||
|
config.max_connections = int(raw["max_connections"])
|
||||||
|
|
||||||
|
if "pool_size" in raw:
|
||||||
|
config.pool_size = int(raw["pool_size"])
|
||||||
|
|
||||||
|
if "pool_max_idle" in raw:
|
||||||
|
config.pool_max_idle = float(raw["pool_max_idle"])
|
||||||
|
|
||||||
if "chain" in raw:
|
if "chain" in raw:
|
||||||
for entry in raw["chain"]:
|
for entry in raw["chain"]:
|
||||||
if isinstance(entry, str):
|
if isinstance(entry, str):
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import struct
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from .config import Config, load_config
|
from .config import Config, load_config
|
||||||
|
from .connpool import FirstHopPool
|
||||||
from .metrics import Metrics
|
from .metrics import Metrics
|
||||||
from .pool import ProxyPool
|
from .pool import ProxyPool
|
||||||
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
|
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
|
||||||
@@ -61,6 +62,7 @@ async def _handle_client(
|
|||||||
config: Config,
|
config: Config,
|
||||||
proxy_pool: ProxyPool | ProxySource | None = None,
|
proxy_pool: ProxyPool | ProxySource | None = None,
|
||||||
metrics: Metrics | None = None,
|
metrics: Metrics | None = None,
|
||||||
|
first_hop_pool: FirstHopPool | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a single SOCKS5 client connection."""
|
"""Handle a single SOCKS5 client connection."""
|
||||||
peer = client_writer.get_extra_info("peername")
|
peer = client_writer.get_extra_info("peername")
|
||||||
@@ -113,12 +115,13 @@ async def _handle_client(
|
|||||||
try:
|
try:
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic()
|
||||||
remote_reader, remote_writer = await build_chain(
|
remote_reader, remote_writer = await build_chain(
|
||||||
effective_chain, target_host, target_port, timeout=config.timeout
|
effective_chain, target_host, target_port,
|
||||||
|
timeout=config.timeout, first_hop_pool=first_hop_pool,
|
||||||
)
|
)
|
||||||
dt = time.monotonic() - t0
|
dt = time.monotonic() - t0
|
||||||
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
|
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
|
||||||
break
|
break
|
||||||
except (ProtoError, asyncio.TimeoutError, ConnectionError, OSError) as e:
|
except (ProtoError, TimeoutError, ConnectionError, OSError) as e:
|
||||||
last_err = e
|
last_err = e
|
||||||
if pool_hop and isinstance(proxy_pool, ProxyPool):
|
if pool_hop and isinstance(proxy_pool, ProxyPool):
|
||||||
proxy_pool.report_failure(pool_hop)
|
proxy_pool.report_failure(pool_hop)
|
||||||
@@ -159,7 +162,7 @@ async def _handle_client(
|
|||||||
await client_writer.drain()
|
await client_writer.drain()
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
except asyncio.TimeoutError:
|
except TimeoutError:
|
||||||
logger.warning("[%s] timeout", tag)
|
logger.warning("[%s] timeout", tag)
|
||||||
if metrics:
|
if metrics:
|
||||||
metrics.failed += 1
|
metrics.failed += 1
|
||||||
@@ -201,7 +204,7 @@ async def _metrics_logger(
|
|||||||
while not stop.is_set():
|
while not stop.is_set():
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(stop.wait(), timeout=60.0)
|
await asyncio.wait_for(stop.wait(), timeout=60.0)
|
||||||
except asyncio.TimeoutError:
|
except TimeoutError:
|
||||||
pass
|
pass
|
||||||
if not stop.is_set():
|
if not stop.is_set():
|
||||||
line = metrics.summary()
|
line = metrics.summary()
|
||||||
@@ -223,12 +226,22 @@ async def serve(config: Config) -> None:
|
|||||||
proxy_pool = ProxySource(config.proxy_source)
|
proxy_pool = ProxySource(config.proxy_source)
|
||||||
await proxy_pool.start()
|
await proxy_pool.start()
|
||||||
|
|
||||||
|
hop_pool: FirstHopPool | None = None
|
||||||
|
if config.pool_size > 0 and config.chain:
|
||||||
|
hop_pool = FirstHopPool(
|
||||||
|
config.chain[0], size=config.pool_size, max_idle=config.pool_max_idle,
|
||||||
|
)
|
||||||
|
await hop_pool.start()
|
||||||
|
|
||||||
|
sem = asyncio.Semaphore(config.max_connections)
|
||||||
|
|
||||||
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
||||||
await _handle_client(r, w, config, proxy_pool, metrics)
|
async with sem:
|
||||||
|
await _handle_client(r, w, config, proxy_pool, metrics, hop_pool)
|
||||||
|
|
||||||
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
|
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
|
||||||
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
|
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
|
||||||
logger.info("listening on %s", addrs)
|
logger.info("listening on %s max_connections=%d", addrs, config.max_connections)
|
||||||
|
|
||||||
if config.chain:
|
if config.chain:
|
||||||
for i, hop in enumerate(config.chain):
|
for i, hop in enumerate(config.chain):
|
||||||
@@ -265,6 +278,7 @@ async def serve(config: Config) -> None:
|
|||||||
return
|
return
|
||||||
config.timeout = new.timeout
|
config.timeout = new.timeout
|
||||||
config.retries = new.retries
|
config.retries = new.retries
|
||||||
|
config.max_connections = new.max_connections
|
||||||
if new.log_level != config.log_level:
|
if new.log_level != config.log_level:
|
||||||
config.log_level = new.log_level
|
config.log_level = new.log_level
|
||||||
level = getattr(logging, new.log_level.upper(), logging.INFO)
|
level = getattr(logging, new.log_level.upper(), logging.INFO)
|
||||||
@@ -289,6 +303,8 @@ async def serve(config: Config) -> None:
|
|||||||
async with srv:
|
async with srv:
|
||||||
sig = await stop
|
sig = await stop
|
||||||
logger.info("received %s, shutting down", signal.Signals(sig).name)
|
logger.info("received %s, shutting down", signal.Signals(sig).name)
|
||||||
|
if hop_pool:
|
||||||
|
await hop_pool.stop()
|
||||||
if isinstance(proxy_pool, ProxyPool):
|
if isinstance(proxy_pool, ProxyPool):
|
||||||
await proxy_pool.stop()
|
await proxy_pool.stop()
|
||||||
shutdown_line = metrics.summary()
|
shutdown_line = metrics.summary()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from s5p.config import ChainHop, Config, parse_proxy_url
|
from s5p.config import ChainHop, Config, load_config, parse_proxy_url
|
||||||
|
|
||||||
|
|
||||||
class TestParseProxyUrl:
|
class TestParseProxyUrl:
|
||||||
@@ -74,3 +74,19 @@ class TestConfig:
|
|||||||
assert c.listen_port == 1080
|
assert c.listen_port == 1080
|
||||||
assert c.chain == []
|
assert c.chain == []
|
||||||
assert c.timeout == 10.0
|
assert c.timeout == 10.0
|
||||||
|
assert c.max_connections == 256
|
||||||
|
assert c.pool_size == 0
|
||||||
|
assert c.pool_max_idle == 30.0
|
||||||
|
|
||||||
|
def test_max_connections_from_yaml(self, tmp_path):
|
||||||
|
cfg_file = tmp_path / "test.yaml"
|
||||||
|
cfg_file.write_text("max_connections: 512\n")
|
||||||
|
c = load_config(cfg_file)
|
||||||
|
assert c.max_connections == 512
|
||||||
|
|
||||||
|
def test_pool_size_from_yaml(self, tmp_path):
|
||||||
|
cfg_file = tmp_path / "test.yaml"
|
||||||
|
cfg_file.write_text("pool_size: 16\npool_max_idle: 45.0\n")
|
||||||
|
c = load_config(cfg_file)
|
||||||
|
assert c.pool_size == 16
|
||||||
|
assert c.pool_max_idle == 45.0
|
||||||
|
|||||||
Reference in New Issue
Block a user