feat: pre-warmed TCP connection pool to first hop

Add FirstHopPool that maintains a deque of pre-established TCP
connections to chain[0]. Connections idle beyond pool_max_idle are
evicted; a background task refills to pool_size. build_chain() tries
the pool first, falls back to open_connection. Enabled with
pool_size > 0 in config. Only pools the TCP handshake -- SOCKS/HTTP
tunnels are consumed, not returned.
This commit is contained in:
user
2026-02-15 17:56:03 +01:00
parent 903cb38b9f
commit 248f5c3306
3 changed files with 259 additions and 5 deletions

138
src/s5p/connpool.py Normal file
View File

@@ -0,0 +1,138 @@
"""Pre-warmed TCP connection pool to the first proxy hop."""
from __future__ import annotations
import asyncio
import logging
import time
from collections import deque
from dataclasses import dataclass
from .config import ChainHop
logger = logging.getLogger("s5p")
@dataclass
class _PooledConn:
"""A pre-established TCP connection with creation timestamp."""
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
created: float
class FirstHopPool:
"""Pool of pre-warmed TCP connections to the first proxy hop.
Only the TCP handshake is pooled -- once a SOCKS/HTTP CONNECT
tunnel is established, the connection is consumed and cannot be
returned to the pool.
"""
def __init__(
self,
hop: ChainHop,
size: int = 8,
max_idle: float = 30.0,
) -> None:
self._hop = hop
self._size = size
self._max_idle = max_idle
self._pool: deque[_PooledConn] = deque()
self._task: asyncio.Task | None = None
self._stop = asyncio.Event()
async def start(self) -> None:
"""Pre-warm the pool and start the background refill loop."""
await self._fill()
self._task = asyncio.create_task(self._refill_loop())
logger.info(
"connpool: started hop=%s size=%d max_idle=%.0fs ready=%d",
self._hop, self._size, self._max_idle, len(self._pool),
)
async def stop(self) -> None:
"""Drain the pool and cancel the refill task."""
self._stop.set()
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
while self._pool:
conn = self._pool.popleft()
self._close_conn(conn)
logger.debug("connpool: stopped")
async def acquire(self) -> tuple[asyncio.StreamReader, asyncio.StreamWriter] | None:
"""Return a pre-established connection, or None if pool is empty.
The caller must fall back to ``asyncio.open_connection()`` when
None is returned. Stale connections are evicted on acquire.
"""
now = time.monotonic()
while self._pool:
conn = self._pool.popleft()
if now - conn.created > self._max_idle:
self._close_conn(conn)
continue
# Verify the connection is still open
if conn.writer.is_closing():
continue
return conn.reader, conn.writer
return None
async def _fill(self) -> None:
"""Top up the pool to target size."""
needed = self._size - len(self._pool)
if needed <= 0:
return
tasks = [self._open_one() for _ in range(needed)]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, _PooledConn):
self._pool.append(result)
elif isinstance(result, Exception):
logger.debug("connpool: pre-warm failed: %s", result)
async def _open_one(self) -> _PooledConn:
"""Open a single TCP connection to the first hop."""
reader, writer = await asyncio.wait_for(
asyncio.open_connection(self._hop.host, self._hop.port),
timeout=5.0,
)
return _PooledConn(reader=reader, writer=writer, created=time.monotonic())
async def _refill_loop(self) -> None:
"""Periodically evict stale connections and refill."""
while not self._stop.is_set():
try:
await asyncio.wait_for(self._stop.wait(), timeout=self._max_idle / 2)
except TimeoutError:
pass
if self._stop.is_set():
break
self._evict_stale()
try:
await self._fill()
except Exception as e:
logger.debug("connpool: refill error: %s", e)
def _evict_stale(self) -> None:
"""Remove connections older than max_idle."""
now = time.monotonic()
while self._pool and now - self._pool[0].created > self._max_idle:
conn = self._pool.popleft()
self._close_conn(conn)
@staticmethod
def _close_conn(conn: _PooledConn) -> None:
"""Close a pooled connection silently."""
try:
conn.writer.close()
except OSError:
pass

