pyserveX/pyserve/routing.py

187 lines
6.5 KiB
Python

import re
import mimetypes
from pathlib import Path
from typing import Dict, Any, Optional, Pattern
from starlette.requests import Request
from starlette.responses import Response, FileResponse, PlainTextResponse
from .logging_utils import get_logger
logger = get_logger(__name__)
class RouteMatch:
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
self.config = config
self.params = params or {}
class Router:
def __init__(self, static_dir: str = "./static"):
self.static_dir = Path(static_dir)
self.routes: Dict[Pattern, Dict[str, Any]] = {}
self.exact_routes: Dict[str, Dict[str, Any]] = {}
self.default_route: Optional[Dict[str, Any]] = None
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
if pattern.startswith("="):
exact_path = pattern[1:]
self.exact_routes[exact_path] = config
logger.debug(f"Added exact route: {exact_path}")
return
if pattern == "__default__":
self.default_route = config
logger.debug("Added default route")
return
if pattern.startswith("~"):
case_insensitive = pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled_pattern = re.compile(regex_pattern, flags)
self.routes[compiled_pattern] = config
logger.debug(f"Added regex route: {pattern}")
except re.error as e:
logger.error(f"Regex compilation error {pattern}: {e}")
def match(self, path: str) -> Optional[RouteMatch]:
if path in self.exact_routes:
return RouteMatch(self.exact_routes[path])
for pattern, config in self.routes.items():
match = pattern.search(path)
if match:
params = match.groupdict()
return RouteMatch(config, params)
if self.default_route:
return RouteMatch(self.default_route)
return None
class RequestHandler:
def __init__(self, router: Router, static_dir: str = "./static"):
self.router = router
self.static_dir = Path(static_dir)
async def handle(self, request: Request) -> Response:
path = request.url.path
logger.info(f"{request.method} {path}")
route_match = self.router.match(path)
if not route_match:
return PlainTextResponse("404 Not Found", status_code=404)
try:
return await self._process_route(request, route_match)
except Exception as e:
logger.error(f"Request processing error {path}: {e}")
return PlainTextResponse("500 Internal Server Error", status_code=500)
async def _process_route(self, request: Request, route_match: RouteMatch) -> Response:
config = route_match.config
# HINT: Not using it right now
# path = request.url.path
if "return" in config:
status_text = config["return"]
if " " in status_text:
status_code, text = status_text.split(" ", 1)
status_code = int(status_code)
else:
status_code = int(status_text)
text = ""
content_type = config.get("content_type", "text/plain")
return PlainTextResponse(text, status_code=status_code,
media_type=content_type)
if "proxy_pass" in config:
return await self._handle_proxy(request, config, route_match.params)
if "root" in config:
return await self._handle_static(request, config)
if config.get("spa_fallback"):
return await self._handle_spa_fallback(request, config)
return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response:
root = Path(config["root"])
path = request.url.path.lstrip("/")
if not path or path == "/":
index_file = config.get("index_file", "index.html")
file_path = root / index_file
else:
file_path = root / path
try:
file_path = file_path.resolve()
root = root.resolve()
if not str(file_path).startswith(str(root)):
return PlainTextResponse("403 Forbidden", status_code=403)
except OSError:
return PlainTextResponse("404 Not Found", status_code=404)
if not file_path.exists() or not file_path.is_file():
return PlainTextResponse("404 Not Found", status_code=404)
content_type, _ = mimetypes.guess_type(str(file_path))
response = FileResponse(str(file_path), media_type=content_type)
if "headers" in config:
for header in config["headers"]:
if ":" in header:
name, value = header.split(":", 1)
response.headers[name.strip()] = value.strip()
if "cache_control" in config:
response.headers["Cache-Control"] = config["cache_control"]
return response
async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response:
path = request.url.path
exclude_patterns = config.get("exclude_patterns", [])
for pattern in exclude_patterns:
if path.startswith(pattern):
return PlainTextResponse("404 Not Found", status_code=404)
root = Path(config.get("root", self.static_dir))
index_file = config.get("index_file", "index.html")
file_path = root / index_file
if file_path.exists() and file_path.is_file():
return FileResponse(str(file_path), media_type="text/html")
return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
params: Dict[str, str]) -> Response:
# TODO: implement real proxying
proxy_url = config["proxy_pass"]
for key, value in params.items():
proxy_url = proxy_url.replace(f"{{{key}}}", value)
logger.info(f"Proxying request to: {proxy_url}")
return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200)
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:
router = Router()
for pattern, config in regex_locations.items():
router.add_route(pattern, config)
return router