From 248f5c330660a0f1dfe8be3dd3527cd58dc4edbd Mon Sep 17 00:00:00 2001 From: user Date: Sun, 15 Feb 2026 17:56:03 +0100 Subject: [PATCH] feat: pre-warmed TCP connection pool to first hop Add FirstHopPool that maintains a deque of pre-established TCP connections to chain[0]. Connections idle beyond pool_max_idle are evicted; a background task refills to pool_size. build_chain() tries the pool first, falls back to open_connection. Enabled with pool_size > 0 in config. Only pools the TCP handshake -- SOCKS/HTTP tunnels are consumed, not returned. --- src/s5p/connpool.py | 138 +++++++++++++++++++++++++++++++++++++++++ src/s5p/proto.py | 21 +++++-- tests/test_connpool.py | 105 +++++++++++++++++++++++++++++++ 3 files changed, 259 insertions(+), 5 deletions(-) create mode 100644 src/s5p/connpool.py create mode 100644 tests/test_connpool.py diff --git a/src/s5p/connpool.py b/src/s5p/connpool.py new file mode 100644 index 0000000..1073a9d --- /dev/null +++ b/src/s5p/connpool.py @@ -0,0 +1,138 @@ +"""Pre-warmed TCP connection pool to the first proxy hop.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from collections import deque +from dataclasses import dataclass + +from .config import ChainHop + +logger = logging.getLogger("s5p") + + +@dataclass +class _PooledConn: + """A pre-established TCP connection with creation timestamp.""" + + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + created: float + + +class FirstHopPool: + """Pool of pre-warmed TCP connections to the first proxy hop. + + Only the TCP handshake is pooled -- once a SOCKS/HTTP CONNECT + tunnel is established, the connection is consumed and cannot be + returned to the pool. + """ + + def __init__( + self, + hop: ChainHop, + size: int = 8, + max_idle: float = 30.0, + ) -> None: + self._hop = hop + self._size = size + self._max_idle = max_idle + self._pool: deque[_PooledConn] = deque() + self._task: asyncio.Task | None = None + self._stop = asyncio.Event() + + async def start(self) -> None: + """Pre-warm the pool and start the background refill loop.""" + await self._fill() + self._task = asyncio.create_task(self._refill_loop()) + logger.info( + "connpool: started hop=%s size=%d max_idle=%.0fs ready=%d", + self._hop, self._size, self._max_idle, len(self._pool), + ) + + async def stop(self) -> None: + """Drain the pool and cancel the refill task.""" + self._stop.set() + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + while self._pool: + conn = self._pool.popleft() + self._close_conn(conn) + logger.debug("connpool: stopped") + + async def acquire(self) -> tuple[asyncio.StreamReader, asyncio.StreamWriter] | None: + """Return a pre-established connection, or None if pool is empty. + + The caller must fall back to ``asyncio.open_connection()`` when + None is returned. Stale connections are evicted on acquire. + """ + now = time.monotonic() + while self._pool: + conn = self._pool.popleft() + if now - conn.created > self._max_idle: + self._close_conn(conn) + continue + # Verify the connection is still open + if conn.writer.is_closing(): + continue + return conn.reader, conn.writer + return None + + async def _fill(self) -> None: + """Top up the pool to target size.""" + needed = self._size - len(self._pool) + if needed <= 0: + return + + tasks = [self._open_one() for _ in range(needed)] + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, _PooledConn): + self._pool.append(result) + elif isinstance(result, Exception): + logger.debug("connpool: pre-warm failed: %s", result) + + async def _open_one(self) -> _PooledConn: + """Open a single TCP connection to the first hop.""" + reader, writer = await asyncio.wait_for( + asyncio.open_connection(self._hop.host, self._hop.port), + timeout=5.0, + ) + return _PooledConn(reader=reader, writer=writer, created=time.monotonic()) + + async def _refill_loop(self) -> None: + """Periodically evict stale connections and refill.""" + while not self._stop.is_set(): + try: + await asyncio.wait_for(self._stop.wait(), timeout=self._max_idle / 2) + except TimeoutError: + pass + if self._stop.is_set(): + break + self._evict_stale() + try: + await self._fill() + except Exception as e: + logger.debug("connpool: refill error: %s", e) + + def _evict_stale(self) -> None: + """Remove connections older than max_idle.""" + now = time.monotonic() + while self._pool and now - self._pool[0].created > self._max_idle: + conn = self._pool.popleft() + self._close_conn(conn) + + @staticmethod + def _close_conn(conn: _PooledConn) -> None: + """Close a pooled connection silently.""" + try: + conn.writer.close() + except OSError: + pass diff --git a/src/s5p/proto.py b/src/s5p/proto.py index 13e2901..ff7faee 100644 --- a/src/s5p/proto.py +++ b/src/s5p/proto.py @@ -8,9 +8,13 @@ import logging import socket import struct from enum import IntEnum +from typing import TYPE_CHECKING from .config import ChainHop +if TYPE_CHECKING: + from .connpool import FirstHopPool + logger = logging.getLogger("s5p") @@ -214,11 +218,14 @@ async def build_chain( target_host: str, target_port: int, timeout: float = 10.0, + first_hop_pool: FirstHopPool | None = None, ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Build a tunnel through the proxy chain to the target. Connects to the first hop via TCP, then negotiates each subsequent - hop over the tunnel established by the previous one. + hop over the tunnel established by the previous one. If a + ``first_hop_pool`` is provided, attempts to reuse a pre-warmed + TCP connection for the initial hop. """ if not chain: return await asyncio.wait_for( @@ -226,10 +233,14 @@ async def build_chain( timeout=timeout, ) - reader, writer = await asyncio.wait_for( - asyncio.open_connection(chain[0].host, chain[0].port), - timeout=timeout, - ) + conn = await first_hop_pool.acquire() if first_hop_pool else None + if conn: + reader, writer = conn + else: + reader, writer = await asyncio.wait_for( + asyncio.open_connection(chain[0].host, chain[0].port), + timeout=timeout, + ) try: for i, hop in enumerate(chain): diff --git a/tests/test_connpool.py b/tests/test_connpool.py new file mode 100644 index 0000000..5dc8654 --- /dev/null +++ b/tests/test_connpool.py @@ -0,0 +1,105 @@ +"""Tests for the first-hop connection pool.""" + +import asyncio + +from s5p.config import ChainHop +from s5p.connpool import FirstHopPool + + +async def _echo_server(host="127.0.0.1", port=0): + """Start a TCP server that accepts connections and holds them open.""" + + async def handler(reader, writer): + try: + await reader.read(1) # block until client disconnects + except (ConnectionError, asyncio.CancelledError): + pass + finally: + writer.close() + + server = await asyncio.start_server(handler, host, port) + port = server.sockets[0].getsockname()[1] + return server, port + + +class TestFirstHopPool: + """Test connection pool lifecycle.""" + + def test_acquire_returns_connection(self): + async def run(): + server, port = await _echo_server() + async with server: + hop = ChainHop(proto="socks5", host="127.0.0.1", port=port) + pool = FirstHopPool(hop, size=2, max_idle=10.0) + await pool.start() + try: + conn = await pool.acquire() + assert conn is not None + reader, writer = conn + assert not writer.is_closing() + writer.close() + finally: + await pool.stop() + + asyncio.run(run()) + + def test_acquire_exhausts_pool(self): + async def run(): + server, port = await _echo_server() + async with server: + hop = ChainHop(proto="socks5", host="127.0.0.1", port=port) + pool = FirstHopPool(hop, size=2, max_idle=10.0) + await pool.start() + try: + c1 = await pool.acquire() + c2 = await pool.acquire() + c3 = await pool.acquire() + assert c1 is not None + assert c2 is not None + assert c3 is None # pool exhausted + c1[1].close() + c2[1].close() + finally: + await pool.stop() + + asyncio.run(run()) + + def test_stale_connections_evicted(self): + async def run(): + server, port = await _echo_server() + async with server: + hop = ChainHop(proto="socks5", host="127.0.0.1", port=port) + pool = FirstHopPool(hop, size=2, max_idle=0.05) + await pool._fill() # pre-warm without starting refill loop + assert len(pool._pool) == 2 + await asyncio.sleep(0.1) # let connections go stale + conn = await pool.acquire() + assert conn is None # all stale, evicted + + asyncio.run(run()) + + def test_stop_drains_pool(self): + async def run(): + server, port = await _echo_server() + async with server: + hop = ChainHop(proto="socks5", host="127.0.0.1", port=port) + pool = FirstHopPool(hop, size=4, max_idle=30.0) + await pool.start() + await pool.stop() + # Pool should be empty after stop + conn = await pool.acquire() + assert conn is None + + asyncio.run(run()) + + def test_unreachable_hop_graceful(self): + """Pool creation with an unreachable hop should not raise.""" + async def run(): + hop = ChainHop(proto="socks5", host="127.0.0.1", port=1) # nothing listening + pool = FirstHopPool(hop, size=2, max_idle=10.0) + await pool.start() # should not raise, just log warnings + conn = await pool.acquire() + assert conn is None # no connections could be established + await pool.stop() + + asyncio.run(run())