feat: pre-warmed TCP connection pool to first hop
Add FirstHopPool that maintains a deque of pre-established TCP connections to chain[0]. Connections idle beyond pool_max_idle are evicted; a background task refills to pool_size. build_chain() tries the pool first, falls back to open_connection. Enabled with pool_size > 0 in config. Only pools the TCP handshake -- SOCKS/HTTP tunnels are consumed, not returned.
This commit is contained in:
138
src/s5p/connpool.py
Normal file
138
src/s5p/connpool.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Pre-warmed TCP connection pool to the first proxy hop."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .config import ChainHop
|
||||||
|
|
||||||
|
logger = logging.getLogger("s5p")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _PooledConn:
|
||||||
|
"""A pre-established TCP connection with creation timestamp."""
|
||||||
|
|
||||||
|
reader: asyncio.StreamReader
|
||||||
|
writer: asyncio.StreamWriter
|
||||||
|
created: float
|
||||||
|
|
||||||
|
|
||||||
|
class FirstHopPool:
|
||||||
|
"""Pool of pre-warmed TCP connections to the first proxy hop.
|
||||||
|
|
||||||
|
Only the TCP handshake is pooled -- once a SOCKS/HTTP CONNECT
|
||||||
|
tunnel is established, the connection is consumed and cannot be
|
||||||
|
returned to the pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hop: ChainHop,
|
||||||
|
size: int = 8,
|
||||||
|
max_idle: float = 30.0,
|
||||||
|
) -> None:
|
||||||
|
self._hop = hop
|
||||||
|
self._size = size
|
||||||
|
self._max_idle = max_idle
|
||||||
|
self._pool: deque[_PooledConn] = deque()
|
||||||
|
self._task: asyncio.Task | None = None
|
||||||
|
self._stop = asyncio.Event()
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Pre-warm the pool and start the background refill loop."""
|
||||||
|
await self._fill()
|
||||||
|
self._task = asyncio.create_task(self._refill_loop())
|
||||||
|
logger.info(
|
||||||
|
"connpool: started hop=%s size=%d max_idle=%.0fs ready=%d",
|
||||||
|
self._hop, self._size, self._max_idle, len(self._pool),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Drain the pool and cancel the refill task."""
|
||||||
|
self._stop.set()
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._task = None
|
||||||
|
while self._pool:
|
||||||
|
conn = self._pool.popleft()
|
||||||
|
self._close_conn(conn)
|
||||||
|
logger.debug("connpool: stopped")
|
||||||
|
|
||||||
|
async def acquire(self) -> tuple[asyncio.StreamReader, asyncio.StreamWriter] | None:
|
||||||
|
"""Return a pre-established connection, or None if pool is empty.
|
||||||
|
|
||||||
|
The caller must fall back to ``asyncio.open_connection()`` when
|
||||||
|
None is returned. Stale connections are evicted on acquire.
|
||||||
|
"""
|
||||||
|
now = time.monotonic()
|
||||||
|
while self._pool:
|
||||||
|
conn = self._pool.popleft()
|
||||||
|
if now - conn.created > self._max_idle:
|
||||||
|
self._close_conn(conn)
|
||||||
|
continue
|
||||||
|
# Verify the connection is still open
|
||||||
|
if conn.writer.is_closing():
|
||||||
|
continue
|
||||||
|
return conn.reader, conn.writer
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fill(self) -> None:
|
||||||
|
"""Top up the pool to target size."""
|
||||||
|
needed = self._size - len(self._pool)
|
||||||
|
if needed <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = [self._open_one() for _ in range(needed)]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, _PooledConn):
|
||||||
|
self._pool.append(result)
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
logger.debug("connpool: pre-warm failed: %s", result)
|
||||||
|
|
||||||
|
async def _open_one(self) -> _PooledConn:
|
||||||
|
"""Open a single TCP connection to the first hop."""
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_connection(self._hop.host, self._hop.port),
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
return _PooledConn(reader=reader, writer=writer, created=time.monotonic())
|
||||||
|
|
||||||
|
async def _refill_loop(self) -> None:
|
||||||
|
"""Periodically evict stale connections and refill."""
|
||||||
|
while not self._stop.is_set():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._stop.wait(), timeout=self._max_idle / 2)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
if self._stop.is_set():
|
||||||
|
break
|
||||||
|
self._evict_stale()
|
||||||
|
try:
|
||||||
|
await self._fill()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("connpool: refill error: %s", e)
|
||||||
|
|
||||||
|
def _evict_stale(self) -> None:
|
||||||
|
"""Remove connections older than max_idle."""
|
||||||
|
now = time.monotonic()
|
||||||
|
while self._pool and now - self._pool[0].created > self._max_idle:
|
||||||
|
conn = self._pool.popleft()
|
||||||
|
self._close_conn(conn)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _close_conn(conn: _PooledConn) -> None:
|
||||||
|
"""Close a pooled connection silently."""
|
||||||
|
try:
|
||||||
|
conn.writer.close()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
@@ -8,9 +8,13 @@ import logging
|
|||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .config import ChainHop
|
from .config import ChainHop
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .connpool import FirstHopPool
|
||||||
|
|
||||||
logger = logging.getLogger("s5p")
|
logger = logging.getLogger("s5p")
|
||||||
|
|
||||||
|
|
||||||
@@ -214,11 +218,14 @@ async def build_chain(
|
|||||||
target_host: str,
|
target_host: str,
|
||||||
target_port: int,
|
target_port: int,
|
||||||
timeout: float = 10.0,
|
timeout: float = 10.0,
|
||||||
|
first_hop_pool: FirstHopPool | None = None,
|
||||||
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||||
"""Build a tunnel through the proxy chain to the target.
|
"""Build a tunnel through the proxy chain to the target.
|
||||||
|
|
||||||
Connects to the first hop via TCP, then negotiates each subsequent
|
Connects to the first hop via TCP, then negotiates each subsequent
|
||||||
hop over the tunnel established by the previous one.
|
hop over the tunnel established by the previous one. If a
|
||||||
|
``first_hop_pool`` is provided, attempts to reuse a pre-warmed
|
||||||
|
TCP connection for the initial hop.
|
||||||
"""
|
"""
|
||||||
if not chain:
|
if not chain:
|
||||||
return await asyncio.wait_for(
|
return await asyncio.wait_for(
|
||||||
@@ -226,10 +233,14 @@ async def build_chain(
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
reader, writer = await asyncio.wait_for(
|
conn = await first_hop_pool.acquire() if first_hop_pool else None
|
||||||
asyncio.open_connection(chain[0].host, chain[0].port),
|
if conn:
|
||||||
timeout=timeout,
|
reader, writer = conn
|
||||||
)
|
else:
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_connection(chain[0].host, chain[0].port),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i, hop in enumerate(chain):
|
for i, hop in enumerate(chain):
|
||||||
|
|||||||
105
tests/test_connpool.py
Normal file
105
tests/test_connpool.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Tests for the first-hop connection pool."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from s5p.config import ChainHop
|
||||||
|
from s5p.connpool import FirstHopPool
|
||||||
|
|
||||||
|
|
||||||
|
async def _echo_server(host="127.0.0.1", port=0):
|
||||||
|
"""Start a TCP server that accepts connections and holds them open."""
|
||||||
|
|
||||||
|
async def handler(reader, writer):
|
||||||
|
try:
|
||||||
|
await reader.read(1) # block until client disconnects
|
||||||
|
except (ConnectionError, asyncio.CancelledError):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
server = await asyncio.start_server(handler, host, port)
|
||||||
|
port = server.sockets[0].getsockname()[1]
|
||||||
|
return server, port
|
||||||
|
|
||||||
|
|
||||||
|
class TestFirstHopPool:
|
||||||
|
"""Test connection pool lifecycle."""
|
||||||
|
|
||||||
|
def test_acquire_returns_connection(self):
|
||||||
|
async def run():
|
||||||
|
server, port = await _echo_server()
|
||||||
|
async with server:
|
||||||
|
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
|
||||||
|
pool = FirstHopPool(hop, size=2, max_idle=10.0)
|
||||||
|
await pool.start()
|
||||||
|
try:
|
||||||
|
conn = await pool.acquire()
|
||||||
|
assert conn is not None
|
||||||
|
reader, writer = conn
|
||||||
|
assert not writer.is_closing()
|
||||||
|
writer.close()
|
||||||
|
finally:
|
||||||
|
await pool.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_acquire_exhausts_pool(self):
|
||||||
|
async def run():
|
||||||
|
server, port = await _echo_server()
|
||||||
|
async with server:
|
||||||
|
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
|
||||||
|
pool = FirstHopPool(hop, size=2, max_idle=10.0)
|
||||||
|
await pool.start()
|
||||||
|
try:
|
||||||
|
c1 = await pool.acquire()
|
||||||
|
c2 = await pool.acquire()
|
||||||
|
c3 = await pool.acquire()
|
||||||
|
assert c1 is not None
|
||||||
|
assert c2 is not None
|
||||||
|
assert c3 is None # pool exhausted
|
||||||
|
c1[1].close()
|
||||||
|
c2[1].close()
|
||||||
|
finally:
|
||||||
|
await pool.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_stale_connections_evicted(self):
|
||||||
|
async def run():
|
||||||
|
server, port = await _echo_server()
|
||||||
|
async with server:
|
||||||
|
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
|
||||||
|
pool = FirstHopPool(hop, size=2, max_idle=0.05)
|
||||||
|
await pool._fill() # pre-warm without starting refill loop
|
||||||
|
assert len(pool._pool) == 2
|
||||||
|
await asyncio.sleep(0.1) # let connections go stale
|
||||||
|
conn = await pool.acquire()
|
||||||
|
assert conn is None # all stale, evicted
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_stop_drains_pool(self):
|
||||||
|
async def run():
|
||||||
|
server, port = await _echo_server()
|
||||||
|
async with server:
|
||||||
|
hop = ChainHop(proto="socks5", host="127.0.0.1", port=port)
|
||||||
|
pool = FirstHopPool(hop, size=4, max_idle=30.0)
|
||||||
|
await pool.start()
|
||||||
|
await pool.stop()
|
||||||
|
# Pool should be empty after stop
|
||||||
|
conn = await pool.acquire()
|
||||||
|
assert conn is None
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_unreachable_hop_graceful(self):
|
||||||
|
"""Pool creation with an unreachable hop should not raise."""
|
||||||
|
async def run():
|
||||||
|
hop = ChainHop(proto="socks5", host="127.0.0.1", port=1) # nothing listening
|
||||||
|
pool = FirstHopPool(hop, size=2, max_idle=10.0)
|
||||||
|
await pool.start() # should not raise, just log warnings
|
||||||
|
conn = await pool.acquire()
|
||||||
|
assert conn is None # no connections could be established
|
||||||
|
await pool.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
Reference in New Issue
Block a user