On SIGHUP, re-read the YAML config file and update mutable runtime settings: timeout, retries, log_level, and pool config (sources, intervals, thresholds). Pool triggers an immediate source re-fetch. Listen address and chain require restart.
297 lines
9.7 KiB
Python
297 lines
9.7 KiB
Python
"""SOCKS5 proxy server with proxy-chain support."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import signal
|
|
import struct
|
|
import time
|
|
|
|
from .config import Config, load_config
|
|
from .metrics import Metrics
|
|
from .pool import ProxyPool
|
|
from .proto import ProtoError, Socks5Reply, build_chain, read_socks5_address
|
|
from .source import ProxySource
|
|
|
|
logger = logging.getLogger("s5p")
|
|
|
|
BUFFER_SIZE = 65536
|
|
|
|
|
|
# -- relay -------------------------------------------------------------------
|
|
|
|
|
|
async def _relay(
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
) -> int:
|
|
"""Unidirectional data relay. Returns total bytes transferred."""
|
|
total = 0
|
|
try:
|
|
while True:
|
|
data = await reader.read(BUFFER_SIZE)
|
|
if not data:
|
|
break
|
|
total += len(data)
|
|
writer.write(data)
|
|
await writer.drain()
|
|
except (ConnectionError, asyncio.CancelledError, OSError):
|
|
pass
|
|
finally:
|
|
try:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
except (OSError, ConnectionError):
|
|
pass
|
|
return total
|
|
|
|
|
|
# -- SOCKS5 server -----------------------------------------------------------
|
|
|
|
|
|
def _socks5_reply(rep: int) -> bytes:
|
|
"""Build a minimal SOCKS5 reply packet."""
|
|
return struct.pack("!BBB", 0x05, rep, 0x00) + b"\x01\x00\x00\x00\x00\x00\x00"
|
|
|
|
|
|
async def _handle_client(
|
|
client_reader: asyncio.StreamReader,
|
|
client_writer: asyncio.StreamWriter,
|
|
config: Config,
|
|
proxy_pool: ProxyPool | ProxySource | None = None,
|
|
metrics: Metrics | None = None,
|
|
) -> None:
|
|
"""Handle a single SOCKS5 client connection."""
|
|
peer = client_writer.get_extra_info("peername")
|
|
tag = f"{peer[0]}:{peer[1]}" if peer else "?"
|
|
|
|
if metrics:
|
|
metrics.connections += 1
|
|
|
|
try:
|
|
# -- greeting --
|
|
header = await asyncio.wait_for(client_reader.readexactly(2), timeout=10.0)
|
|
if header[0] != 0x05:
|
|
logger.warning("[%s] bad socks version: %d", tag, header[0])
|
|
return
|
|
|
|
methods = await client_reader.readexactly(header[1])
|
|
if 0x00 not in methods:
|
|
client_writer.write(b"\x05\xff")
|
|
await client_writer.drain()
|
|
return
|
|
|
|
client_writer.write(b"\x05\x00")
|
|
await client_writer.drain()
|
|
|
|
# -- connect request --
|
|
req = await asyncio.wait_for(client_reader.readexactly(3), timeout=10.0)
|
|
if req[0] != 0x05:
|
|
return
|
|
if req[1] != 0x01:
|
|
client_writer.write(_socks5_reply(Socks5Reply.COMMAND_NOT_SUPPORTED))
|
|
await client_writer.drain()
|
|
return
|
|
|
|
target_host, target_port = await read_socks5_address(client_reader)
|
|
logger.info("[%s] connect %s:%d", tag, target_host, target_port)
|
|
|
|
# -- build chain (with retry) --
|
|
attempts = config.retries if proxy_pool else 1
|
|
last_err: Exception | None = None
|
|
|
|
for attempt in range(attempts):
|
|
effective_chain = list(config.chain)
|
|
pool_hop = None
|
|
if proxy_pool:
|
|
pool_hop = await proxy_pool.get()
|
|
if pool_hop:
|
|
effective_chain.append(pool_hop)
|
|
logger.debug("[%s] +proxy %s", tag, pool_hop)
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
remote_reader, remote_writer = await build_chain(
|
|
effective_chain, target_host, target_port, timeout=config.timeout
|
|
)
|
|
dt = time.monotonic() - t0
|
|
logger.debug("[%s] chain up in %.0fms", tag, dt * 1000)
|
|
break
|
|
except (ProtoError, asyncio.TimeoutError, ConnectionError, OSError) as e:
|
|
last_err = e
|
|
if pool_hop and isinstance(proxy_pool, ProxyPool):
|
|
proxy_pool.report_failure(pool_hop)
|
|
if metrics:
|
|
metrics.retries += 1
|
|
if attempt + 1 < attempts:
|
|
logger.debug("[%s] attempt %d/%d failed: %s", tag, attempt + 1, attempts, e)
|
|
continue
|
|
raise last_err
|
|
|
|
# -- success --
|
|
if metrics:
|
|
metrics.success += 1
|
|
metrics.active += 1
|
|
|
|
client_writer.write(_socks5_reply(Socks5Reply.SUCCEEDED))
|
|
await client_writer.drain()
|
|
|
|
# -- relay --
|
|
try:
|
|
bytes_up, bytes_down = await asyncio.gather(
|
|
_relay(client_reader, remote_writer),
|
|
_relay(remote_reader, client_writer),
|
|
)
|
|
if metrics:
|
|
metrics.bytes_in += bytes_up
|
|
metrics.bytes_out += bytes_down
|
|
finally:
|
|
if metrics:
|
|
metrics.active -= 1
|
|
|
|
except ProtoError as e:
|
|
logger.warning("[%s] %s", tag, e)
|
|
if metrics:
|
|
metrics.failed += 1
|
|
try:
|
|
client_writer.write(_socks5_reply(e.reply))
|
|
await client_writer.drain()
|
|
except OSError:
|
|
pass
|
|
except asyncio.TimeoutError:
|
|
logger.warning("[%s] timeout", tag)
|
|
if metrics:
|
|
metrics.failed += 1
|
|
try:
|
|
client_writer.write(_socks5_reply(Socks5Reply.TTL_EXPIRED))
|
|
await client_writer.drain()
|
|
except OSError:
|
|
pass
|
|
except (ConnectionError, OSError) as e:
|
|
logger.debug("[%s] %s", tag, e)
|
|
if metrics:
|
|
metrics.failed += 1
|
|
try:
|
|
client_writer.write(_socks5_reply(Socks5Reply.CONNECTION_REFUSED))
|
|
await client_writer.drain()
|
|
except OSError:
|
|
pass
|
|
except Exception:
|
|
logger.exception("[%s] unexpected error", tag)
|
|
if metrics:
|
|
metrics.failed += 1
|
|
finally:
|
|
try:
|
|
client_writer.close()
|
|
await client_writer.wait_closed()
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
# -- entry point -------------------------------------------------------------
|
|
|
|
|
|
async def _metrics_logger(
|
|
metrics: Metrics,
|
|
stop: asyncio.Event,
|
|
pool: ProxyPool | None = None,
|
|
) -> None:
|
|
"""Log metrics summary every 60 seconds."""
|
|
while not stop.is_set():
|
|
try:
|
|
await asyncio.wait_for(stop.wait(), timeout=60.0)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
if not stop.is_set():
|
|
line = metrics.summary()
|
|
if pool:
|
|
line += f" pool={pool.alive_count}/{pool.count}"
|
|
logger.info("metrics: %s", line)
|
|
|
|
|
|
async def serve(config: Config) -> None:
|
|
"""Start the SOCKS5 proxy server."""
|
|
metrics = Metrics()
|
|
|
|
proxy_pool: ProxyPool | ProxySource | None = None
|
|
if config.proxy_pool and config.proxy_pool.sources:
|
|
pool = ProxyPool(config.proxy_pool, config.chain, config.timeout)
|
|
await pool.start()
|
|
proxy_pool = pool
|
|
elif config.proxy_source and config.proxy_source.url:
|
|
proxy_pool = ProxySource(config.proxy_source)
|
|
await proxy_pool.start()
|
|
|
|
async def on_client(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
|
|
await _handle_client(r, w, config, proxy_pool, metrics)
|
|
|
|
srv = await asyncio.start_server(on_client, config.listen_host, config.listen_port)
|
|
addrs = ", ".join(str(s.getsockname()) for s in srv.sockets)
|
|
logger.info("listening on %s", addrs)
|
|
|
|
if config.chain:
|
|
for i, hop in enumerate(config.chain):
|
|
logger.info(" chain[%d] %s", i, hop)
|
|
else:
|
|
logger.info(" mode: direct (no chain)")
|
|
|
|
if isinstance(proxy_pool, ProxyPool):
|
|
nsrc = len(config.proxy_pool.sources)
|
|
logger.info(
|
|
" pool: %d proxies, %d alive (from %d source%s)",
|
|
proxy_pool.count, proxy_pool.alive_count, nsrc, "s" if nsrc != 1 else "",
|
|
)
|
|
logger.info(" retries: %d", config.retries)
|
|
elif proxy_pool:
|
|
logger.info(" proxy source: %s (%d proxies)", config.proxy_source.url, proxy_pool.count)
|
|
logger.info(" retries: %d", config.retries)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
stop = loop.create_future()
|
|
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, lambda s=sig: stop.set_result(s))
|
|
|
|
# SIGHUP: hot-reload config (timeout, retries, log_level, pool settings)
|
|
async def _reload() -> None:
|
|
if not config.config_file:
|
|
logger.warning("reload: no config file specified, ignoring SIGHUP")
|
|
return
|
|
try:
|
|
new = load_config(config.config_file)
|
|
except Exception as e:
|
|
logger.warning("reload: failed to read config: %s", e)
|
|
return
|
|
config.timeout = new.timeout
|
|
config.retries = new.retries
|
|
if new.log_level != config.log_level:
|
|
config.log_level = new.log_level
|
|
logging.getLogger("s5p").setLevel(
|
|
getattr(logging, new.log_level.upper(), logging.INFO),
|
|
)
|
|
if isinstance(proxy_pool, ProxyPool) and new.proxy_pool:
|
|
await proxy_pool.reload(new.proxy_pool)
|
|
logger.info("reload: config reloaded")
|
|
|
|
def _on_sighup() -> None:
|
|
asyncio.ensure_future(_reload())
|
|
|
|
loop.add_signal_handler(signal.SIGHUP, _on_sighup)
|
|
|
|
metrics_stop = asyncio.Event()
|
|
pool_ref = proxy_pool if isinstance(proxy_pool, ProxyPool) else None
|
|
metrics_task = asyncio.create_task(_metrics_logger(metrics, metrics_stop, pool_ref))
|
|
|
|
async with srv:
|
|
sig = await stop
|
|
logger.info("received %s, shutting down", signal.Signals(sig).name)
|
|
if isinstance(proxy_pool, ProxyPool):
|
|
await proxy_pool.stop()
|
|
shutdown_line = metrics.summary()
|
|
if pool_ref:
|
|
shutdown_line += f" pool={pool_ref.alive_count}/{pool_ref.count}"
|
|
logger.info("metrics: %s", shutdown_line)
|
|
metrics_stop.set()
|
|
await metrics_task
|