diff --git a/src/s5p/http.py b/src/s5p/http.py new file mode 100644 index 0000000..f29c12a --- /dev/null +++ b/src/s5p/http.py @@ -0,0 +1,139 @@ +"""Minimal async HTTP/1.1 client using asyncio streams.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import ssl +from urllib.parse import urlparse + +logger = logging.getLogger("s5p") + + +async def http_get_json(url: str, timeout: float = 10.0) -> dict: + """GET a URL and return parsed JSON response.""" + parsed = urlparse(url) + host = parsed.hostname or "" + use_tls = parsed.scheme == "https" + default_port = 443 if use_tls else 80 + port = parsed.port or default_port + path = parsed.path or "/" + if parsed.query: + path = f"{path}?{parsed.query}" + + ssl_ctx = ssl.create_default_context() if use_tls else None + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port, ssl=ssl_ctx), + timeout=timeout, + ) + + try: + request = ( + f"GET {path} HTTP/1.1\r\n" + f"Host: {host}\r\n" + f"Accept: application/json\r\n" + f"Connection: close\r\n" + f"\r\n" + ) + writer.write(request.encode()) + await writer.drain() + + body = await _read_response(reader, timeout) + return json.loads(body) + finally: + writer.close() + try: + await writer.wait_closed() + except OSError: + pass + + +async def http_post_json(url: str, payload: dict, timeout: float = 10.0) -> None: + """POST JSON body to a URL (fire-and-forget).""" + parsed = urlparse(url) + host = parsed.hostname or "" + use_tls = parsed.scheme == "https" + default_port = 443 if use_tls else 80 + port = parsed.port or default_port + path = parsed.path or "/" + if parsed.query: + path = f"{path}?{parsed.query}" + + ssl_ctx = ssl.create_default_context() if use_tls else None + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port, ssl=ssl_ctx), + timeout=timeout, + ) + + try: + body = json.dumps(payload).encode() + request = ( + f"POST {path} HTTP/1.1\r\n" + f"Host: {host}\r\n" + f"Content-Type: application/json\r\n" + f"Content-Length: {len(body)}\r\n" + f"Connection: close\r\n" + f"\r\n" + ) + writer.write(request.encode() + body) + await writer.drain() + + # Read response status to detect errors, but don't require body. + status_line = await asyncio.wait_for(reader.readline(), timeout=timeout) + parts = status_line.decode("utf-8", errors="replace").split(None, 2) + if len(parts) >= 2 and not parts[1].startswith("2"): + logger.debug("http_post_json: %s returned %s", url, parts[1]) + finally: + writer.close() + try: + await writer.wait_closed() + except OSError: + pass + + +async def _read_response(reader: asyncio.StreamReader, timeout: float) -> bytes: + """Read HTTP response, return body bytes. Handles Content-Length and chunked.""" + # Status line + status_line = await asyncio.wait_for(reader.readline(), timeout=timeout) + parts = status_line.decode("utf-8", errors="replace").split(None, 2) + if len(parts) < 2 or not parts[1].startswith("2"): + status = parts[1] if len(parts) >= 2 else "?" + raise OSError(f"HTTP {status}: {status_line.decode(errors='replace').strip()}") + + # Headers + content_length = -1 + chunked = False + while True: + line = await asyncio.wait_for(reader.readline(), timeout=timeout) + if line in (b"\r\n", b"\n", b""): + break + header = line.decode("utf-8", errors="replace").lower() + if header.startswith("content-length:"): + content_length = int(header.split(":", 1)[1].strip()) + elif header.startswith("transfer-encoding:") and "chunked" in header: + chunked = True + + # Body + if chunked: + return await _read_chunked(reader, timeout) + elif content_length >= 0: + return await asyncio.wait_for(reader.readexactly(content_length), timeout=timeout) + else: + # Connection: close -- read until EOF + return await asyncio.wait_for(reader.read(), timeout=timeout) + + +async def _read_chunked(reader: asyncio.StreamReader, timeout: float) -> bytes: + """Read chunked transfer-encoding body.""" + parts: list[bytes] = [] + while True: + size_line = await asyncio.wait_for(reader.readline(), timeout=timeout) + size = int(size_line.strip(), 16) + if size == 0: + await asyncio.wait_for(reader.readline(), timeout=timeout) # trailing CRLF + break + chunk = await asyncio.wait_for(reader.readexactly(size), timeout=timeout) + parts.append(chunk) + await asyncio.wait_for(reader.readline(), timeout=timeout) # trailing CRLF + return b"".join(parts) diff --git a/src/s5p/pool.py b/src/s5p/pool.py index cf8e905..ca5a0ef 100644 --- a/src/s5p/pool.py +++ b/src/s5p/pool.py @@ -8,12 +8,12 @@ import logging import os import random import time -import urllib.request from dataclasses import dataclass from pathlib import Path from urllib.parse import urlencode, urlparse from .config import ChainHop, PoolSourceConfig, ProxyPoolConfig, parse_proxy_url +from .http import http_get_json, http_post_json from .proto import ProtoError, build_chain logger = logging.getLogger("s5p") @@ -152,26 +152,35 @@ class ProxyPool: # -- source fetching ----------------------------------------------------- async def _fetch_all_sources(self) -> None: - """Fetch proxies from all configured sources and merge.""" - loop = asyncio.get_running_loop() + """Fetch proxies from all configured sources in parallel and merge.""" + + async def _fetch_one(src: PoolSourceConfig) -> list[ChainHop]: + if src.url: + return await self._fetch_api(src) + elif src.file: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._fetch_file_sync, src) + return [] + + results = await asyncio.gather( + *[_fetch_one(s) for s in self._cfg.sources], + return_exceptions=True, + ) + proxies: list[ChainHop] = [] - for src in self._cfg.sources: - try: - if src.url: - batch = await loop.run_in_executor(None, self._fetch_api_sync, src) - logger.info("pool: fetched %d proxies from %s", len(batch), src.url) - proxies.extend(batch) - elif src.file: - batch = await loop.run_in_executor(None, self._fetch_file_sync, src) - logger.info("pool: loaded %d proxies from %s", len(batch), src.file) - proxies.extend(batch) - except Exception as e: - label = src.url or src.file or "?" - logger.warning("pool: source %s failed: %s", label, e) + for i, result in enumerate(results): + src = self._cfg.sources[i] + label = src.url or src.file or "?" + if isinstance(result, Exception): + logger.warning("pool: source %s failed: %s", label, result) + else: + kind = "fetched" if src.url else "loaded" + logger.info("pool: %s %d proxies from %s", kind, len(result), label) + proxies.extend(result) self._merge(proxies) - def _fetch_api_sync(self, src: PoolSourceConfig) -> list[ChainHop]: - """Fetch proxies from an HTTP API (runs in executor).""" + async def _fetch_api(self, src: PoolSourceConfig) -> list[ChainHop]: + """Fetch proxies from an HTTP API (async).""" params: dict[str, str] = {} if src.limit: params["limit"] = str(src.limit) @@ -185,9 +194,7 @@ class ProxyPool: sep = "&" if "?" in url else "?" url = f"{url}{sep}{urlencode(params)}" - req = urllib.request.Request(url, headers={"Accept": "application/json"}) - with urllib.request.urlopen(req, timeout=10) as resp: - data = json.loads(resp.read()) + data = await http_get_json(url) proxies: list[ChainHop] = [] for entry in data.get("proxies", []): @@ -405,7 +412,7 @@ class ProxyPool: asyncio.ensure_future(self._report_dead(dead)) async def _report_dead(self, keys: list[str]) -> None: - """POST dead proxy list to report_url (fire-and-forget).""" + """POST dead proxy list to report_url (fire-and-forget, async).""" dead = [] for key in keys: # key format: proto://host:port @@ -416,25 +423,12 @@ class ProxyPool: if not dead: return - loop = asyncio.get_running_loop() try: - await loop.run_in_executor(None, self._report_sync, dead) + await http_post_json(self._cfg.report_url, {"dead": dead}) logger.info("pool: reported %d dead proxies to %s", len(dead), self._cfg.report_url) except Exception as e: logger.debug("pool: report failed: %s", e) - def _report_sync(self, dead: list[dict[str, str]]) -> None: - """Synchronous POST to report_url (runs in executor).""" - payload = json.dumps({"dead": dead}).encode() - req = urllib.request.Request( - self._cfg.report_url, - data=payload, - headers={"Content-Type": "application/json"}, - method="POST", - ) - with urllib.request.urlopen(req, timeout=10): - pass - def _rebuild_alive(self) -> None: """Rebuild the alive keys list from current state.""" self._alive_keys = [k for k, e in self._proxies.items() if e.alive] diff --git a/src/s5p/source.py b/src/s5p/source.py index a742248..8e7cbbd 100644 --- a/src/s5p/source.py +++ b/src/s5p/source.py @@ -3,14 +3,13 @@ from __future__ import annotations import asyncio -import json import logging import random import time -import urllib.request from urllib.parse import urlencode from .config import ChainHop, ProxySourceConfig +from .http import http_get_json logger = logging.getLogger("s5p") @@ -49,11 +48,10 @@ class ProxySource: return random.choice(self._cache) async def _refresh(self) -> None: - """Fetch proxy list from the API.""" + """Fetch proxy list from the API (async).""" async with self._lock: - loop = asyncio.get_running_loop() try: - proxies = await loop.run_in_executor(None, self._fetch_sync) + proxies = await self._fetch() self._cache = proxies self._last_fetch = time.monotonic() logger.info("proxy source: loaded %d proxies", len(proxies)) @@ -62,8 +60,8 @@ class ProxySource: if self._cache: logger.info("proxy source: using stale cache (%d proxies)", len(self._cache)) - def _fetch_sync(self) -> list[ChainHop]: - """Synchronous HTTP fetch (runs in executor).""" + async def _fetch(self) -> list[ChainHop]: + """Async HTTP fetch.""" params: dict[str, str] = {} if self._cfg.limit: params["limit"] = str(self._cfg.limit) @@ -77,9 +75,7 @@ class ProxySource: sep = "&" if "?" in url else "?" url = f"{url}{sep}{urlencode(params)}" - req = urllib.request.Request(url, headers={"Accept": "application/json"}) - with urllib.request.urlopen(req, timeout=10) as resp: - data = json.loads(resp.read()) + data = await http_get_json(url) proxies: list[ChainHop] = [] for entry in data.get("proxies", []): diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 0000000..56fe4f6 --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,153 @@ +"""Tests for the async HTTP client.""" + +import asyncio +import json + +import pytest + +from s5p.http import http_get_json, http_post_json + + +async def _run_mock_server(handler, host="127.0.0.1", port=0): + """Start a mock TCP server, return (server, port).""" + server = await asyncio.start_server(handler, host, port) + port = server.sockets[0].getsockname()[1] + return server, port + + +class TestHttpGetJson: + """Test async HTTP GET.""" + + def test_basic_get(self): + payload = {"proxies": [{"proto": "socks5", "proxy": "1.2.3.4:1080"}]} + body = json.dumps(payload).encode() + + async def handler(reader, writer): + await reader.readline() # request line + while (await reader.readline()) not in (b"\r\n", b"\n", b""): + pass + writer.write( + f"HTTP/1.1 200 OK\r\n" + f"Content-Length: {len(body)}\r\n" + f"Connection: close\r\n" + f"\r\n".encode() + + body + ) + await writer.drain() + writer.close() + + async def run(): + server, port = await _run_mock_server(handler) + async with server: + result = await http_get_json(f"http://127.0.0.1:{port}/test") + assert result == payload + + asyncio.run(run()) + + def test_chunked_response(self): + payload = {"status": "ok"} + body = json.dumps(payload).encode() + + async def handler(reader, writer): + await reader.readline() + while (await reader.readline()) not in (b"\r\n", b"\n", b""): + pass + writer.write( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Connection: close\r\n" + b"\r\n" + ) + # Send body as a single chunk + writer.write(f"{len(body):x}\r\n".encode() + body + b"\r\n") + writer.write(b"0\r\n\r\n") + await writer.drain() + writer.close() + + async def run(): + server, port = await _run_mock_server(handler) + async with server: + result = await http_get_json(f"http://127.0.0.1:{port}/chunked") + assert result == payload + + asyncio.run(run()) + + def test_error_status(self): + async def handler(reader, writer): + await reader.readline() + while (await reader.readline()) not in (b"\r\n", b"\n", b""): + pass + writer.write(b"HTTP/1.1 500 Internal Server Error\r\nConnection: close\r\n\r\n") + await writer.drain() + writer.close() + + async def run(): + server, port = await _run_mock_server(handler) + async with server: + with pytest.raises(OSError, match="HTTP 500"): + await http_get_json(f"http://127.0.0.1:{port}/fail") + + asyncio.run(run()) + + def test_connection_close_body(self): + """Server sends no Content-Length, just closes connection.""" + payload = {"key": "value"} + body = json.dumps(payload).encode() + + async def handler(reader, writer): + await reader.readline() + while (await reader.readline()) not in (b"\r\n", b"\n", b""): + pass + writer.write( + b"HTTP/1.1 200 OK\r\n" + b"Connection: close\r\n" + b"\r\n" + + body + ) + await writer.drain() + writer.close() + await writer.wait_closed() + + async def run(): + server, port = await _run_mock_server(handler) + async with server: + result = await http_get_json(f"http://127.0.0.1:{port}/nosize") + assert result == payload + + asyncio.run(run()) + + +class TestHttpPostJson: + """Test async HTTP POST.""" + + def test_basic_post(self): + received = {} + + async def handler(reader, writer): + request_line = await reader.readline() + received["method"] = request_line.decode().split()[0] + content_length = 0 + while True: + line = await reader.readline() + if line in (b"\r\n", b"\n", b""): + break + header = line.decode().lower() + if header.startswith("content-length:"): + content_length = int(header.split(":", 1)[1].strip()) + body = await reader.readexactly(content_length) + received["body"] = json.loads(body) + writer.write(b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n") + await writer.drain() + writer.close() + + async def run(): + server, port = await _run_mock_server(handler) + async with server: + await http_post_json( + f"http://127.0.0.1:{port}/report", + {"dead": [{"proto": "socks5", "proxy": "1.2.3.4:1080"}]}, + ) + assert received["method"] == "POST" + assert received["body"]["dead"][0]["proto"] == "socks5" + + asyncio.run(run()) diff --git a/tests/test_pool.py b/tests/test_pool.py index 0cf7540..1e30986 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,6 +1,5 @@ """Tests for the managed proxy pool.""" -import json import time import pytest @@ -304,21 +303,20 @@ class TestProxyPoolReport: asyncio.run(pool._run_health_tests()) mock_report.assert_not_called() - def test_report_sync_payload(self): - from unittest.mock import MagicMock, patch + def test_report_async_payload(self): + import asyncio + from unittest.mock import AsyncMock, patch cfg = ProxyPoolConfig(sources=[], report_url="http://api:8081/report") pool = ProxyPool(cfg, [], timeout=10.0) - dead = [{"proto": "socks5", "proxy": "10.0.0.1:1080"}] - with patch("s5p.pool.urllib.request.urlopen", new_callable=MagicMock) as mock_open: - mock_open.return_value.__enter__ = MagicMock() - mock_open.return_value.__exit__ = MagicMock(return_value=False) - pool._report_sync(dead) - req = mock_open.call_args[0][0] - assert req.method == "POST" - assert req.full_url == "http://api:8081/report" - assert b'"dead"' in req.data + with patch("s5p.pool.http_post_json", new_callable=AsyncMock) as mock_post: + asyncio.run(pool._report_dead(["socks5://10.0.0.1:1080"])) + mock_post.assert_called_once() + url = mock_post.call_args[0][0] + payload = mock_post.call_args[0][1] + assert url == "http://api:8081/report" + assert payload == {"dead": [{"proto": "socks5", "proxy": "10.0.0.1:1080"}]} class TestProxyPoolStaleExpiry: