"""SOCKS5 proxy server with proxy-chain support.""" from __future__ import annotations import asyncio import logging import signal import struct import time from .api import start_api from .config import Config, load_config from .connpool import FirstHopPool from .metrics import Metrics from .pool import ProxyPool from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address logger = logging.getLogger("s5p") BUFFER_SIZE = 65536 # -- relay ------------------------------------------------------------------- async def _relay( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> int: """Unidirectional data relay. Returns total bytes transferred.""" total = 0 try: while True: data = await reader.read(BUFFER_SIZE) if not data: break total += len(data) writer.write(data) await writer.drain() except (ConnectionError, asyncio.CancelledError, OSError): pass finally: try: writer.close() await writer.wait_closed() except (OSError, ConnectionError): pass return total # -- SOCKS5 server ----------------------------------------------------------- def _socks5_reply(rep: int) -> bytes: """Build a minimal SOCKS5 reply packet.""" return struct.pack("!BBB", 0x05, rep, 0x00) + b"\x01\x00\x00\x00\x00\x00\x00" async def _handle_client( client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, config: Config, proxy_pool: ProxyPool | None = None, metrics: Metrics | None = None, first_hop_pool: FirstHopPool | None = None, ) -> None: """Handle a single SOCKS5 client connection.""" peer = client_writer.get_extra_info("peername") tag = f"{peer[0]}:{peer[1]}" if peer else "?" if metrics: metrics.connections += 1 try: # -- greeting -- header = await asyncio.wait_for(client_reader.readexactly(2), timeout=10.0) if header[0] != 0x05: logger.warning("[%s] bad socks version: %d", tag, header[0]) return methods = await client_reader.readexactly(header[1]) if 0x00 not in methods: client_writer.write(b"\x05\xff") await client_writer.drain() return client_writer.write(b"\x05\x00") await client_writer.drain() # -- connect request -- req = await asyncio.wait_for(client_reader.readexactly(3), timeout=10.0) if req[0] != 0x05: return if req[1] != 0x01: client_writer.write(_socks5_reply(Socks5Reply.COMMAND_NOT_SUPPORTED)) await client_writer.drain() return target_host, target_port = await read_socks5_address(client_reader) logger.info("[%s] connect %s:%d", tag, target_host, target_port) # -- build chain (with retry) -- attempts = config.retries if proxy_pool else 1 last_err: Exception | None = None for attempt in range(attempts): effective_chain = list(config.chain) pool_hop = None if proxy_pool: pool_hop = await proxy_pool.get() if pool_hop: effective_chain.append(pool_hop) logger.debug("[%s] +proxy %s", tag, pool_hop) try: t0 = time.monotonic() remote_reader, remote_writer = await build_chain( effective_chain, target_host, target_port, timeout=config.timeout, first_hop_pool=first_hop_pool, ) dt = time.monotonic() - t0 logger.debug("[%s] chain up in %.0fms", tag, dt * 1000) break except (ProtoError, TimeoutError, ConnectionError, OSError) as e: last_err = e if pool_hop and proxy_pool: proxy_pool.report_failure(pool_hop) if metrics: metrics.retries += 1 if attempt + 1 < attempts: logger.debug("[%s] attempt %d/%d failed: %s", tag, attempt + 1, attempts, e) continue raise last_err # -- success -- if metrics: metrics.success += 1 metrics.active += 1 client_writer.write(_socks5_reply(Socks5Reply.SUCCEEDED)) await client_writer.drain() # -- relay -- try: bytes_up, bytes_down = await asyncio.gather( _relay(client_reader, remote_writer), _relay(remote_reader, client_writer), ) if metrics: metrics.bytes_in += bytes_up metrics.bytes_out += bytes_down finally: if metrics: metrics.active -= 1 except ProtoError as e: logger.warning("[%s] %s", tag, e) if metrics: metrics.failed += 1 try: client_writer.write(_socks5_reply(e.reply)) await client_writer.drain() except OSError: pass except TimeoutError: logger.warning("[%s] timeout", tag) if metrics: metrics.failed += 1 try: client_writer.write(_socks5_reply(Socks5Reply.TTL_EXPIRED)) await client_writer.drain() except OSError: pass except (ConnectionError, OSError) as e: logger.debug("[%s] %s", tag, e) if metrics: metrics.failed += 1 try: client_writer.write(_socks5_reply(Socks5Reply.CONNECTION_REFUSED)) await client_writer.drain() except OSError: pass except Exception: logger.exception("[%s] unexpected error", tag) if metrics: metrics.failed += 1 finally: try: client_writer.close() await client_writer.wait_closed() except OSError: pass # -- entry point ------------------------------------------------------------- async def _metrics_logger( metrics: Metrics, stop: asyncio.Event, pool: ProxyPool | None = None, ) -> None: """Log metrics summary every 60 seconds.""" while not stop.is_set(): try: await asyncio.wait_for(stop.wait(), timeout=60.0) except TimeoutError: pass if not stop.is_set(): line = metrics.summary() if pool: line += f" pool={pool.alive_count}/{pool.count}" logger.info("metrics: %s", line) async def serve(config: Config) -> None: """Start the SOCKS5 proxy server.""" # register signal handlers early so SIGTERM is never ignored loop = asyncio.get_running_loop() stop = loop.create_future() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, lambda s=sig: stop.set_result(s)) metrics = Metrics() proxy_pool: ProxyPool | None = None if config.proxy_pool and config.proxy_pool.sources: proxy_pool = ProxyPool(config.proxy_pool, config.chain, config.timeout) await proxy_pool.start() hop_pool: FirstHopPool | None = None if config.pool_size > 0 and config.chain: hop_pool = FirstHopPool( config.chain[0], size=config.pool_size, max_idle=config.pool_max_idle, ) await hop_pool.start() sem = asyncio.Semaphore(config.max_connections) async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: async with sem: await _handle_client(r, w, config, proxy_pool, metrics, hop_pool) srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port) addrs = ", ".join(str(s.getsockname()) for s in srv.sockets) logger.info("listening on %s max_connections=%d", addrs, config.max_connections) if config.chain: for i, hop in enumerate(config.chain): logger.info(" chain[%d] %s", i, hop) else: logger.info(" mode: direct (no chain)") if proxy_pool: nsrc = len(config.proxy_pool.sources) logger.info( " pool: %d proxies, %d alive (from %d source%s)", proxy_pool.count, proxy_pool.alive_count, nsrc, "s" if nsrc != 1 else "", ) logger.info(" retries: %d", config.retries) # -- control API --------------------------------------------------------- api_srv: asyncio.Server | None = None if config.api_port: api_ctx: dict = { "config": config, "metrics": metrics, "pool": proxy_pool, "hop_pool": hop_pool, } # SIGHUP: hot-reload config (timeout, retries, log_level, pool settings) async def _reload() -> None: if not config.config_file: logger.warning("reload: no config file specified, ignoring SIGHUP") return try: new = load_config(config.config_file) except Exception as e: logger.warning("reload: failed to read config: %s", e) return config.timeout = new.timeout config.retries = new.retries config.max_connections = new.max_connections if new.log_level != config.log_level: config.log_level = new.log_level level = getattr(logging, new.log_level.upper(), logging.INFO) root = logging.getLogger() root.setLevel(level) for h in root.handlers: h.setLevel(level) logging.getLogger("s5p").setLevel(level) if proxy_pool and new.proxy_pool: await proxy_pool.reload(new.proxy_pool) logger.info("reload: config reloaded") def _on_sighup() -> None: asyncio.create_task(_reload()) loop.add_signal_handler(signal.SIGHUP, _on_sighup) if config.api_port: api_ctx["reload_fn"] = _reload api_srv = await start_api(config.api_host, config.api_port, api_ctx) metrics_stop = asyncio.Event() pool_ref = proxy_pool metrics_task = asyncio.create_task(_metrics_logger(metrics, metrics_stop, pool_ref)) async with srv: sig = await stop logger.info("received %s, shutting down", signal.Signals(sig).name) if api_srv: api_srv.close() await api_srv.wait_closed() if hop_pool: await hop_pool.stop() if proxy_pool: await proxy_pool.stop() shutdown_line = metrics.summary() if pool_ref: shutdown_line += f" pool={pool_ref.alive_count}/{pool_ref.count}" logger.info("metrics: %s", shutdown_line) metrics_stop.set() await metrics_task