reverse proxy added
added tests for reverse proxy too
This commit is contained in:
parent
e2646a752a
commit
5262c5e1fb
84
poetry.lock
generated
84
poetry.lock
generated
@ -65,6 +65,18 @@ d = ["aiohttp (>=3.10)"]
|
||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||
uvloop = ["uvloop (>=0.15.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.11.12"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b"},
|
||||
{file = "certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click"
|
||||
version = "8.2.1"
|
||||
@ -223,6 +235,28 @@ files = [
|
||||
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
|
||||
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
h11 = ">=0.16"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httptools"
|
||||
version = "0.6.4"
|
||||
@ -279,6 +313,32 @@ files = [
|
||||
[package.extras]
|
||||
test = ["Cython (>=0.29.24)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.27.2"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
|
||||
{file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
certifi = "*"
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
sniffio = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
@ -524,6 +584,26 @@ pygments = ">=2.7.2"
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "1.3.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"},
|
||||
{file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=8.2,<10"
|
||||
typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "6.2.1"
|
||||
@ -967,9 +1047,9 @@ files = [
|
||||
]
|
||||
|
||||
[extras]
|
||||
dev = ["black", "flake8", "isort", "mypy", "pytest", "pytest-cov"]
|
||||
dev = ["black", "flake8", "isort", "mypy", "pytest", "pytest-asyncio", "pytest-cov"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.12"
|
||||
content-hash = "5eda39db8e3d119d03c8e6083d1f9cd14691669a7130fb17b1445a0dd7bb79e7"
|
||||
content-hash = "e68108657ddfdc07ac0c4f5dbd9c5d2950e78b8b0053e4487ebf2327bbf4e020"
|
||||
|
||||
@ -14,6 +14,7 @@ dependencies = [
|
||||
"pyyaml (>=6.0,<7.0)",
|
||||
"types-pyyaml (>=6.0.12.20250822,<7.0.0.0)",
|
||||
"structlog (>=25.4.0,<26.0.0)",
|
||||
"httpx (>=0.27.0,<0.28.0)",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@ -23,6 +24,7 @@ pyserve = "pyserve.cli:main"
|
||||
dev = [
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-asyncio",
|
||||
"black",
|
||||
"isort",
|
||||
"mypy",
|
||||
@ -73,4 +75,5 @@ black = "^25.1.0"
|
||||
isort = "^6.0.1"
|
||||
mypy = "^1.17.1"
|
||||
flake8 = "^7.3.0"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ class ServerConfig:
|
||||
port: int = 8080
|
||||
backlog: int = 5
|
||||
default_root: bool = False
|
||||
proxy_timeout: float = 30.0
|
||||
redirect_instructions: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@ -112,6 +113,7 @@ class Config:
|
||||
port=server_data.get('port', config.server.port),
|
||||
backlog=server_data.get('backlog', config.server.backlog),
|
||||
default_root=server_data.get('default_root', config.server.default_root),
|
||||
proxy_timeout=server_data.get('proxy_timeout', config.server.proxy_timeout),
|
||||
redirect_instructions=server_data.get('redirect_instructions', {})
|
||||
)
|
||||
|
||||
|
||||
@ -33,9 +33,10 @@ class RoutingExtension(Extension):
|
||||
from .routing import create_router_from_config
|
||||
|
||||
regex_locations = config.get("regex_locations", {})
|
||||
default_proxy_timeout = config.get("default_proxy_timeout", 30.0)
|
||||
self.router = create_router_from_config(regex_locations)
|
||||
from .routing import RequestHandler
|
||||
self.handler = RequestHandler(self.router)
|
||||
self.handler = RequestHandler(self.router, default_proxy_timeout=default_proxy_timeout)
|
||||
|
||||
async def process_request(self, request: Request) -> Optional[Response]:
|
||||
try:
|
||||
|
||||
@ -2,6 +2,8 @@ import re
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Pattern
|
||||
from urllib.parse import urlparse
|
||||
import httpx
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, FileResponse, PlainTextResponse
|
||||
from .logging_utils import get_logger
|
||||
@ -63,9 +65,10 @@ class Router:
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, router: Router, static_dir: str = "./static"):
|
||||
def __init__(self, router: Router, static_dir: str = "./static", default_proxy_timeout: float = 30.0):
|
||||
self.router = router
|
||||
self.static_dir = Path(static_dir)
|
||||
self.default_proxy_timeout = default_proxy_timeout
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
path = request.url.path
|
||||
@ -166,15 +169,89 @@ class RequestHandler:
|
||||
|
||||
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
|
||||
params: Dict[str, str]) -> Response:
|
||||
# TODO: implement real proxying
|
||||
proxy_url = config["proxy_pass"]
|
||||
|
||||
for key, value in params.items():
|
||||
proxy_url = proxy_url.replace(f"{{{key}}}", value)
|
||||
|
||||
logger.info(f"Proxying request to: {proxy_url}")
|
||||
parsed_proxy = urlparse(proxy_url)
|
||||
|
||||
return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200)
|
||||
original_path = request.url.path
|
||||
|
||||
if parsed_proxy.path and parsed_proxy.path not in ("/", ""):
|
||||
target_url = proxy_url
|
||||
else:
|
||||
base_url = f"{parsed_proxy.scheme}://{parsed_proxy.netloc}"
|
||||
target_url = f"{base_url}{original_path}"
|
||||
|
||||
if request.url.query:
|
||||
separator = "&" if "?" in target_url else "?"
|
||||
target_url = f"{target_url}{separator}{request.url.query}"
|
||||
|
||||
logger.info(f"Proxying request to: {target_url}")
|
||||
|
||||
proxy_headers = dict(request.headers)
|
||||
|
||||
hop_by_hop_headers = [
|
||||
"connection", "keep-alive", "proxy-authenticate",
|
||||
"proxy-authorization", "te", "trailers", "transfer-encoding",
|
||||
"upgrade", "host"
|
||||
]
|
||||
for header in hop_by_hop_headers:
|
||||
proxy_headers.pop(header, None)
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
proxy_headers["X-Forwarded-For"] = client_ip
|
||||
proxy_headers["X-Forwarded-Proto"] = request.url.scheme
|
||||
proxy_headers["X-Forwarded-Host"] = request.headers.get("host", "")
|
||||
proxy_headers["X-Real-IP"] = client_ip
|
||||
|
||||
proxy_headers["Host"] = parsed_proxy.netloc
|
||||
|
||||
if "headers" in config:
|
||||
for header in config["headers"]:
|
||||
if ":" in header:
|
||||
name, value = header.split(":", 1)
|
||||
value = value.strip()
|
||||
for key, param_value in params.items():
|
||||
value = value.replace(f"{{{key}}}", param_value)
|
||||
value = value.replace("$remote_addr", client_ip)
|
||||
proxy_headers[name.strip()] = value
|
||||
|
||||
body = await request.body()
|
||||
|
||||
timeout = config.get("timeout", self.default_proxy_timeout)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
|
||||
proxy_response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=proxy_headers,
|
||||
content=body if body else None,
|
||||
)
|
||||
|
||||
response_headers = dict(proxy_response.headers)
|
||||
|
||||
for header in hop_by_hop_headers:
|
||||
response_headers.pop(header, None)
|
||||
|
||||
return Response(
|
||||
content=proxy_response.content,
|
||||
status_code=proxy_response.status_code,
|
||||
headers=response_headers,
|
||||
media_type=proxy_response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Proxy connection error to {target_url}: {e}")
|
||||
return PlainTextResponse("502 Bad Gateway", status_code=502)
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Proxy timeout to {target_url}: {e}")
|
||||
return PlainTextResponse("504 Gateway Timeout", status_code=504)
|
||||
except Exception as e:
|
||||
logger.error(f"Proxy error to {target_url}: {e}")
|
||||
return PlainTextResponse("502 Bad Gateway", status_code=502)
|
||||
|
||||
|
||||
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:
|
||||
|
||||
@ -76,9 +76,13 @@ class PyServeServer:
|
||||
|
||||
def _load_extensions(self) -> None:
|
||||
for ext_config in self.config.extensions:
|
||||
config = ext_config.config.copy()
|
||||
if ext_config.type == "routing":
|
||||
config.setdefault("default_proxy_timeout", self.config.server.proxy_timeout)
|
||||
|
||||
self.extension_manager.load_extension(
|
||||
ext_config.type,
|
||||
ext_config.config
|
||||
config
|
||||
)
|
||||
|
||||
def _create_app(self) -> None:
|
||||
|
||||
724
tests/test_reverse_proxy.py
Normal file
724
tests/test_reverse_proxy.py
Normal file
@ -0,0 +1,724 @@
|
||||
"""
|
||||
Tests for reverse proxy functionality.
|
||||
|
||||
These tests start a backend test server and the main PyServe server,
|
||||
then verify that requests are correctly proxied between them.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import httpx
|
||||
import socket
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Dict, Any, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import uvicorn
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, PlainTextResponse, Response
|
||||
from starlette.routing import Route
|
||||
|
||||
from pyserve.config import Config, ServerConfig, HttpConfig, LoggingConfig, ExtensionConfig
|
||||
from pyserve.server import PyServeServer
|
||||
|
||||
|
||||
def get_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
# ============== Backend Test Application ==============
|
||||
|
||||
class BackendTestApp:
|
||||
def __init__(self, port: int):
|
||||
self.port = port
|
||||
self.request_log: list[Dict[str, Any]] = []
|
||||
self.app = self._create_app()
|
||||
self._server_task = None
|
||||
|
||||
def _create_app(self) -> Starlette:
|
||||
routes = [
|
||||
Route("/", self._root, methods=["GET"]),
|
||||
Route("/api/v{version}/users", self._users, methods=["GET", "POST"]),
|
||||
Route("/api/v{version}/users/{user_id}", self._user_detail, methods=["GET", "PUT", "DELETE"]),
|
||||
Route("/api/v{version}/data", self._data, methods=["GET", "POST"]),
|
||||
Route("/echo", self._echo, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
|
||||
Route("/headers", self._headers, methods=["GET"]),
|
||||
Route("/slow", self._slow, methods=["GET"]),
|
||||
Route("/status/{code:int}", self._status, methods=["GET"]),
|
||||
Route("/json", self._json, methods=["POST"]),
|
||||
Route("/backend2/", self._backend2_root, methods=["GET"]),
|
||||
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
|
||||
]
|
||||
return Starlette(routes=routes)
|
||||
|
||||
async def _root(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
return JSONResponse({"message": "Backend root", "port": self.port})
|
||||
|
||||
async def _users(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
version = request.path_params.get("version", "1")
|
||||
if request.method == "POST":
|
||||
body = await request.json()
|
||||
return JSONResponse({"action": "create_user", "version": version, "data": body}, status_code=201)
|
||||
return JSONResponse({"users": [{"id": 1, "name": "Test User"}], "version": version})
|
||||
|
||||
async def _user_detail(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
version = request.path_params.get("version", "1")
|
||||
user_id = request.path_params.get("user_id", "0")
|
||||
if request.method == "DELETE":
|
||||
return JSONResponse({"action": "delete_user", "user_id": user_id, "version": version})
|
||||
if request.method == "PUT":
|
||||
body = await request.json()
|
||||
return JSONResponse({"action": "update_user", "user_id": user_id, "version": version, "data": body})
|
||||
return JSONResponse({"user": {"id": user_id, "name": f"User {user_id}"}, "version": version})
|
||||
|
||||
async def _data(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
version = request.path_params.get("version", "1")
|
||||
return JSONResponse({"data": "test data", "version": version})
|
||||
|
||||
async def _echo(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
body = await request.body()
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=200,
|
||||
media_type=request.headers.get("content-type", "text/plain")
|
||||
)
|
||||
|
||||
async def _headers(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
headers = dict(request.headers)
|
||||
return JSONResponse({
|
||||
"received_headers": headers,
|
||||
"client_ip": request.client.host if request.client else None,
|
||||
})
|
||||
|
||||
async def _slow(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
delay = float(request.query_params.get("delay", "2"))
|
||||
await asyncio.sleep(delay)
|
||||
return JSONResponse({"message": "slow response", "delay": delay})
|
||||
|
||||
async def _status(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
code = request.path_params.get("code", 200)
|
||||
return PlainTextResponse(f"Status: {code}", status_code=code)
|
||||
|
||||
async def _json(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
body = await request.json()
|
||||
return JSONResponse({"received": body, "processed": True})
|
||||
|
||||
async def _backend2_root(self, request: Request) -> Response:
|
||||
self._log_request(request)
|
||||
return JSONResponse({"message": "Backend2 root", "port": self.port})
|
||||
|
||||
async def _catch_all(self, request: Request) -> Response:
|
||||
"""Catch-all handler for debugging unmatched routes."""
|
||||
self._log_request(request)
|
||||
return JSONResponse({
|
||||
"message": "Catch-all",
|
||||
"path": str(request.url.path),
|
||||
"method": request.method,
|
||||
"port": self.port
|
||||
})
|
||||
|
||||
def _log_request(self, request: Request) -> None:
|
||||
self.request_log.append({
|
||||
"method": request.method,
|
||||
"path": str(request.url.path),
|
||||
"query": str(request.url.query),
|
||||
"headers": dict(request.headers),
|
||||
})
|
||||
|
||||
async def start(self) -> None:
|
||||
config = uvicorn.Config(
|
||||
app=self.app,
|
||||
host="127.0.0.1",
|
||||
port=self.port,
|
||||
log_level="critical",
|
||||
access_log=False,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
self._server_task = asyncio.create_task(server.serve())
|
||||
|
||||
# Wait for server to be ready
|
||||
for _ in range(50): # 5 seconds max
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(f"http://127.0.0.1:{self.port}/")
|
||||
return
|
||||
except httpx.ConnectError:
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(f"Backend server failed to start on port {self.port}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server_task:
|
||||
self._server_task.cancel()
|
||||
try:
|
||||
await self._server_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# ============== PyServe Test Server ==============
|
||||
|
||||
class PyServeTestServer:
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.server = PyServeServer(config)
|
||||
self._server_task = None
|
||||
|
||||
async def start(self) -> None:
|
||||
assert self.server.app is not None, "Server app not initialized"
|
||||
config = uvicorn.Config(
|
||||
app=self.server.app,
|
||||
host=self.config.server.host,
|
||||
port=self.config.server.port,
|
||||
log_level="critical",
|
||||
access_log=False,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
self._server_task = asyncio.create_task(server.serve())
|
||||
|
||||
# Wait for server to be ready
|
||||
for _ in range(50): # 5 seconds max
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(f"http://127.0.0.1:{self.config.server.port}/health")
|
||||
return
|
||||
except httpx.ConnectError:
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(f"PyServe server failed to start on port {self.config.server.port}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server_task:
|
||||
self._server_task.cancel()
|
||||
try:
|
||||
await self._server_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# ============== Fixtures ==============
|
||||
|
||||
@pytest.fixture
|
||||
def backend_port() -> int:
|
||||
return get_free_port()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pyserve_port() -> int:
|
||||
return get_free_port()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_proxy_config():
|
||||
def _create_config(pyserve_port: int, backend_port: int, extra_locations: Optional[Dict[str, Any]] = None) -> Config:
|
||||
locations = {
|
||||
# API proxy with version capture
|
||||
"~^/api/v(?P<version>\\d+)/": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}",
|
||||
"headers": [
|
||||
"X-API-Version: {version}",
|
||||
"X-Forwarded-For: $remote_addr"
|
||||
]
|
||||
},
|
||||
# Echo endpoint proxy
|
||||
"=/echo": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/echo",
|
||||
},
|
||||
# Headers test proxy
|
||||
"=/headers": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/headers",
|
||||
},
|
||||
# Slow endpoint proxy with timeout
|
||||
"~^/slow": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
|
||||
"timeout": 5.0,
|
||||
},
|
||||
# Status endpoint proxy
|
||||
"~^/status/": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}",
|
||||
},
|
||||
# JSON endpoint proxy
|
||||
"=/json": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/json",
|
||||
},
|
||||
# Health check
|
||||
"=/health": {
|
||||
"return": "200 OK",
|
||||
"content_type": "text/plain"
|
||||
},
|
||||
# Default fallback
|
||||
"__default__": {
|
||||
"return": "404 Not Found",
|
||||
"content_type": "text/plain"
|
||||
}
|
||||
}
|
||||
|
||||
if extra_locations:
|
||||
locations.update(extra_locations)
|
||||
|
||||
config = Config(
|
||||
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
|
||||
server=ServerConfig(host="127.0.0.1", port=pyserve_port),
|
||||
logging=LoggingConfig(level="ERROR", console_output=False),
|
||||
extensions=[
|
||||
ExtensionConfig(
|
||||
type="routing",
|
||||
config={"regex_locations": locations}
|
||||
)
|
||||
]
|
||||
)
|
||||
return config
|
||||
|
||||
return _create_config
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def running_servers(
|
||||
pyserve_port: int,
|
||||
backend_port: int,
|
||||
config_factory,
|
||||
extra_locations: Optional[Dict[str, Any]] = None
|
||||
) -> AsyncGenerator[tuple[PyServeTestServer, BackendTestApp], None]:
|
||||
backend = BackendTestApp(backend_port)
|
||||
config = config_factory(pyserve_port, backend_port, extra_locations)
|
||||
pyserve = PyServeTestServer(config)
|
||||
|
||||
await backend.start()
|
||||
await pyserve.start()
|
||||
|
||||
try:
|
||||
yield pyserve, backend
|
||||
finally:
|
||||
await pyserve.stop()
|
||||
await backend.stop()
|
||||
|
||||
|
||||
# ============== Tests ==============
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_proxy_get(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test basic GET request proxying."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "users" in data
|
||||
assert data["version"] == "1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_api_versions(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test that API version is correctly captured and passed."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Test v1
|
||||
response_v1 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/data")
|
||||
assert response_v1.status_code == 200
|
||||
assert response_v1.json()["version"] == "1"
|
||||
|
||||
# Test v2
|
||||
response_v2 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v2/data")
|
||||
assert response_v2.status_code == 200
|
||||
assert response_v2.json()["version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_post_with_body(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test POST request with JSON body."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
payload = {"name": "New User", "email": "test@example.com"}
|
||||
response = await client.post(
|
||||
f"http://127.0.0.1:{pyserve_port}/api/v1/users",
|
||||
json=payload
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["action"] == "create_user"
|
||||
assert data["data"]["name"] == "New User"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_put_request(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test PUT request proxying."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
payload = {"name": "Updated User"}
|
||||
response = await client.put(
|
||||
f"http://127.0.0.1:{pyserve_port}/api/v1/users/123",
|
||||
json=payload
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["action"] == "update_user"
|
||||
assert data["user_id"] == "123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_delete_request(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test DELETE request proxying."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"http://127.0.0.1:{pyserve_port}/api/v1/users/456"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["action"] == "delete_user"
|
||||
assert data["user_id"] == "456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_headers_forwarding(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test that headers are correctly forwarded."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/headers",
|
||||
headers={"X-Custom-Header": "test-value"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check that custom header was forwarded
|
||||
assert data["received_headers"].get("x-custom-header") == "test-value"
|
||||
|
||||
# Check that X-Forwarded headers were added
|
||||
assert "x-forwarded-for" in data["received_headers"]
|
||||
assert "x-forwarded-proto" in data["received_headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_echo_endpoint(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test echo endpoint that returns request body."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
test_data = "Hello, Proxy!"
|
||||
response = await client.post(
|
||||
f"http://127.0.0.1:{pyserve_port}/echo",
|
||||
content=test_data,
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == test_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_json_endpoint(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test JSON processing through proxy."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
payload = {"key": "value", "numbers": [1, 2, 3]}
|
||||
response = await client.post(
|
||||
f"http://127.0.0.1:{pyserve_port}/json",
|
||||
json=payload
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["processed"] is True
|
||||
assert data["received"]["key"] == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_status_codes(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test that various status codes are correctly proxied."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
for status in [200, 201, 400, 404, 500]:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/status/{status}"
|
||||
)
|
||||
assert response.status_code == status
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_query_params(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test that query parameters are passed through."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/slow?delay=0.1"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["delay"] == 0.1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_backend_unavailable(pyserve_port, create_proxy_config):
|
||||
"""Test handling when backend is unavailable (502 Bad Gateway)."""
|
||||
# Use a port where nothing is listening
|
||||
unavailable_port = get_free_port()
|
||||
|
||||
config = create_proxy_config(pyserve_port, unavailable_port)
|
||||
pyserve = PyServeTestServer(config)
|
||||
|
||||
await pyserve.start()
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "Bad Gateway" in response.text
|
||||
finally:
|
||||
await pyserve.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_timeout(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test proxy timeout handling (504 Gateway Timeout)."""
|
||||
# Create config with very short timeout
|
||||
extra_locations = {
|
||||
"=/timeout-test": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow?delay=5",
|
||||
"timeout": 0.5, # Very short timeout
|
||||
}
|
||||
}
|
||||
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations):
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/timeout-test"
|
||||
)
|
||||
|
||||
assert response.status_code == 504
|
||||
assert "Gateway Timeout" in response.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_health_check_not_proxied(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test that health check endpoint is handled locally, not proxied."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/health"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_default_fallback(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test default fallback for unmatched routes."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/nonexistent/path"
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Not Found" in response.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_config_multiple_proxies(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test complex configuration with multiple proxy rules."""
|
||||
# Create a second backend port for testing multiple backends
|
||||
backend2_port = get_free_port()
|
||||
backend2 = BackendTestApp(backend2_port)
|
||||
|
||||
extra_locations = {
|
||||
"~^/backend2/": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend2_port}",
|
||||
}
|
||||
}
|
||||
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations):
|
||||
await backend2.start()
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Request to first backend
|
||||
response1 = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
assert response1.json()["users"][0]["id"] == 1
|
||||
|
||||
# Request to second backend
|
||||
response2 = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/backend2/"
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
data2 = response2.json()
|
||||
assert data2["port"] == backend2_port
|
||||
finally:
|
||||
await backend2.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_large_request_body(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test proxying large request bodies."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Create a large payload
|
||||
large_data = "x" * 100000 # 100KB of data
|
||||
response = await client.post(
|
||||
f"http://127.0.0.1:{pyserve_port}/echo",
|
||||
content=large_data,
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == large_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_content_type_preservation(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test that content-type is correctly preserved."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "application/json" in response.headers.get("content-type", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_concurrent_requests(backend_port, pyserve_port, create_proxy_config):
|
||||
"""Test handling multiple concurrent proxy requests."""
|
||||
async with running_servers(pyserve_port, backend_port, create_proxy_config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Send multiple concurrent requests
|
||||
tasks = [
|
||||
client.get(f"http://127.0.0.1:{pyserve_port}/api/v{i % 3 + 1}/users")
|
||||
for i in range(10)
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == 200
|
||||
assert "users" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_default_proxy_timeout(backend_port, pyserve_port):
|
||||
"""Test that server-level proxy_timeout is used as default."""
|
||||
locations = {
|
||||
"~^/slow": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
|
||||
# No timeout specified - should use server default
|
||||
},
|
||||
"=/health": {
|
||||
"return": "200 OK",
|
||||
"content_type": "text/plain"
|
||||
},
|
||||
}
|
||||
|
||||
config = Config(
|
||||
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
|
||||
server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.5), # Very short timeout
|
||||
logging=LoggingConfig(level="ERROR", console_output=False),
|
||||
extensions=[
|
||||
ExtensionConfig(
|
||||
type="routing",
|
||||
config={"regex_locations": locations}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
backend = BackendTestApp(backend_port)
|
||||
pyserve = PyServeTestServer(config)
|
||||
|
||||
await backend.start()
|
||||
await pyserve.start()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# This should timeout because server default is 0.5s and slow endpoint takes 2s
|
||||
response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=2")
|
||||
assert response.status_code == 504
|
||||
assert "Gateway Timeout" in response.text
|
||||
finally:
|
||||
await pyserve.stop()
|
||||
await backend.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_timeout_overrides_server_default(backend_port, pyserve_port):
|
||||
"""Test that route-level timeout overrides server default proxy_timeout."""
|
||||
locations = {
|
||||
"~^/slow": {
|
||||
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
|
||||
"timeout": 5.0, # Route-level timeout overrides server default
|
||||
},
|
||||
"=/health": {
|
||||
"return": "200 OK",
|
||||
"content_type": "text/plain"
|
||||
},
|
||||
}
|
||||
|
||||
config = Config(
|
||||
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
|
||||
server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.1), # Very short server default
|
||||
logging=LoggingConfig(level="ERROR", console_output=False),
|
||||
extensions=[
|
||||
ExtensionConfig(
|
||||
type="routing",
|
||||
config={"regex_locations": locations}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
backend = BackendTestApp(backend_port)
|
||||
pyserve = PyServeTestServer(config)
|
||||
|
||||
await backend.start()
|
||||
await pyserve.start()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# This should succeed because route timeout (5s) > delay (0.5s), even though server default is 0.1s
|
||||
response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=0.5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["delay"] == 0.5
|
||||
finally:
|
||||
await pyserve.stop()
|
||||
await backend.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_timeout_config_parsing():
|
||||
"""Test that proxy_timeout is correctly parsed from config."""
|
||||
from pyserve.config import Config
|
||||
|
||||
# Test default value
|
||||
config = Config()
|
||||
assert config.server.proxy_timeout == 30.0
|
||||
|
||||
# Test custom value from dict
|
||||
config_with_timeout = Config._from_dict({
|
||||
"server": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080,
|
||||
"proxy_timeout": 60.0
|
||||
}
|
||||
})
|
||||
assert config_with_timeout.server.proxy_timeout == 60.0
|
||||
Loading…
x
Reference in New Issue
Block a user