Files
s5p/src/s5p/server.py
user 818169758b feat: add SIGHUP hot config reload
On SIGHUP, re-read the YAML config file and update mutable runtime
settings: timeout, retries, log_level, and pool config (sources,
intervals, thresholds). Pool triggers an immediate source re-fetch.
Listen address and chain require restart.
2026-02-15 16:02:57 +01:00

297 lines
9.7 KiB
Python

"""SOCKS5 proxy server with proxy-chain support."""
from __future__ import annotations
import asyncio
import logging
import signal
import struct
import time
from .config import Config, load_config
from .metrics import Metrics
from .pool import ProxyPool
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
from .source import ProxySource
logger = logging.getLogger("s5p")
BUFFER_SIZE = 65536
# -- relay -------------------------------------------------------------------
async def _relay(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> int:
"""Unidirectional data relay. Returns total bytes transferred."""
total = 0
try:
while True:
data = await reader.read(BUFFER_SIZE)
if not data:
break
total += len(data)
writer.write(data)
await writer.drain()
except (ConnectionError, asyncio.CancelledError, OSError):
pass
finally:
try:
writer.close()
await writer.wait_closed()
except (OSError, ConnectionError):
pass
return total
# -- SOCKS5 server -----------------------------------------------------------
def _socks5_reply(rep: int) -> bytes:
"""Build a minimal SOCKS5 reply packet."""
return struct.pack("!BBB", 0x05, rep, 0x00) + b"\x01\x00\x00\x00\x00\x00\x00"
async def _handle_client(
client_reader: asyncio.StreamReader,
client_writer: asyncio.StreamWriter,
config: Config,
proxy_pool: ProxyPool | ProxySource | None = None,
metrics: Metrics | None = None,
) -> None:
"""Handle a single SOCKS5 client connection."""
peer = client_writer.get_extra_info("peername")
tag = f"{peer[0]}:{peer[1]}" if peer else "?"
if metrics:
metrics.connections += 1
try:
# -- greeting --
header = await asyncio.wait_for(client_reader.readexactly(2), timeout=10.0)
if header[0] != 0x05:
logger.warning("[%s] bad socks version: %d", tag, header[0])
return
methods = await client_reader.readexactly(header[1])
if 0x00 not in methods:
client_writer.write(b"\x05\xff")
await client_writer.drain()
return
client_writer.write(b"\x05\x00")
await client_writer.drain()
# -- connect request --
req = await asyncio.wait_for(client_reader.readexactly(3), timeout=10.0)
if req[0] != 0x05:
return
if req[1] != 0x01:
client_writer.write(_socks5_reply(Socks5Reply.COMMAND_NOT_SUPPORTED))
await client_writer.drain()
return
target_host, target_port = await read_socks5_address(client_reader)
logger.info("[%s] connect %s:%d", tag, target_host, target_port)
# -- build chain (with retry) --
attempts = config.retries if proxy_pool else 1
last_err: Exception | None = None
for attempt in range(attempts):
effective_chain = list(config.chain)
pool_hop = None
if proxy_pool:
pool_hop = await proxy_pool.get()
if pool_hop:
effective_chain.append(pool_hop)
logger.debug("[%s] +proxy %s", tag, pool_hop)
try:
t0 = time.monotonic()
remote_reader, remote_writer = await build_chain(
effective_chain, target_host, target_port, timeout=config.timeout
)
dt = time.monotonic() - t0
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
break
except (ProtoError, asyncio.TimeoutError, ConnectionError, OSError) as e:
last_err = e
if pool_hop and isinstance(proxy_pool, ProxyPool):
proxy_pool.report_failure(pool_hop)
if metrics:
metrics.retries += 1
if attempt + 1 < attempts:
logger.debug("[%s] attempt %d/%d failed: %s", tag, attempt + 1, attempts, e)
continue
raise last_err
# -- success --
if metrics:
metrics.success += 1
metrics.active += 1
client_writer.write(_socks5_reply(Socks5Reply.SUCCEEDED))
await client_writer.drain()
# -- relay --
try:
bytes_up, bytes_down = await asyncio.gather(
_relay(client_reader, remote_writer),
_relay(remote_reader, client_writer),
)
if metrics:
metrics.bytes_in += bytes_up
metrics.bytes_out += bytes_down
finally:
if metrics:
metrics.active -= 1
except ProtoError as e:
logger.warning("[%s] %s", tag, e)
if metrics:
metrics.failed += 1
try:
client_writer.write(_socks5_reply(e.reply))
await client_writer.drain()
except OSError:
pass
except asyncio.TimeoutError:
logger.warning("[%s] timeout", tag)
if metrics:
metrics.failed += 1
try:
client_writer.write(_socks5_reply(Socks5Reply.TTL_EXPIRED))
await client_writer.drain()
except OSError:
pass
except (ConnectionError, OSError) as e:
logger.debug("[%s] %s", tag, e)
if metrics:
metrics.failed += 1
try:
client_writer.write(_socks5_reply(Socks5Reply.CONNECTION_REFUSED))
await client_writer.drain()
except OSError:
pass
except Exception:
logger.exception("[%s] unexpected error", tag)
if metrics:
metrics.failed += 1
finally:
try:
client_writer.close()
await client_writer.wait_closed()
except OSError:
pass
# -- entry point -------------------------------------------------------------
async def _metrics_logger(
metrics: Metrics,
stop: asyncio.Event,
pool: ProxyPool | None = None,
) -> None:
"""Log metrics summary every 60 seconds."""
while not stop.is_set():
try:
await asyncio.wait_for(stop.wait(), timeout=60.0)
except asyncio.TimeoutError:
pass
if not stop.is_set():
line = metrics.summary()
if pool:
line += f" pool={pool.alive_count}/{pool.count}"
logger.info("metrics: %s", line)
async def serve(config: Config) -> None:
"""Start the SOCKS5 proxy server."""
metrics = Metrics()
proxy_pool: ProxyPool | ProxySource | None = None
if config.proxy_pool and config.proxy_pool.sources:
pool = ProxyPool(config.proxy_pool, config.chain, config.timeout)
await pool.start()
proxy_pool = pool
elif config.proxy_source and config.proxy_source.url:
proxy_pool = ProxySource(config.proxy_source)
await proxy_pool.start()
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
await _handle_client(r, w, config, proxy_pool, metrics)
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
logger.info("listening on %s", addrs)
if config.chain:
for i, hop in enumerate(config.chain):
logger.info(" chain[%d] %s", i, hop)
else:
logger.info(" mode: direct (no chain)")
if isinstance(proxy_pool, ProxyPool):
nsrc = len(config.proxy_pool.sources)
logger.info(
" pool: %d proxies, %d alive (from %d source%s)",
proxy_pool.count, proxy_pool.alive_count, nsrc, "s" if nsrc != 1 else "",
)
logger.info(" retries: %d", config.retries)
elif proxy_pool:
logger.info(" proxy source: %s (%d proxies)", config.proxy_source.url, proxy_pool.count)
logger.info(" retries: %d", config.retries)
loop = asyncio.get_running_loop()
stop = loop.create_future()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: stop.set_result(s))
# SIGHUP: hot-reload config (timeout, retries, log_level, pool settings)
async def _reload() -> None:
if not config.config_file:
logger.warning("reload: no config file specified, ignoring SIGHUP")
return
try:
new = load_config(config.config_file)
except Exception as e:
logger.warning("reload: failed to read config: %s", e)
return
config.timeout = new.timeout
config.retries = new.retries
if new.log_level != config.log_level:
config.log_level = new.log_level
logging.getLogger("s5p").setLevel(
getattr(logging, new.log_level.upper(), logging.INFO),
)
if isinstance(proxy_pool, ProxyPool) and new.proxy_pool:
await proxy_pool.reload(new.proxy_pool)
logger.info("reload: config reloaded")
def _on_sighup() -> None:
asyncio.ensure_future(_reload())
loop.add_signal_handler(signal.SIGHUP, _on_sighup)
metrics_stop = asyncio.Event()
pool_ref = proxy_pool if isinstance(proxy_pool, ProxyPool) else None
metrics_task = asyncio.create_task(_metrics_logger(metrics, metrics_stop, pool_ref))
async with srv:
sig = await stop
logger.info("received %s, shutting down", signal.Signals(sig).name)
if isinstance(proxy_pool, ProxyPool):
await proxy_pool.stop()
shutdown_line = metrics.summary()
if pool_ref:
shutdown_line += f" pool={pool_ref.alive_count}/{pool_ref.count}"
logger.info("metrics: %s", shutdown_line)
metrics_stop.set()
await metrics_task