reverse proxy added

added tests for reverse proxy too
This commit is contained in:
Илья Глазунов 2025-12-03 00:05:11 +03:00
parent e2646a752a
commit 5262c5e1fb
7 changed files with 899 additions and 8 deletions

84
poetry.lock generated
View File

@ -65,6 +65,18 @@ d = ["aiohttp (>=3.10)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "certifi"
version = "2025.11.12"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b"},
{file = "certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316"},
]
[[package]]
name = "click"
version = "8.2.1"
@ -223,6 +235,28 @@ files = [
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
]
[[package]]
name = "httpcore"
version = "1.0.9"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
]
[package.dependencies]
certifi = "*"
h11 = ">=0.16"
[package.extras]
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<1.0)"]
[[package]]
name = "httptools"
version = "0.6.4"
@ -279,6 +313,32 @@ files = [
[package.extras]
test = ["Cython (>=0.29.24)"]
[[package]]
name = "httpx"
version = "0.27.2"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
{file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
]
[package.dependencies]
anyio = "*"
certifi = "*"
httpcore = "==1.*"
idna = "*"
sniffio = "*"
[package.extras]
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "idna"
version = "3.10"
@ -524,6 +584,26 @@ pygments = ">=2.7.2"
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "1.3.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.10"
groups = ["main", "dev"]
files = [
{file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"},
{file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"},
]
[package.dependencies]
pytest = ">=8.2,<10"
typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""}
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "6.2.1"
@ -967,9 +1047,9 @@ files = [
]
[extras]
dev = ["black", "flake8", "isort", "mypy", "pytest", "pytest-cov"]
dev = ["black", "flake8", "isort", "mypy", "pytest", "pytest-asyncio", "pytest-cov"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.12"
content-hash = "5eda39db8e3d119d03c8e6083d1f9cd14691669a7130fb17b1445a0dd7bb79e7"
content-hash = "e68108657ddfdc07ac0c4f5dbd9c5d2950e78b8b0053e4487ebf2327bbf4e020"

View File

@ -14,6 +14,7 @@ dependencies = [
"pyyaml (>=6.0,<7.0)",
"types-pyyaml (>=6.0.12.20250822,<7.0.0.0)",
"structlog (>=25.4.0,<26.0.0)",
"httpx (>=0.27.0,<0.28.0)",
]
[project.scripts]
@ -23,6 +24,7 @@ pyserve = "pyserve.cli:main"
dev = [
"pytest",
"pytest-cov",
"pytest-asyncio",
"black",
"isort",
"mypy",
@ -73,4 +75,5 @@ black = "^25.1.0"
isort = "^6.0.1"
mypy = "^1.17.1"
flake8 = "^7.3.0"
pytest-asyncio = "^1.3.0"

View File

@ -18,6 +18,7 @@ class ServerConfig:
port: int = 8080
backlog: int = 5
default_root: bool = False
proxy_timeout: float = 30.0
redirect_instructions: Dict[str, str] = field(default_factory=dict)
@ -112,6 +113,7 @@ class Config:
port=server_data.get('port', config.server.port),
backlog=server_data.get('backlog', config.server.backlog),
default_root=server_data.get('default_root', config.server.default_root),
proxy_timeout=server_data.get('proxy_timeout', config.server.proxy_timeout),
redirect_instructions=server_data.get('redirect_instructions', {})
)

View File

@ -33,9 +33,10 @@ class RoutingExtension(Extension):
from .routing import create_router_from_config
regex_locations = config.get("regex_locations", {})
default_proxy_timeout = config.get("default_proxy_timeout", 30.0)
self.router = create_router_from_config(regex_locations)
from .routing import RequestHandler
self.handler = RequestHandler(self.router)
self.handler = RequestHandler(self.router, default_proxy_timeout=default_proxy_timeout)
async def process_request(self, request: Request) -> Optional[Response]:
try:

View File

@ -2,6 +2,8 @@ import re
import mimetypes
from pathlib import Path
from typing import Dict, Any, Optional, Pattern
from urllib.parse import urlparse
import httpx
from starlette.requests import Request
from starlette.responses import Response, FileResponse, PlainTextResponse
from .logging_utils import get_logger
@ -63,9 +65,10 @@ class Router:
class RequestHandler:
def __init__(self, router: Router, static_dir: str = "./static"):
def __init__(self, router: Router, static_dir: str = "./static", default_proxy_timeout: float = 30.0):
self.router = router
self.static_dir = Path(static_dir)
self.default_proxy_timeout = default_proxy_timeout
async def handle(self, request: Request) -> Response:
path = request.url.path
@ -166,15 +169,89 @@ class RequestHandler:
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
params: Dict[str, str]) -> Response:
# TODO: implement real proxying
proxy_url = config["proxy_pass"]
for key, value in params.items():
proxy_url = proxy_url.replace(f"{{{key}}}", value)
logger.info(f"Proxying request to: {proxy_url}")
parsed_proxy = urlparse(proxy_url)
return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200)
original_path = request.url.path
if parsed_proxy.path and parsed_proxy.path not in ("/", ""):
target_url = proxy_url
else:
base_url = f"{parsed_proxy.scheme}://{parsed_proxy.netloc}"
target_url = f"{base_url}{original_path}"
if request.url.query:
separator = "&" if "?" in target_url else "?"
target_url = f"{target_url}{separator}{request.url.query}"
logger.info(f"Proxying request to: {target_url}")
proxy_headers = dict(request.headers)
hop_by_hop_headers = [
"connection", "keep-alive", "proxy-authenticate",
"proxy-authorization", "te", "trailers", "transfer-encoding",
"upgrade", "host"
]
for header in hop_by_hop_headers:
proxy_headers.pop(header, None)
client_ip = request.client.host if request.client else "unknown"
proxy_headers["X-Forwarded-For"] = client_ip
proxy_headers["X-Forwarded-Proto"] = request.url.scheme
proxy_headers["X-Forwarded-Host"] = request.headers.get("host", "")
proxy_headers["X-Real-IP"] = client_ip
proxy_headers["Host"] = parsed_proxy.netloc
if "headers" in config:
for header in config["headers"]:
if ":" in header:
name, value = header.split(":", 1)
value = value.strip()
for key, param_value in params.items():
value = value.replace(f"{{{key}}}", param_value)
value = value.replace("$remote_addr", client_ip)
proxy_headers[name.strip()] = value
body = await request.body()
timeout = config.get("timeout", self.default_proxy_timeout)
try:
async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
proxy_response = await client.request(
method=request.method,
url=target_url,
headers=proxy_headers,
content=body if body else None,
)
response_headers = dict(proxy_response.headers)
for header in hop_by_hop_headers:
response_headers.pop(header, None)
return Response(
content=proxy_response.content,
status_code=proxy_response.status_code,
headers=response_headers,
media_type=proxy_response.headers.get("content-type"),
)
except httpx.ConnectError as e:
logger.error(f"Proxy connection error to {target_url}: {e}")
return PlainTextResponse("502 Bad Gateway", status_code=502)
except httpx.TimeoutException as e:
logger.error(f"Proxy timeout to {target_url}: {e}")
return PlainTextResponse("504 Gateway Timeout", status_code=504)
except Exception as e:
logger.error(f"Proxy error to {target_url}: {e}")
return PlainTextResponse("502 Bad Gateway", status_code=502)
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:

View File

@ -76,9 +76,13 @@ class PyServeServer:
def _load_extensions(self) -> None:
for ext_config in self.config.extensions:
config = ext_config.config.copy()
if ext_config.type == "routing":
config.setdefault("default_proxy_timeout", self.config.server.proxy_timeout)
self.extension_manager.load_extension(
ext_config.type,
ext_config.config
config
)
def _create_app(self) -> None:

724
tests/test_reverse_proxy.py Normal file
View File

@ -0,0 +1,724 @@
"""
Tests for reverse proxy functionality.
These tests start a backend test server and the main PyServe server,
then verify that requests are correctly proxied between them.
"""
import asyncio
import pytest
import httpx
import socket
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Any, Optional
from unittest.mock import patch
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Route
from pyserve.config import Config, ServerConfig, HttpConfig, LoggingConfig, ExtensionConfig
from pyserve.server import PyServeServer
def get_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
# ============== Backend Test Application ==============
class BackendTestApp:
def __init__(self, port: int):
self.port = port
self.request_log: list[Dict[str, Any]] = []
self.app = self._create_app()
self._server_task = None
def _create_app(self) -> Starlette:
routes = [
Route("/", self._root, methods=["GET"]),
Route("/api/v{version}/users", self._users, methods=["GET", "POST"]),
Route("/api/v{version}/users/{user_id}", self._user_detail, methods=["GET", "PUT", "DELETE"]),
Route("/api/v{version}/data", self._data, methods=["GET", "POST"]),
Route("/echo", self._echo, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
Route("/headers", self._headers, methods=["GET"]),
Route("/slow", self._slow, methods=["GET"]),
Route("/status/{code:int}", self._status, methods=["GET"]),
Route("/json", self._json, methods=["POST"]),
Route("/backend2/", self._backend2_root, methods=["GET"]),
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
]
return Starlette(routes=routes)
async def _root(self, request: Request) -> Response:
self._log_request(request)
return JSONResponse({"message": "Backend root", "port": self.port})
async def _users(self, request: Request) -> Response:
self._log_request(request)
version = request.path_params.get("version", "1")
if request.method == "POST":
body = await request.json()
return JSONResponse({"action": "create_user", "version": version, "data": body}, status_code=201)
return JSONResponse({"users": [{"id": 1, "name": "Test User"}], "version": version})
async def _user_detail(self, request: Request) -> Response:
self._log_request(request)
version = request.path_params.get("version", "1")
user_id = request.path_params.get("user_id", "0")
if request.method == "DELETE":
return JSONResponse({"action": "delete_user", "user_id": user_id, "version": version})
if request.method == "PUT":
body = await request.json()
return JSONResponse({"action": "update_user", "user_id": user_id, "version": version, "data": body})
return JSONResponse({"user": {"id": user_id, "name": f"User {user_id}"}, "version": version})
async def _data(self, request: Request) -> Response:
self._log_request(request)
version = request.path_params.get("version", "1")
return JSONResponse({"data": "test data", "version": version})
async def _echo(self, request: Request) -> Response:
self._log_request(request)
body = await request.body()
return Response(
content=body,
status_code=200,
media_type=request.headers.get("content-type", "text/plain")
)
async def _headers(self, request: Request) -> Response:
self._log_request(request)
headers = dict(request.headers)
return JSONResponse({
"received_headers": headers,
"client_ip": request.client.host if request.client else None,
})
async def _slow(self, request: Request) -> Response:
self._log_request(request)
delay = float(request.query_params.get("delay", "2"))
await asyncio.sleep(delay)
return JSONResponse({"message": "slow response", "delay": delay})
async def _status(self, request: Request) -> Response:
self._log_request(request)
code = request.path_params.get("code", 200)
return PlainTextResponse(f"Status: {code}", status_code=code)
async def _json(self, request: Request) -> Response:
self._log_request(request)
body = await request.json()
return JSONResponse({"received": body, "processed": True})
async def _backend2_root(self, request: Request) -> Response:
self._log_request(request)
return JSONResponse({"message": "Backend2 root", "port": self.port})
async def _catch_all(self, request: Request) -> Response:
"""Catch-all handler for debugging unmatched routes."""
self._log_request(request)
return JSONResponse({
"message": "Catch-all",
"path": str(request.url.path),
"method": request.method,
"port": self.port
})
def _log_request(self, request: Request) -> None:
self.request_log.append({
"method": request.method,
"path": str(request.url.path),
"query": str(request.url.query),
"headers": dict(request.headers),
})
async def start(self) -> None:
config = uvicorn.Config(
app=self.app,
host="127.0.0.1",
port=self.port,
log_level="critical",
access_log=False,
)
server = uvicorn.Server(config)
self._server_task = asyncio.create_task(server.serve())
# Wait for server to be ready
for _ in range(50): # 5 seconds max
try:
async with httpx.AsyncClient() as client:
await client.get(f"http://127.0.0.1:{self.port}/")
return
except httpx.ConnectError:
await asyncio.sleep(0.1)
raise RuntimeError(f"Backend server failed to start on port {self.port}")
async def stop(self) -> None:
if self._server_task:
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass
# ============== PyServe Test Server ==============
class PyServeTestServer:
def __init__(self, config: Config):
self.config = config
self.server = PyServeServer(config)
self._server_task = None
async def start(self) -> None:
assert self.server.app is not None, "Server app not initialized"
config = uvicorn.Config(
app=self.server.app,
host=self.config.server.host,
port=self.config.server.port,
log_level="critical",
access_log=False,
)
server = uvicorn.Server(config)
self._server_task = asyncio.create_task(server.serve())
# Wait for server to be ready
for _ in range(50): # 5 seconds max
try:
async with httpx.AsyncClient() as client:
await client.get(f"http://127.0.0.1:{self.config.server.port}/health")
return
except httpx.ConnectError:
await asyncio.sleep(0.1)
raise RuntimeError(f"PyServe server failed to start on port {self.config.server.port}")
async def stop(self) -> None:
if self._server_task:
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass
# ============== Fixtures ==============
@pytest.fixture
def backend_port() -> int:
return get_free_port()
@pytest.fixture
def pyserve_port() -> int:
return get_free_port()
@pytest.fixture
def create_proxy_config():
def _create_config(pyserve_port: int, backend_port: int, extra_locations: Optional[Dict[str, Any]] = None) -> Config:
locations = {
# API proxy with version capture
"~^/api/v(?P<version>\\d+)/": {
"proxy_pass": f"http://127.0.0.1:{backend_port}",
"headers": [
"X-API-Version: {version}",
"X-Forwarded-For: $remote_addr"
]
},
# Echo endpoint proxy
"=/echo": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/echo",
},
# Headers test proxy
"=/headers": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/headers",
},
# Slow endpoint proxy with timeout
"~^/slow": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
"timeout": 5.0,
},
# Status endpoint proxy
"~^/status/": {
"proxy_pass": f"http://127.0.0.1:{backend_port}",
},
# JSON endpoint proxy
"=/json": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/json",
},
# Health check
"=/health": {
"return": "200 OK",
"content_type": "text/plain"
},
# Default fallback
"__default__": {
"return": "404 Not Found",
"content_type": "text/plain"
}
}
if extra_locations:
locations.update(extra_locations)
config = Config(
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
server=ServerConfig(host="127.0.0.1", port=pyserve_port),
logging=LoggingConfig(level="ERROR", console_output=False),
extensions=[
ExtensionConfig(
type="routing",
config={"regex_locations": locations}
)
]
)
return config
return _create_config
@asynccontextmanager
async def running_servers(
pyserve_port: int,
backend_port: int,
config_factory,
extra_locations: Optional[Dict[str, Any]] = None
) -> AsyncGenerator[tuple[PyServeTestServer, BackendTestApp], None]:
backend = BackendTestApp(backend_port)
config = config_factory(pyserve_port, backend_port, extra_locations)
pyserve = PyServeTestServer(config)
await backend.start()
await pyserve.start()
try:
yield pyserve, backend
finally:
await pyserve.stop()
await backend.stop()
# ============== Tests ==============
@pytest.mark.asyncio
async def test_basic_proxy_get(backend_port, pyserve_port, create_proxy_config):
"""Test basic GET request proxying."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/users")
assert response.status_code == 200
data = response.json()
assert "users" in data
assert data["version"] == "1"
@pytest.mark.asyncio
async def test_proxy_api_versions(backend_port, pyserve_port, create_proxy_config):
"""Test that API version is correctly captured and passed."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
# Test v1
response_v1 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/data")
assert response_v1.status_code == 200
assert response_v1.json()["version"] == "1"
# Test v2
response_v2 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v2/data")
assert response_v2.status_code == 200
assert response_v2.json()["version"] == "2"
@pytest.mark.asyncio
async def test_proxy_post_with_body(backend_port, pyserve_port, create_proxy_config):
"""Test POST request with JSON body."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
payload = {"name": "New User", "email": "test@example.com"}
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/api/v1/users",
json=payload
)
assert response.status_code == 201
data = response.json()
assert data["action"] == "create_user"
assert data["data"]["name"] == "New User"
@pytest.mark.asyncio
async def test_proxy_put_request(backend_port, pyserve_port, create_proxy_config):
"""Test PUT request proxying."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
payload = {"name": "Updated User"}
response = await client.put(
f"http://127.0.0.1:{pyserve_port}/api/v1/users/123",
json=payload
)
assert response.status_code == 200
data = response.json()
assert data["action"] == "update_user"
assert data["user_id"] == "123"
@pytest.mark.asyncio
async def test_proxy_delete_request(backend_port, pyserve_port, create_proxy_config):
"""Test DELETE request proxying."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.delete(
f"http://127.0.0.1:{pyserve_port}/api/v1/users/456"
)
assert response.status_code == 200
data = response.json()
assert data["action"] == "delete_user"
assert data["user_id"] == "456"
@pytest.mark.asyncio
async def test_proxy_headers_forwarding(backend_port, pyserve_port, create_proxy_config):
"""Test that headers are correctly forwarded."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/headers",
headers={"X-Custom-Header": "test-value"}
)
assert response.status_code == 200
data = response.json()
# Check that custom header was forwarded
assert data["received_headers"].get("x-custom-header") == "test-value"
# Check that X-Forwarded headers were added
assert "x-forwarded-for" in data["received_headers"]
assert "x-forwarded-proto" in data["received_headers"]
@pytest.mark.asyncio
async def test_proxy_echo_endpoint(backend_port, pyserve_port, create_proxy_config):
"""Test echo endpoint that returns request body."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
test_data = "Hello, Proxy!"
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/echo",
content=test_data,
headers={"Content-Type": "text/plain"}
)
assert response.status_code == 200
assert response.text == test_data
@pytest.mark.asyncio
async def test_proxy_json_endpoint(backend_port, pyserve_port, create_proxy_config):
"""Test JSON processing through proxy."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
payload = {"key": "value", "numbers": [1, 2, 3]}
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/json",
json=payload
)
assert response.status_code == 200
data = response.json()
assert data["processed"] is True
assert data["received"]["key"] == "value"
@pytest.mark.asyncio
async def test_proxy_status_codes(backend_port, pyserve_port, create_proxy_config):
"""Test that various status codes are correctly proxied."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
for status in [200, 201, 400, 404, 500]:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/status/{status}"
)
assert response.status_code == status
@pytest.mark.asyncio
async def test_proxy_query_params(backend_port, pyserve_port, create_proxy_config):
"""Test that query parameters are passed through."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/slow?delay=0.1"
)
assert response.status_code == 200
data = response.json()
assert data["delay"] == 0.1
@pytest.mark.asyncio
async def test_proxy_backend_unavailable(pyserve_port, create_proxy_config):
"""Test handling when backend is unavailable (502 Bad Gateway)."""
# Use a port where nothing is listening
unavailable_port = get_free_port()
config = create_proxy_config(pyserve_port, unavailable_port)
pyserve = PyServeTestServer(config)
await pyserve.start()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
)
assert response.status_code == 502
assert "Bad Gateway" in response.text
finally:
await pyserve.stop()
@pytest.mark.asyncio
async def test_proxy_timeout(backend_port, pyserve_port, create_proxy_config):
"""Test proxy timeout handling (504 Gateway Timeout)."""
# Create config with very short timeout
extra_locations = {
"=/timeout-test": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow?delay=5",
"timeout": 0.5, # Very short timeout
}
}
async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations):
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/timeout-test"
)
assert response.status_code == 504
assert "Gateway Timeout" in response.text
@pytest.mark.asyncio
async def test_proxy_health_check_not_proxied(backend_port, pyserve_port, create_proxy_config):
"""Test that health check endpoint is handled locally, not proxied."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/health"
)
assert response.status_code == 200
assert response.text == "OK"
@pytest.mark.asyncio
async def test_proxy_default_fallback(backend_port, pyserve_port, create_proxy_config):
"""Test default fallback for unmatched routes."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/nonexistent/path"
)
assert response.status_code == 404
assert "Not Found" in response.text
@pytest.mark.asyncio
async def test_complex_config_multiple_proxies(backend_port, pyserve_port, create_proxy_config):
"""Test complex configuration with multiple proxy rules."""
# Create a second backend port for testing multiple backends
backend2_port = get_free_port()
backend2 = BackendTestApp(backend2_port)
extra_locations = {
"~^/backend2/": {
"proxy_pass": f"http://127.0.0.1:{backend2_port}",
}
}
async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations):
await backend2.start()
try:
async with httpx.AsyncClient() as client:
# Request to first backend
response1 = await client.get(
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
)
assert response1.status_code == 200
assert response1.json()["users"][0]["id"] == 1
# Request to second backend
response2 = await client.get(
f"http://127.0.0.1:{pyserve_port}/backend2/"
)
assert response2.status_code == 200
data2 = response2.json()
assert data2["port"] == backend2_port
finally:
await backend2.stop()
@pytest.mark.asyncio
async def test_proxy_large_request_body(backend_port, pyserve_port, create_proxy_config):
"""Test proxying large request bodies."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
# Create a large payload
large_data = "x" * 100000 # 100KB of data
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/echo",
content=large_data,
headers={"Content-Type": "text/plain"}
)
assert response.status_code == 200
assert response.text == large_data
@pytest.mark.asyncio
async def test_proxy_content_type_preservation(backend_port, pyserve_port, create_proxy_config):
"""Test that content-type is correctly preserved."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
)
assert response.status_code == 200
assert "application/json" in response.headers.get("content-type", "")
@pytest.mark.asyncio
async def test_proxy_concurrent_requests(backend_port, pyserve_port, create_proxy_config):
"""Test handling multiple concurrent proxy requests."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
# Send multiple concurrent requests
tasks = [
client.get(f"http://127.0.0.1:{pyserve_port}/api/v{i % 3 + 1}/users")
for i in range(10)
]
responses = await asyncio.gather(*tasks)
# All should succeed
for response in responses:
assert response.status_code == 200
assert "users" in response.json()
@pytest.mark.asyncio
async def test_server_default_proxy_timeout(backend_port, pyserve_port):
"""Test that server-level proxy_timeout is used as default."""
locations = {
"~^/slow": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
# No timeout specified - should use server default
},
"=/health": {
"return": "200 OK",
"content_type": "text/plain"
},
}
config = Config(
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.5), # Very short timeout
logging=LoggingConfig(level="ERROR", console_output=False),
extensions=[
ExtensionConfig(
type="routing",
config={"regex_locations": locations}
)
]
)
backend = BackendTestApp(backend_port)
pyserve = PyServeTestServer(config)
await backend.start()
await pyserve.start()
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# This should timeout because server default is 0.5s and slow endpoint takes 2s
response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=2")
assert response.status_code == 504
assert "Gateway Timeout" in response.text
finally:
await pyserve.stop()
await backend.stop()
@pytest.mark.asyncio
async def test_route_timeout_overrides_server_default(backend_port, pyserve_port):
"""Test that route-level timeout overrides server default proxy_timeout."""
locations = {
"~^/slow": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
"timeout": 5.0, # Route-level timeout overrides server default
},
"=/health": {
"return": "200 OK",
"content_type": "text/plain"
},
}
config = Config(
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.1), # Very short server default
logging=LoggingConfig(level="ERROR", console_output=False),
extensions=[
ExtensionConfig(
type="routing",
config={"regex_locations": locations}
)
]
)
backend = BackendTestApp(backend_port)
pyserve = PyServeTestServer(config)
await backend.start()
await pyserve.start()
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# This should succeed because route timeout (5s) > delay (0.5s), even though server default is 0.1s
response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=0.5")
assert response.status_code == 200
data = response.json()
assert data["delay"] == 0.5
finally:
await pyserve.stop()
await backend.stop()
@pytest.mark.asyncio
async def test_proxy_timeout_config_parsing():
"""Test that proxy_timeout is correctly parsed from config."""
from pyserve.config import Config
# Test default value
config = Config()
assert config.server.proxy_timeout == 30.0
# Test custom value from dict
config_with_timeout = Config._from_dict({
"server": {
"host": "127.0.0.1",
"port": 8080,
"proxy_timeout": 60.0
}
})
assert config_with_timeout.server.proxy_timeout == 60.0