""" Tests for reverse proxy functionality. These tests start a backend test server and the main PyServe server, then verify that requests are correctly proxied between them. """ import asyncio import pytest import httpx import socket import time from contextlib import asynccontextmanager from typing import AsyncGenerator, Dict, Any, Optional from unittest.mock import patch import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Route from pyserve.config import Config, ServerConfig, HttpConfig, LoggingConfig, ExtensionConfig from pyserve.server import PyServeServer def get_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] # ============== Backend Test Application ============== class BackendTestApp: def __init__(self, port: int): self.port = port self.request_log: list[Dict[str, Any]] = [] self.app = self._create_app() self._server_task = None def _create_app(self) -> Starlette: routes = [ Route("/", self._root, methods=["GET"]), Route("/api/v{version}/users", self._users, methods=["GET", "POST"]), Route("/api/v{version}/users/{user_id}", self._user_detail, methods=["GET", "PUT", "DELETE"]), Route("/api/v{version}/data", self._data, methods=["GET", "POST"]), Route("/echo", self._echo, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]), Route("/headers", self._headers, methods=["GET"]), Route("/slow", self._slow, methods=["GET"]), Route("/status/{code:int}", self._status, methods=["GET"]), Route("/json", self._json, methods=["POST"]), Route("/backend2/", self._backend2_root, methods=["GET"]), Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]), ] return Starlette(routes=routes) async def _root(self, request: Request) -> Response: self._log_request(request) return JSONResponse({"message": "Backend root", "port": self.port}) async def _users(self, request: Request) -> Response: self._log_request(request) version = request.path_params.get("version", "1") if request.method == "POST": body = await request.json() return JSONResponse({"action": "create_user", "version": version, "data": body}, status_code=201) return JSONResponse({"users": [{"id": 1, "name": "Test User"}], "version": version}) async def _user_detail(self, request: Request) -> Response: self._log_request(request) version = request.path_params.get("version", "1") user_id = request.path_params.get("user_id", "0") if request.method == "DELETE": return JSONResponse({"action": "delete_user", "user_id": user_id, "version": version}) if request.method == "PUT": body = await request.json() return JSONResponse({"action": "update_user", "user_id": user_id, "version": version, "data": body}) return JSONResponse({"user": {"id": user_id, "name": f"User {user_id}"}, "version": version}) async def _data(self, request: Request) -> Response: self._log_request(request) version = request.path_params.get("version", "1") return JSONResponse({"data": "test data", "version": version}) async def _echo(self, request: Request) -> Response: self._log_request(request) body = await request.body() return Response( content=body, status_code=200, media_type=request.headers.get("content-type", "text/plain") ) async def _headers(self, request: Request) -> Response: self._log_request(request) headers = dict(request.headers) return JSONResponse({ "received_headers": headers, "client_ip": request.client.host if request.client else None, }) async def _slow(self, request: Request) -> Response: self._log_request(request) delay = float(request.query_params.get("delay", "2")) await asyncio.sleep(delay) return JSONResponse({"message": "slow response", "delay": delay}) async def _status(self, request: Request) -> Response: self._log_request(request) code = request.path_params.get("code", 200) return PlainTextResponse(f"Status: {code}", status_code=code) async def _json(self, request: Request) -> Response: self._log_request(request) body = await request.json() return JSONResponse({"received": body, "processed": True}) async def _backend2_root(self, request: Request) -> Response: self._log_request(request) return JSONResponse({"message": "Backend2 root", "port": self.port}) async def _catch_all(self, request: Request) -> Response: """Catch-all handler for debugging unmatched routes.""" self._log_request(request) return JSONResponse({ "message": "Catch-all", "path": str(request.url.path), "method": request.method, "port": self.port }) def _log_request(self, request: Request) -> None: self.request_log.append({ "method": request.method, "path": str(request.url.path), "query": str(request.url.query), "headers": dict(request.headers), }) async def start(self) -> None: config = uvicorn.Config( app=self.app, host="127.0.0.1", port=self.port, log_level="critical", access_log=False, ) server = uvicorn.Server(config) self._server_task = asyncio.create_task(server.serve()) # Wait for server to be ready for _ in range(50): # 5 seconds max try: async with httpx.AsyncClient() as client: await client.get(f"http://127.0.0.1:{self.port}/") return except httpx.ConnectError: await asyncio.sleep(0.1) raise RuntimeError(f"Backend server failed to start on port {self.port}") async def stop(self) -> None: if self._server_task: self._server_task.cancel() try: await self._server_task except asyncio.CancelledError: pass # ============== PyServe Test Server ============== class PyServeTestServer: def __init__(self, config: Config): self.config = config self.server = PyServeServer(config) self._server_task = None async def start(self) -> None: assert self.server.app is not None, "Server app not initialized" config = uvicorn.Config( app=self.server.app, host=self.config.server.host, port=self.config.server.port, log_level="critical", access_log=False, ) server = uvicorn.Server(config) self._server_task = asyncio.create_task(server.serve()) # Wait for server to be ready for _ in range(50): # 5 seconds max try: async with httpx.AsyncClient() as client: await client.get(f"http://127.0.0.1:{self.config.server.port}/health") return except httpx.ConnectError: await asyncio.sleep(0.1) raise RuntimeError(f"PyServe server failed to start on port {self.config.server.port}") async def stop(self) -> None: if self._server_task: self._server_task.cancel() try: await self._server_task except asyncio.CancelledError: pass # ============== Fixtures ============== @pytest.fixture def backend_port() -> int: return get_free_port() @pytest.fixture def pyserve_port() -> int: return get_free_port() @pytest.fixture def create_proxy_config(): def _create_config(pyserve_port: int, backend_port: int, extra_locations: Optional[Dict[str, Any]] = None) -> Config: locations = { # API proxy with version capture "~^/api/v(?P\\d+)/": { "proxy_pass": f"http://127.0.0.1:{backend_port}", "headers": [ "X-API-Version: {version}", "X-Forwarded-For: $remote_addr" ] }, # Echo endpoint proxy "=/echo": { "proxy_pass": f"http://127.0.0.1:{backend_port}/echo", }, # Headers test proxy "=/headers": { "proxy_pass": f"http://127.0.0.1:{backend_port}/headers", }, # Slow endpoint proxy with timeout "~^/slow": { "proxy_pass": f"http://127.0.0.1:{backend_port}/slow", "timeout": 5.0, }, # Status endpoint proxy "~^/status/": { "proxy_pass": f"http://127.0.0.1:{backend_port}", }, # JSON endpoint proxy "=/json": { "proxy_pass": f"http://127.0.0.1:{backend_port}/json", }, # Health check "=/health": { "return": "200 OK", "content_type": "text/plain" }, # Default fallback "__default__": { "return": "404 Not Found", "content_type": "text/plain" } } if extra_locations: locations.update(extra_locations) config = Config( http=HttpConfig(static_dir="./static", templates_dir="./templates"), server=ServerConfig(host="127.0.0.1", port=pyserve_port), logging=LoggingConfig(level="ERROR", console_output=False), extensions=[ ExtensionConfig( type="routing", config={"regex_locations": locations} ) ] ) return config return _create_config @asynccontextmanager async def running_servers( pyserve_port: int, backend_port: int, config_factory, extra_locations: Optional[Dict[str, Any]] = None ) -> AsyncGenerator[tuple[PyServeTestServer, BackendTestApp], None]: backend = BackendTestApp(backend_port) config = config_factory(pyserve_port, backend_port, extra_locations) pyserve = PyServeTestServer(config) await backend.start() await pyserve.start() try: yield pyserve, backend finally: await pyserve.stop() await backend.stop() # ============== Tests ============== @pytest.mark.asyncio async def test_basic_proxy_get(backend_port, pyserve_port, create_proxy_config): """Test basic GET request proxying.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/users") assert response.status_code == 200 data = response.json() assert "users" in data assert data["version"] == "1" @pytest.mark.asyncio async def test_proxy_api_versions(backend_port, pyserve_port, create_proxy_config): """Test that API version is correctly captured and passed.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: # Test v1 response_v1 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/data") assert response_v1.status_code == 200 assert response_v1.json()["version"] == "1" # Test v2 response_v2 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v2/data") assert response_v2.status_code == 200 assert response_v2.json()["version"] == "2" @pytest.mark.asyncio async def test_proxy_post_with_body(backend_port, pyserve_port, create_proxy_config): """Test POST request with JSON body.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: payload = {"name": "New User", "email": "test@example.com"} response = await client.post( f"http://127.0.0.1:{pyserve_port}/api/v1/users", json=payload ) assert response.status_code == 201 data = response.json() assert data["action"] == "create_user" assert data["data"]["name"] == "New User" @pytest.mark.asyncio async def test_proxy_put_request(backend_port, pyserve_port, create_proxy_config): """Test PUT request proxying.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: payload = {"name": "Updated User"} response = await client.put( f"http://127.0.0.1:{pyserve_port}/api/v1/users/123", json=payload ) assert response.status_code == 200 data = response.json() assert data["action"] == "update_user" assert data["user_id"] == "123" @pytest.mark.asyncio async def test_proxy_delete_request(backend_port, pyserve_port, create_proxy_config): """Test DELETE request proxying.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.delete( f"http://127.0.0.1:{pyserve_port}/api/v1/users/456" ) assert response.status_code == 200 data = response.json() assert data["action"] == "delete_user" assert data["user_id"] == "456" @pytest.mark.asyncio async def test_proxy_headers_forwarding(backend_port, pyserve_port, create_proxy_config): """Test that headers are correctly forwarded.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/headers", headers={"X-Custom-Header": "test-value"} ) assert response.status_code == 200 data = response.json() # Check that custom header was forwarded assert data["received_headers"].get("x-custom-header") == "test-value" # Check that X-Forwarded headers were added assert "x-forwarded-for" in data["received_headers"] assert "x-forwarded-proto" in data["received_headers"] @pytest.mark.asyncio async def test_proxy_echo_endpoint(backend_port, pyserve_port, create_proxy_config): """Test echo endpoint that returns request body.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: test_data = "Hello, Proxy!" response = await client.post( f"http://127.0.0.1:{pyserve_port}/echo", content=test_data, headers={"Content-Type": "text/plain"} ) assert response.status_code == 200 assert response.text == test_data @pytest.mark.asyncio async def test_proxy_json_endpoint(backend_port, pyserve_port, create_proxy_config): """Test JSON processing through proxy.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: payload = {"key": "value", "numbers": [1, 2, 3]} response = await client.post( f"http://127.0.0.1:{pyserve_port}/json", json=payload ) assert response.status_code == 200 data = response.json() assert data["processed"] is True assert data["received"]["key"] == "value" @pytest.mark.asyncio async def test_proxy_status_codes(backend_port, pyserve_port, create_proxy_config): """Test that various status codes are correctly proxied.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: for status in [200, 201, 400, 404, 500]: response = await client.get( f"http://127.0.0.1:{pyserve_port}/status/{status}" ) assert response.status_code == status @pytest.mark.asyncio async def test_proxy_query_params(backend_port, pyserve_port, create_proxy_config): """Test that query parameters are passed through.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/slow?delay=0.1" ) assert response.status_code == 200 data = response.json() assert data["delay"] == 0.1 @pytest.mark.asyncio async def test_proxy_backend_unavailable(pyserve_port, create_proxy_config): """Test handling when backend is unavailable (502 Bad Gateway).""" # Use a port where nothing is listening unavailable_port = get_free_port() config = create_proxy_config(pyserve_port, unavailable_port) pyserve = PyServeTestServer(config) await pyserve.start() try: async with httpx.AsyncClient() as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/api/v1/users" ) assert response.status_code == 502 assert "Bad Gateway" in response.text finally: await pyserve.stop() @pytest.mark.asyncio async def test_proxy_timeout(backend_port, pyserve_port, create_proxy_config): """Test proxy timeout handling (504 Gateway Timeout).""" # Create config with very short timeout extra_locations = { "=/timeout-test": { "proxy_pass": f"http://127.0.0.1:{backend_port}/slow?delay=5", "timeout": 0.5, # Very short timeout } } async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations): async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/timeout-test" ) assert response.status_code == 504 assert "Gateway Timeout" in response.text @pytest.mark.asyncio async def test_proxy_health_check_not_proxied(backend_port, pyserve_port, create_proxy_config): """Test that health check endpoint is handled locally, not proxied.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/health" ) assert response.status_code == 200 assert response.text == "OK" @pytest.mark.asyncio async def test_proxy_default_fallback(backend_port, pyserve_port, create_proxy_config): """Test default fallback for unmatched routes.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/nonexistent/path" ) assert response.status_code == 404 assert "Not Found" in response.text @pytest.mark.asyncio async def test_complex_config_multiple_proxies(backend_port, pyserve_port, create_proxy_config): """Test complex configuration with multiple proxy rules.""" # Create a second backend port for testing multiple backends backend2_port = get_free_port() backend2 = BackendTestApp(backend2_port) extra_locations = { "~^/backend2/": { "proxy_pass": f"http://127.0.0.1:{backend2_port}", } } async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations): await backend2.start() try: async with httpx.AsyncClient() as client: # Request to first backend response1 = await client.get( f"http://127.0.0.1:{pyserve_port}/api/v1/users" ) assert response1.status_code == 200 assert response1.json()["users"][0]["id"] == 1 # Request to second backend response2 = await client.get( f"http://127.0.0.1:{pyserve_port}/backend2/" ) assert response2.status_code == 200 data2 = response2.json() assert data2["port"] == backend2_port finally: await backend2.stop() @pytest.mark.asyncio async def test_proxy_large_request_body(backend_port, pyserve_port, create_proxy_config): """Test proxying large request bodies.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: # Create a large payload large_data = "x" * 100000 # 100KB of data response = await client.post( f"http://127.0.0.1:{pyserve_port}/echo", content=large_data, headers={"Content-Type": "text/plain"} ) assert response.status_code == 200 assert response.text == large_data @pytest.mark.asyncio async def test_proxy_content_type_preservation(backend_port, pyserve_port, create_proxy_config): """Test that content-type is correctly preserved.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: response = await client.get( f"http://127.0.0.1:{pyserve_port}/api/v1/users" ) assert response.status_code == 200 assert "application/json" in response.headers.get("content-type", "") @pytest.mark.asyncio async def test_proxy_concurrent_requests(backend_port, pyserve_port, create_proxy_config): """Test handling multiple concurrent proxy requests.""" async with running_servers(pyserve_port, backend_port, create_proxy_config): async with httpx.AsyncClient() as client: # Send multiple concurrent requests tasks = [ client.get(f"http://127.0.0.1:{pyserve_port}/api/v{i % 3 + 1}/users") for i in range(10) ] responses = await asyncio.gather(*tasks) # All should succeed for response in responses: assert response.status_code == 200 assert "users" in response.json() @pytest.mark.asyncio async def test_server_default_proxy_timeout(backend_port, pyserve_port): """Test that server-level proxy_timeout is used as default.""" locations = { "~^/slow": { "proxy_pass": f"http://127.0.0.1:{backend_port}/slow", # No timeout specified - should use server default }, "=/health": { "return": "200 OK", "content_type": "text/plain" }, } config = Config( http=HttpConfig(static_dir="./static", templates_dir="./templates"), server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.5), # Very short timeout logging=LoggingConfig(level="ERROR", console_output=False), extensions=[ ExtensionConfig( type="routing", config={"regex_locations": locations} ) ] ) backend = BackendTestApp(backend_port) pyserve = PyServeTestServer(config) await backend.start() await pyserve.start() try: async with httpx.AsyncClient(timeout=10.0) as client: # This should timeout because server default is 0.5s and slow endpoint takes 2s response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=2") assert response.status_code == 504 assert "Gateway Timeout" in response.text finally: await pyserve.stop() await backend.stop() @pytest.mark.asyncio async def test_route_timeout_overrides_server_default(backend_port, pyserve_port): """Test that route-level timeout overrides server default proxy_timeout.""" locations = { "~^/slow": { "proxy_pass": f"http://127.0.0.1:{backend_port}/slow", "timeout": 5.0, # Route-level timeout overrides server default }, "=/health": { "return": "200 OK", "content_type": "text/plain" }, } config = Config( http=HttpConfig(static_dir="./static", templates_dir="./templates"), server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.1), # Very short server default logging=LoggingConfig(level="ERROR", console_output=False), extensions=[ ExtensionConfig( type="routing", config={"regex_locations": locations} ) ] ) backend = BackendTestApp(backend_port) pyserve = PyServeTestServer(config) await backend.start() await pyserve.start() try: async with httpx.AsyncClient(timeout=10.0) as client: # This should succeed because route timeout (5s) > delay (0.5s), even though server default is 0.1s response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=0.5") assert response.status_code == 200 data = response.json() assert data["delay"] == 0.5 finally: await pyserve.stop() await backend.stop() @pytest.mark.asyncio async def test_proxy_timeout_config_parsing(): """Test that proxy_timeout is correctly parsed from config.""" from pyserve.config import Config # Test default value config = Config() assert config.server.proxy_timeout == 30.0 # Test custom value from dict config_with_timeout = Config._from_dict({ "server": { "host": "127.0.0.1", "port": 8080, "proxy_timeout": 60.0 } }) assert config_with_timeout.server.proxy_timeout == 60.0