View File

@@ -8,9 +8,13 @@ import logging
import socket
import struct
from enum import IntEnum
from typing import TYPE_CHECKING
from .config import ChainHop
if TYPE_CHECKING:
from .connpool import FirstHopPool
logger = logging.getLogger("s5p")
@@ -214,11 +218,14 @@ async def build_chain(
target_host: str,
target_port: int,
timeout: float = 10.0,
first_hop_pool: FirstHopPool | None = None,
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""Build a tunnel through the proxy chain to the target.
Connects to the first hop via TCP, then negotiates each subsequent
hop over the tunnel established by the previous one.
hop over the tunnel established by the previous one. If a
``first_hop_pool`` is provided, attempts to reuse a pre-warmed
TCP connection for the initial hop.
"""
if not chain:
return await asyncio.wait_for(
@@ -226,10 +233,14 @@ async def build_chain(
timeout=timeout,
)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(chain[0].host, chain[0].port),
timeout=timeout,
)
conn = await first_hop_pool.acquire() if first_hop_pool else None
if conn:
reader, writer = conn
else:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(chain[0].host, chain[0].port),
timeout=timeout,
)
try:
for i, hop in enumerate(chain):

105
tests/test_connpool.py Normal file
View File

@@ -0,0 +1,105 @@
"""Tests for the first-hop connection pool."""
import asyncio
from s5p.config import ChainHop
from s5p.connpool import FirstHopPool
async def _echo_server(host="127.0.0.1", port=0):
"""Start a TCP server that accepts connections and holds them open."""
async def handler(reader, writer):
try:
await reader.read(1) # block until client disconnects
except (ConnectionError, asyncio.CancelledError):
pass
finally:
writer.close()
server = await asyncio.start_server(handler, host, port)
port = server.sockets[0].getsockname()[1]
return server, port
class TestFirstHopPool:
"""Test connection pool lifecycle."""
def test_acquire_returns_connection(self):
async def run():
server, port = await _echo_server()
async with server:
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
pool = FirstHopPool(hop, size=2, max_idle=10.0)
await pool.start()
try:
conn = await pool.acquire()
assert conn is not None
reader, writer = conn
assert not writer.is_closing()
writer.close()
finally:
await pool.stop()
asyncio.run(run())
def test_acquire_exhausts_pool(self):
async def run():
server, port = await _echo_server()
async with server:
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
pool = FirstHopPool(hop, size=2, max_idle=10.0)
await pool.start()
try:
c1 = await pool.acquire()
c2 = await pool.acquire()
c3 = await pool.acquire()
assert c1 is not None
assert c2 is not None
assert c3 is None # pool exhausted
c1[1].close()
c2[1].close()
finally:
await pool.stop()
asyncio.run(run())
def test_stale_connections_evicted(self):
async def run():
server, port = await _echo_server()
async with server:
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
pool = FirstHopPool(hop, size=2, max_idle=0.05)
await pool._fill() # pre-warm without starting refill loop
assert len(pool._pool) == 2
await asyncio.sleep(0.1) # let connections go stale
conn = await pool.acquire()
assert conn is None # all stale, evicted
asyncio.run(run())
def test_stop_drains_pool(self):
async def run():
server, port = await _echo_server()
async with server:
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
pool = FirstHopPool(hop, size=4, max_idle=30.0)
await pool.start()
await pool.stop()
# Pool should be empty after stop
conn = await pool.acquire()
assert conn is None
asyncio.run(run())
def test_unreachable_hop_graceful(self):
"""Pool creation with an unreachable hop should not raise."""
async def run():
hop = ChainHop(proto="socks5", host="127.0.0.1", port=1) # nothing listening
pool = FirstHopPool(hop, size=2, max_idle=10.0)
await pool.start() # should not raise, just log warnings
conn = await pool.acquire()
assert conn is None # no connections could be established
await pool.stop()
asyncio.run(run())