Files
s5p/src/s5p/server.py
T
user b72d083f56 feat: wire control API into server and config
Add api_host/api_port to Config dataclass, parse api_listen key in
load_config(), add --api [HOST:]PORT CLI flag. Start/stop API server
in serve() alongside the SOCKS5 listener.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 19:03:44 +01:00

327 lines
10 KiB
Python

"""SOCKS5 proxy server with proxy-chain support."""
from __future__ import annotations
import asyncio
import logging
import signal
import struct
import time
from .api import start_api
from .config import Config, load_config
from .connpool import FirstHopPool
from .metrics import Metrics
from .pool import ProxyPool
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
logger = logging.getLogger("s5p")
BUFFER_SIZE = 65536
# -- relay -------------------------------------------------------------------
async def _relay(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> int:
"""Unidirectional data relay. Returns total bytes transferred."""
total = 0
try:
while True:
data = await reader.read(BUFFER_SIZE)
if not data:
break
total += len(data)
writer.write(data)
await writer.drain()
except (ConnectionError, asyncio.CancelledError, OSError):
pass
finally:
try:
writer.close()
await writer.wait_closed()
except (OSError, ConnectionError):
pass
return total
# -- SOCKS5 server -----------------------------------------------------------
def _socks5_reply(rep: int) -> bytes:
"""Build a minimal SOCKS5 reply packet."""
return struct.pack("!BBB", 0x05, rep, 0x00) + b"\x01\x00\x00\x00\x00\x00\x00"
async def _handle_client(
client_reader: asyncio.StreamReader,
client_writer: asyncio.StreamWriter,
config: Config,
proxy_pool: ProxyPool | None = None,
metrics: Metrics | None = None,
first_hop_pool: FirstHopPool | None = None,
) -> None:
"""Handle a single SOCKS5 client connection."""
peer = client_writer.get_extra_info("peername")
tag = f"{peer[0]}:{peer[1]}" if peer else "?"
if metrics:
metrics.connections += 1
try:
# -- greeting --
header = await asyncio.wait_for(client_reader.readexactly(2), timeout=10.0)
if header[0] != 0x05:
logger.warning("[%s] bad socks version: %d", tag, header[0])
return
methods = await client_reader.readexactly(header[1])
if 0x00 not in methods:
client_writer.write(b"\x05\xff")
await client_writer.drain()
return
client_writer.write(b"\x05\x00")
await client_writer.drain()
# -- connect request --
req = await asyncio.wait_for(client_reader.readexactly(3), timeout=10.0)
if req[0] != 0x05:
return
if req[1] != 0x01:
client_writer.write(_socks5_reply(Socks5Reply.COMMAND_NOT_SUPPORTED))
await client_writer.drain()
return
target_host, target_port = await read_socks5_address(client_reader)
logger.info("[%s] connect %s:%d", tag, target_host, target_port)
# -- build chain (with retry) --
attempts = config.retries if proxy_pool else 1
last_err: Exception | None = None
for attempt in range(attempts):
effective_chain = list(config.chain)
pool_hop = None
if proxy_pool:
pool_hop = await proxy_pool.get()
if pool_hop:
effective_chain.append(pool_hop)
logger.debug("[%s] +proxy %s", tag, pool_hop)
try:
t0 = time.monotonic()
remote_reader, remote_writer = await build_chain(
effective_chain, target_host, target_port,
timeout=config.timeout, first_hop_pool=first_hop_pool,
)
dt = time.monotonic() - t0
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
break
except (ProtoError, TimeoutError, ConnectionError, OSError) as e:
last_err = e
if pool_hop and proxy_pool:
proxy_pool.report_failure(pool_hop)
if metrics:
metrics.retries += 1
if attempt + 1 < attempts:
logger.debug("[%s] attempt %d/%d failed: %s", tag, attempt + 1, attempts, e)
continue
raise last_err
# -- success --
if metrics:
metrics.success += 1
metrics.active += 1
client_writer.write(_socks5_reply(Socks5Reply.SUCCEEDED))
await client_writer.drain()
# -- relay --
try:
bytes_up, bytes_down = await asyncio.gather(
_relay(client_reader, remote_writer),
_relay(remote_reader, client_writer),
)
if metrics:
metrics.bytes_in += bytes_up
metrics.bytes_out += bytes_down
finally:
if metrics:
metrics.active -= 1
except ProtoError as e:
logger.warning("[%s] %s", tag, e)
if metrics:
metrics.failed += 1
try:
client_writer.write(_socks5_reply(e.reply))
await client_writer.drain()
except OSError:
pass
except TimeoutError:
logger.warning("[%s] timeout", tag)
if metrics:
metrics.failed += 1
try:
client_writer.write(_socks5_reply(Socks5Reply.TTL_EXPIRED))
await client_writer.drain()
except OSError:
pass
except (ConnectionError, OSError) as e:
logger.debug("[%s] %s", tag, e)
if metrics:
metrics.failed += 1
try:
client_writer.write(_socks5_reply(Socks5Reply.CONNECTION_REFUSED))
await client_writer.drain()
except OSError:
pass
except Exception:
logger.exception("[%s] unexpected error", tag)
if metrics:
metrics.failed += 1
finally:
try:
client_writer.close()
await client_writer.wait_closed()
except OSError:
pass
# -- entry point -------------------------------------------------------------
async def _metrics_logger(
metrics: Metrics,
stop: asyncio.Event,
pool: ProxyPool | None = None,
) -> None:
"""Log metrics summary every 60 seconds."""
while not stop.is_set():
try:
await asyncio.wait_for(stop.wait(), timeout=60.0)
except TimeoutError:
pass
if not stop.is_set():
line = metrics.summary()
if pool:
line += f" pool={pool.alive_count}/{pool.count}"
logger.info("metrics: %s", line)
async def serve(config: Config) -> None:
"""Start the SOCKS5 proxy server."""
# register signal handlers early so SIGTERM is never ignored
loop = asyncio.get_running_loop()
stop = loop.create_future()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: stop.set_result(s))
metrics = Metrics()
proxy_pool: ProxyPool | None = None
if config.proxy_pool and config.proxy_pool.sources:
proxy_pool = ProxyPool(config.proxy_pool, config.chain, config.timeout)
await proxy_pool.start()
hop_pool: FirstHopPool | None = None
if config.pool_size > 0 and config.chain:
hop_pool = FirstHopPool(
config.chain[0], size=config.pool_size, max_idle=config.pool_max_idle,
)
await hop_pool.start()
sem = asyncio.Semaphore(config.max_connections)
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
async with sem:
await _handle_client(r, w, config, proxy_pool, metrics, hop_pool)
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
logger.info("listening on %s max_connections=%d", addrs, config.max_connections)
if config.chain:
for i, hop in enumerate(config.chain):
logger.info(" chain[%d] %s", i, hop)
else:
logger.info(" mode: direct (no chain)")
if proxy_pool:
nsrc = len(config.proxy_pool.sources)
logger.info(
" pool: %d proxies, %d alive (from %d source%s)",
proxy_pool.count, proxy_pool.alive_count, nsrc, "s" if nsrc != 1 else "",
)
logger.info(" retries: %d", config.retries)
# -- control API ---------------------------------------------------------
api_srv: asyncio.Server | None = None
if config.api_port:
api_ctx: dict = {
"config": config,
"metrics": metrics,
"pool": proxy_pool,
"hop_pool": hop_pool,
}
# SIGHUP: hot-reload config (timeout, retries, log_level, pool settings)
async def _reload() -> None:
if not config.config_file:
logger.warning("reload: no config file specified, ignoring SIGHUP")
return
try:
new = load_config(config.config_file)
except Exception as e:
logger.warning("reload: failed to read config: %s", e)
return
config.timeout = new.timeout
config.retries = new.retries
config.max_connections = new.max_connections
if new.log_level != config.log_level:
config.log_level = new.log_level
level = getattr(logging, new.log_level.upper(), logging.INFO)
root = logging.getLogger()
root.setLevel(level)
for h in root.handlers:
h.setLevel(level)
logging.getLogger("s5p").setLevel(level)
if proxy_pool and new.proxy_pool:
await proxy_pool.reload(new.proxy_pool)
logger.info("reload: config reloaded")
def _on_sighup() -> None:
asyncio.create_task(_reload())
loop.add_signal_handler(signal.SIGHUP, _on_sighup)
if config.api_port:
api_ctx["reload_fn"] = _reload
api_srv = await start_api(config.api_host, config.api_port, api_ctx)
metrics_stop = asyncio.Event()
pool_ref = proxy_pool
metrics_task = asyncio.create_task(_metrics_logger(metrics, metrics_stop, pool_ref))
async with srv:
sig = await stop
logger.info("received %s, shutting down", signal.Signals(sig).name)
if api_srv:
api_srv.close()
await api_srv.wait_closed()
if hop_pool:
await hop_pool.stop()
if proxy_pool:
await proxy_pool.stop()
shutdown_line = metrics.summary()
if pool_ref:
shutdown_line += f" pool={pool_ref.alive_count}/{pool_ref.count}"
logger.info("metrics: %s", shutdown_line)
metrics_stop.set()
await metrics_task