feat: add bypass rules, weighted pool selection, integration tests

Per-listener bypass rules skip the chain for local/private destinations
(CIDR, exact IP/hostname, domain suffix). Weighted multi-candidate pool
selection biases toward pools with more alive proxies. End-to-end
integration tests validate the full client->s5p->hop->target path using
mock SOCKS5 proxies.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
user
2026-02-20 19:58:12 +01:00
parent ef0d8f347b
commit c191942712
11 changed files with 745 additions and 69 deletions

138
tests/conftest.py Normal file
View File

@@ -0,0 +1,138 @@
"""Shared helpers for integration tests."""
from __future__ import annotations
import asyncio
import socket
import struct
from s5p.proto import encode_address, read_socks5_address
def free_port() -> int:
"""Return an available TCP port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
# -- echo server -------------------------------------------------------------
async def _echo_handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter,
) -> None:
"""Echo back everything received, then close."""
try:
while True:
data = await reader.read(65536)
if not data:
break
writer.write(data)
await writer.drain()
except (ConnectionError, asyncio.CancelledError):
pass
finally:
writer.close()
await writer.wait_closed()
async def start_echo_server() -> tuple[str, int, asyncio.Server]:
"""Start a TCP echo server. Returns (host, port, server)."""
host = "127.0.0.1"
port = free_port()
srv = await asyncio.start_server(_echo_handler, host, port)
await srv.start_serving()
return host, port, srv
# -- mock SOCKS5 proxy -------------------------------------------------------
async def _mock_socks5_handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter,
) -> None:
"""Minimal SOCKS5 proxy: greeting, CONNECT, relay."""
remote_writer = None
try:
# greeting
header = await reader.readexactly(2)
if header[0] != 0x05:
return
await reader.readexactly(header[1]) # skip methods
writer.write(b"\x05\x00")
await writer.drain()
# connect request
req = await reader.readexactly(3)
if req[0] != 0x05 or req[1] != 0x01:
return
target_host, target_port = await read_socks5_address(reader)
# connect to actual target
try:
remote_reader, remote_writer = await asyncio.wait_for(
asyncio.open_connection(target_host, target_port),
timeout=5.0,
)
except (OSError, TimeoutError):
# connection refused reply
reply = struct.pack("!BBB", 0x05, 0x05, 0x00)
reply += b"\x01\x00\x00\x00\x00\x00\x00"
writer.write(reply)
await writer.drain()
return
# success reply
atyp, addr_bytes = encode_address(target_host)
reply = struct.pack("!BBB", 0x05, 0x00, 0x00)
reply += bytes([atyp]) + addr_bytes + struct.pack("!H", target_port)
writer.write(reply)
await writer.drain()
# relay both directions (close dst on EOF so peer sees shutdown)
async def _fwd(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionError, asyncio.CancelledError):
pass
finally:
try:
dst.close()
await dst.wait_closed()
except OSError:
pass
await asyncio.gather(
_fwd(reader, remote_writer),
_fwd(remote_reader, writer),
)
except (ConnectionError, asyncio.IncompleteReadError, asyncio.CancelledError):
pass
finally:
if remote_writer:
remote_writer.close()
try:
await remote_writer.wait_closed()
except OSError:
pass
writer.close()
try:
await writer.wait_closed()
except OSError:
pass
async def start_mock_socks5() -> tuple[str, int, asyncio.Server]:
"""Start a mock SOCKS5 proxy. Returns (host, port, server)."""
host = "127.0.0.1"
port = free_port()
srv = await asyncio.start_server(_mock_socks5_handler, host, port)
await srv.start_serving()
return host, port, srv

View File

@@ -132,7 +132,7 @@ class TestHandleStatus:
ListenerConfig(
listen_host="0.0.0.0", listen_port=1081,
chain=[ChainHop("socks5", "127.0.0.1", 9050)],
pool_seq=["default"],
pool_seq=[["default"]],
),
],
)
@@ -180,7 +180,7 @@ class TestHandleStatusMultiPool:
ListenerConfig(
listen_host="0.0.0.0", listen_port=1080,
chain=[ChainHop("socks5", "127.0.0.1", 9050)],
pool_seq=["clean", "clean"], pool_name="clean",
pool_seq=[["clean"], ["clean"]], pool_name="clean",
),
],
)
@@ -195,13 +195,13 @@ class TestHandleStatusMultiPool:
ListenerConfig(
listen_host="0.0.0.0", listen_port=1080,
chain=[ChainHop("socks5", "127.0.0.1", 9050)],
pool_seq=["clean", "mitm"], pool_name="clean",
pool_seq=[["clean"], ["mitm"]], pool_name="clean",
),
],
)
ctx = _make_ctx(config=config)
_, body = _handle_status(ctx)
assert body["listeners"][0]["pool_seq"] == ["clean", "mitm"]
assert body["listeners"][0]["pool_seq"] == [["clean"], ["mitm"]]
assert body["listeners"][0]["pool_hops"] == 2
def test_multi_pool_in_config(self):
@@ -211,13 +211,13 @@ class TestHandleStatusMultiPool:
ListenerConfig(
listen_host="0.0.0.0", listen_port=1080,
chain=[ChainHop("socks5", "127.0.0.1", 9050)],
pool_seq=["clean", "mitm"], pool_name="clean",
pool_seq=[["clean"], ["mitm"]], pool_name="clean",
),
],
)
ctx = _make_ctx(config=config)
_, body = _handle_config(ctx)
assert body["listeners"][0]["pool_seq"] == ["clean", "mitm"]
assert body["listeners"][0]["pool_seq"] == [["clean"], ["mitm"]]
class TestHandleMetrics:
@@ -395,7 +395,7 @@ class TestHandleConfig:
listeners=[ListenerConfig(
listen_host="0.0.0.0", listen_port=1080,
chain=[ChainHop("socks5", "127.0.0.1", 9050)],
pool_seq=["clean", "clean"], pool_name="clean",
pool_seq=[["clean"], ["clean"]], pool_name="clean",
)],
)
ctx = _make_ctx(config=config)

View File

@@ -10,6 +10,7 @@ from s5p.config import (
parse_api_proxies,
parse_proxy_url,
)
from s5p.server import _bypass_match
class TestParseProxyUrl:
@@ -411,7 +412,7 @@ class TestPoolSeq:
"""Test per-hop pool references (pool:name syntax)."""
def test_bare_pool_uses_default_name(self, tmp_path):
"""Bare `pool` + `pool: clean` -> pool_seq=["clean"]."""
"""Bare `pool` + `pool: clean` -> pool_seq=[["clean"]]."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
@@ -421,10 +422,10 @@ class TestPoolSeq:
" - pool\n"
)
c = load_config(cfg_file)
assert c.listeners[0].pool_seq == ["clean"]
assert c.listeners[0].pool_seq == [["clean"]]
def test_bare_pool_no_pool_name(self, tmp_path):
"""Bare `pool` with no `pool:` key -> pool_seq=["default"]."""
"""Bare `pool` with no `pool:` key -> pool_seq=[["default"]]."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
@@ -433,10 +434,10 @@ class TestPoolSeq:
" - pool\n"
)
c = load_config(cfg_file)
assert c.listeners[0].pool_seq == ["default"]
assert c.listeners[0].pool_seq == [["default"]]
def test_pool_colon_name(self, tmp_path):
"""`pool:clean, pool:mitm` -> pool_seq=["clean", "mitm"]."""
"""`pool:clean, pool:mitm` -> pool_seq=[["clean"], ["mitm"]]."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
@@ -446,10 +447,10 @@ class TestPoolSeq:
" - pool:mitm\n"
)
c = load_config(cfg_file)
assert c.listeners[0].pool_seq == ["clean", "mitm"]
assert c.listeners[0].pool_seq == [["clean"], ["mitm"]]
def test_mixed_bare_and_named(self, tmp_path):
"""Bare `pool` + `pool:mitm` with `pool: clean` -> ["clean", "mitm"]."""
"""Bare `pool` + `pool:mitm` with `pool: clean` -> [["clean"], ["mitm"]]."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
@@ -460,10 +461,10 @@ class TestPoolSeq:
" - pool:mitm\n"
)
c = load_config(cfg_file)
assert c.listeners[0].pool_seq == ["clean", "mitm"]
assert c.listeners[0].pool_seq == [["clean"], ["mitm"]]
def test_pool_colon_case_insensitive_prefix(self, tmp_path):
"""`Pool:MyPool` -> pool_seq=["MyPool"] (prefix case-insensitive)."""
"""`Pool:MyPool` -> pool_seq=[["MyPool"]] (prefix case-insensitive)."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
@@ -472,7 +473,7 @@ class TestPoolSeq:
" - Pool:MyPool\n"
)
c = load_config(cfg_file)
assert c.listeners[0].pool_seq == ["MyPool"]
assert c.listeners[0].pool_seq == [["MyPool"]]
def test_pool_colon_empty_is_bare(self, tmp_path):
"""`pool:` (empty name) -> treated as bare pool."""
@@ -485,17 +486,17 @@ class TestPoolSeq:
" - pool:\n"
)
c = load_config(cfg_file)
assert c.listeners[0].pool_seq == ["clean"]
assert c.listeners[0].pool_seq == [["clean"]]
def test_backward_compat_pool_hops_property(self):
"""pool_hops property returns len(pool_seq)."""
lc = ListenerConfig(pool_seq=["clean", "mitm"])
lc = ListenerConfig(pool_seq=[["clean"], ["mitm"]])
assert lc.pool_hops == 2
lc2 = ListenerConfig()
assert lc2.pool_hops == 0
def test_legacy_auto_append(self, tmp_path):
"""Singular `proxy_pool:` -> pool_seq=["default"]."""
"""Singular `proxy_pool:` -> pool_seq=[["default"]]."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listen: 0.0.0.0:1080\n"
@@ -507,9 +508,26 @@ class TestPoolSeq:
)
c = load_config(cfg_file)
lc = c.listeners[0]
assert lc.pool_seq == ["default"]
assert lc.pool_seq == [["default"]]
assert lc.pool_hops == 1
def test_list_candidates(self, tmp_path):
"""List in chain -> multi-candidate hop."""
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
" - listen: 1080\n"
" chain:\n"
" - socks5://tor:9050\n"
" - [pool:clean, pool:mitm]\n"
" - [pool:clean, pool:mitm]\n"
)
c = load_config(cfg_file)
lc = c.listeners[0]
assert len(lc.chain) == 1
assert lc.pool_hops == 2
assert lc.pool_seq == [["clean", "mitm"], ["clean", "mitm"]]
class TestListenerBackwardCompat:
"""Test backward-compatible single listener from old format."""
@@ -573,3 +591,79 @@ class TestListenerPoolCompat:
lc = c.listeners[0]
# explicit listeners: no auto pool_hops
assert lc.pool_hops == 0
class TestBypassConfig:
"""Test bypass rules in listener config."""
def test_bypass_from_yaml(self, tmp_path):
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
" - listen: 1080\n"
" bypass:\n"
" - 127.0.0.0/8\n"
" - 192.168.0.0/16\n"
" - localhost\n"
" - .local\n"
" chain:\n"
" - socks5://127.0.0.1:9050\n"
)
c = load_config(cfg_file)
lc = c.listeners[0]
assert lc.bypass == ["127.0.0.0/8", "192.168.0.0/16", "localhost", ".local"]
def test_bypass_empty_default(self):
lc = ListenerConfig()
assert lc.bypass == []
def test_bypass_absent_from_yaml(self, tmp_path):
cfg_file = tmp_path / "test.yaml"
cfg_file.write_text(
"listeners:\n"
" - listen: 1080\n"
" chain:\n"
" - socks5://127.0.0.1:9050\n"
)
c = load_config(cfg_file)
assert c.listeners[0].bypass == []
class TestBypassMatch:
"""Test _bypass_match function."""
def test_cidr_ipv4(self):
assert _bypass_match(["10.0.0.0/8"], "10.1.2.3") is True
assert _bypass_match(["10.0.0.0/8"], "11.0.0.1") is False
def test_cidr_ipv6(self):
assert _bypass_match(["fc00::/7"], "fd00::1") is True
assert _bypass_match(["fc00::/7"], "2001:db8::1") is False
def test_exact_ip(self):
assert _bypass_match(["127.0.0.1"], "127.0.0.1") is True
assert _bypass_match(["127.0.0.1"], "127.0.0.2") is False
def test_exact_hostname(self):
assert _bypass_match(["localhost"], "localhost") is True
assert _bypass_match(["localhost"], "otherhost") is False
def test_domain_suffix(self):
assert _bypass_match([".local"], "myhost.local") is True
assert _bypass_match([".local"], "local") is True
assert _bypass_match([".local"], "notlocal") is False
assert _bypass_match([".example.com"], "api.example.com") is True
assert _bypass_match([".example.com"], "example.com") is True
def test_multiple_rules(self):
rules = ["10.0.0.0/8", "192.168.0.0/16", ".local"]
assert _bypass_match(rules, "10.1.2.3") is True
assert _bypass_match(rules, "192.168.1.1") is True
assert _bypass_match(rules, "host.local") is True
assert _bypass_match(rules, "8.8.8.8") is False
def test_empty_rules(self):
assert _bypass_match([], "anything") is False
def test_hostname_not_matched_by_cidr(self):
assert _bypass_match(["10.0.0.0/8"], "example.com") is False

285
tests/test_integration.py Normal file
View File

@@ -0,0 +1,285 @@
"""End-to-end integration tests with mock SOCKS5 proxies."""
from __future__ import annotations
import asyncio
import struct
from s5p.config import ChainHop, ListenerConfig
from s5p.proto import encode_address
from s5p.server import _handle_client
from .conftest import free_port, start_echo_server, start_mock_socks5
# -- helpers -----------------------------------------------------------------
async def _socks5_connect(
host: str, port: int, target_host: str, target_port: int,
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""Connect as a SOCKS5 client, perform greeting + CONNECT."""
reader, writer = await asyncio.open_connection(host, port)
# greeting: version 5, 1 method (no-auth)
writer.write(b"\x05\x01\x00")
await writer.drain()
resp = await reader.readexactly(2)
assert resp == b"\x05\x00", f"greeting failed: {resp!r}"
# connect request
atyp, addr_bytes = encode_address(target_host)
writer.write(
struct.pack("!BBB", 0x05, 0x01, 0x00)
+ bytes([atyp])
+ addr_bytes
+ struct.pack("!H", target_port)
)
await writer.drain()
# read reply
rep_header = await reader.readexactly(3)
atyp_resp = (await reader.readexactly(1))[0]
if atyp_resp == 0x01:
await reader.readexactly(4)
elif atyp_resp == 0x03:
length = (await reader.readexactly(1))[0]
await reader.readexactly(length)
elif atyp_resp == 0x04:
await reader.readexactly(16)
await reader.readexactly(2) # port
if rep_header[1] != 0x00:
writer.close()
await writer.wait_closed()
raise ConnectionError(f"SOCKS5 reply={rep_header[1]:#x}")
return reader, writer
async def _close_server(srv: asyncio.Server) -> None:
"""Close a server and wait."""
srv.close()
await srv.wait_closed()
# -- tests -------------------------------------------------------------------
class TestDirectNoChain:
"""Client -> s5p -> echo (empty chain)."""
def test_echo(self):
async def _run():
servers = []
try:
echo_host, echo_port, echo_srv = await start_echo_server()
servers.append(echo_srv)
listener = ListenerConfig(listen_host="127.0.0.1", listen_port=free_port())
s5p_srv = await asyncio.start_server(
lambda r, w: _handle_client(r, w, listener, timeout=5.0, retries=1),
listener.listen_host, listener.listen_port,
)
servers.append(s5p_srv)
await s5p_srv.start_serving()
reader, writer = await _socks5_connect(
listener.listen_host, listener.listen_port, echo_host, echo_port,
)
writer.write(b"hello direct")
await writer.drain()
data = await asyncio.wait_for(reader.read(4096), timeout=2.0)
assert data == b"hello direct"
writer.close()
await writer.wait_closed()
finally:
for s in servers:
await _close_server(s)
asyncio.run(_run())
class TestSingleHop:
"""Client -> s5p -> mock socks5 -> echo."""
def test_echo_through_one_hop(self):
async def _run():
servers = []
try:
echo_host, echo_port, echo_srv = await start_echo_server()
servers.append(echo_srv)
mock_host, mock_port, mock_srv = await start_mock_socks5()
servers.append(mock_srv)
listener = ListenerConfig(
listen_host="127.0.0.1",
listen_port=free_port(),
chain=[ChainHop(proto="socks5", host=mock_host, port=mock_port)],
)
s5p_srv = await asyncio.start_server(
lambda r, w: _handle_client(r, w, listener, timeout=5.0, retries=1),
listener.listen_host, listener.listen_port,
)
servers.append(s5p_srv)
await s5p_srv.start_serving()
reader, writer = await _socks5_connect(
listener.listen_host, listener.listen_port, echo_host, echo_port,
)
writer.write(b"hello one hop")
await writer.drain()
data = await asyncio.wait_for(reader.read(4096), timeout=2.0)
assert data == b"hello one hop"
writer.close()
await writer.wait_closed()
finally:
for s in servers:
await _close_server(s)
asyncio.run(_run())
class TestTwoHops:
"""Client -> s5p -> mock1 -> mock2 -> echo."""
def test_echo_through_two_hops(self):
async def _run():
servers = []
try:
echo_host, echo_port, echo_srv = await start_echo_server()
servers.append(echo_srv)
m1_host, m1_port, m1_srv = await start_mock_socks5()
servers.append(m1_srv)
m2_host, m2_port, m2_srv = await start_mock_socks5()
servers.append(m2_srv)
listener = ListenerConfig(
listen_host="127.0.0.1",
listen_port=free_port(),
chain=[
ChainHop(proto="socks5", host=m1_host, port=m1_port),
ChainHop(proto="socks5", host=m2_host, port=m2_port),
],
)
s5p_srv = await asyncio.start_server(
lambda r, w: _handle_client(r, w, listener, timeout=5.0, retries=1),
listener.listen_host, listener.listen_port,
)
servers.append(s5p_srv)
await s5p_srv.start_serving()
reader, writer = await _socks5_connect(
listener.listen_host, listener.listen_port, echo_host, echo_port,
)
writer.write(b"hello two hops")
await writer.drain()
data = await asyncio.wait_for(reader.read(4096), timeout=2.0)
assert data == b"hello two hops"
writer.close()
await writer.wait_closed()
finally:
for s in servers:
await _close_server(s)
asyncio.run(_run())
class TestConnectionRefused:
"""Dead hop returns SOCKS5 error to client."""
def test_refused(self):
async def _run():
servers = []
try:
# use a port with nothing listening
dead_port = free_port()
listener = ListenerConfig(
listen_host="127.0.0.1",
listen_port=free_port(),
chain=[ChainHop(proto="socks5", host="127.0.0.1", port=dead_port)],
)
s5p_srv = await asyncio.start_server(
lambda r, w: _handle_client(r, w, listener, timeout=3.0, retries=1),
listener.listen_host, listener.listen_port,
)
servers.append(s5p_srv)
await s5p_srv.start_serving()
reader, writer = await asyncio.open_connection(
listener.listen_host, listener.listen_port,
)
# greeting
writer.write(b"\x05\x01\x00")
await writer.drain()
resp = await reader.readexactly(2)
assert resp == b"\x05\x00"
# connect to a dummy target
atyp, addr_bytes = encode_address("127.0.0.1")
writer.write(
struct.pack("!BBB", 0x05, 0x01, 0x00)
+ bytes([atyp])
+ addr_bytes
+ struct.pack("!H", 9999)
)
await writer.drain()
# should get error reply (non-zero rep field)
rep = await asyncio.wait_for(reader.read(4096), timeout=5.0)
assert len(rep) >= 3
assert rep[1] != 0x00, "expected non-zero SOCKS5 reply code"
writer.close()
await writer.wait_closed()
finally:
for s in servers:
await _close_server(s)
asyncio.run(_run())
class TestBypassDirectConnect:
"""Target matches bypass rule -> chain skipped, direct connect to echo."""
def test_bypass_skips_chain(self):
async def _run():
servers = []
try:
echo_host, echo_port, echo_srv = await start_echo_server()
servers.append(echo_srv)
# dead hop -- would fail if bypass didn't skip it
dead_port = free_port()
listener = ListenerConfig(
listen_host="127.0.0.1",
listen_port=free_port(),
chain=[ChainHop(proto="socks5", host="127.0.0.1", port=dead_port)],
bypass=["127.0.0.0/8"],
)
s5p_srv = await asyncio.start_server(
lambda r, w: _handle_client(r, w, listener, timeout=5.0, retries=1),
listener.listen_host, listener.listen_port,
)
servers.append(s5p_srv)
await s5p_srv.start_serving()
reader, writer = await _socks5_connect(
listener.listen_host, listener.listen_port, echo_host, echo_port,
)
writer.write(b"hello bypass")
await writer.drain()
data = await asyncio.wait_for(reader.read(4096), timeout=2.0)
assert data == b"hello bypass"
writer.close()
await writer.wait_closed()
finally:
for s in servers:
await _close_server(s)
asyncio.run(_run())