fix: Update max-line-length in .flake8 and refactor routing and server code for improved readability and functionality

This commit is contained in:
Илья Глазунов 2025-09-02 14:23:01 +03:00
parent 84cd1c974f
commit 6b157d7626
3 changed files with 113 additions and 93 deletions

View File

@ -1,4 +1,4 @@
[flake8] [flake8]
max-line-length = 100 max-line-length = 120
exclude = __pycache__,.git,.venv,venv,build,dist exclude = __pycache__,.git,.venv,venv,build,dist
ignore = E203,W503 ignore = E203,W503

View File

@ -8,34 +8,36 @@ from .logging_utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
class RouteMatch: class RouteMatch:
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None): def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
self.config = config self.config = config
self.params = params or {} self.params = params or {}
class Router: class Router:
def __init__(self, static_dir: str = "./static"): def __init__(self, static_dir: str = "./static"):
self.static_dir = Path(static_dir) self.static_dir = Path(static_dir)
self.routes: Dict[Pattern, Dict[str, Any]] = {} self.routes: Dict[Pattern, Dict[str, Any]] = {}
self.exact_routes: Dict[str, Dict[str, Any]] = {} self.exact_routes: Dict[str, Dict[str, Any]] = {}
self.default_route: Optional[Dict[str, Any]] = None self.default_route: Optional[Dict[str, Any]] = None
def add_route(self, pattern: str, config: Dict[str, Any]) -> None: def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
if pattern.startswith("="): if pattern.startswith("="):
exact_path = pattern[1:] exact_path = pattern[1:]
self.exact_routes[exact_path] = config self.exact_routes[exact_path] = config
logger.debug(f"Added exact route: {exact_path}") logger.debug(f"Added exact route: {exact_path}")
return return
if pattern == "__default__": if pattern == "__default__":
self.default_route = config self.default_route = config
logger.debug("Added default route") logger.debug("Added default route")
return return
if pattern.startswith("~"): if pattern.startswith("~"):
case_insensitive = pattern.startswith("~*") case_insensitive = pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:] regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
flags = re.IGNORECASE if case_insensitive else 0 flags = re.IGNORECASE if case_insensitive else 0
try: try:
compiled_pattern = re.compile(regex_pattern, flags) compiled_pattern = re.compile(regex_pattern, flags)
@ -43,47 +45,48 @@ class Router:
logger.debug(f"Added regex route: {pattern}") logger.debug(f"Added regex route: {pattern}")
except re.error as e: except re.error as e:
logger.error(f"Regex compilation error {pattern}: {e}") logger.error(f"Regex compilation error {pattern}: {e}")
def match(self, path: str) -> Optional[RouteMatch]: def match(self, path: str) -> Optional[RouteMatch]:
if path in self.exact_routes: if path in self.exact_routes:
return RouteMatch(self.exact_routes[path]) return RouteMatch(self.exact_routes[path])
for pattern, config in self.routes.items(): for pattern, config in self.routes.items():
match = pattern.search(path) match = pattern.search(path)
if match: if match:
params = match.groupdict() params = match.groupdict()
return RouteMatch(config, params) return RouteMatch(config, params)
if self.default_route: if self.default_route:
return RouteMatch(self.default_route) return RouteMatch(self.default_route)
return None return None
class RequestHandler: class RequestHandler:
def __init__(self, router: Router, static_dir: str = "./static"): def __init__(self, router: Router, static_dir: str = "./static"):
self.router = router self.router = router
self.static_dir = Path(static_dir) self.static_dir = Path(static_dir)
async def handle(self, request: Request) -> Response: async def handle(self, request: Request) -> Response:
path = request.url.path path = request.url.path
logger.info(f"{request.method} {path}") logger.info(f"{request.method} {path}")
route_match = self.router.match(path) route_match = self.router.match(path)
if not route_match: if not route_match:
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
try: try:
return await self._process_route(request, route_match) return await self._process_route(request, route_match)
except Exception as e: except Exception as e:
logger.error(f"Request processing error {path}: {e}") logger.error(f"Request processing error {path}: {e}")
return PlainTextResponse("500 Internal Server Error", status_code=500) return PlainTextResponse("500 Internal Server Error", status_code=500)
async def _process_route(self, request: Request, route_match: RouteMatch) -> Response: async def _process_route(self, request: Request, route_match: RouteMatch) -> Response:
config = route_match.config config = route_match.config
path = request.url.path # HINT: Not using it right now
# path = request.url.path
if "return" in config: if "return" in config:
status_text = config["return"] status_text = config["return"]
if " " in status_text: if " " in status_text:
@ -92,32 +95,32 @@ class RequestHandler:
else: else:
status_code = int(status_text) status_code = int(status_text)
text = "" text = ""
content_type = config.get("content_type", "text/plain") content_type = config.get("content_type", "text/plain")
return PlainTextResponse(text, status_code=status_code, return PlainTextResponse(text, status_code=status_code,
media_type=content_type) media_type=content_type)
if "proxy_pass" in config: if "proxy_pass" in config:
return await self._handle_proxy(request, config, route_match.params) return await self._handle_proxy(request, config, route_match.params)
if "root" in config: if "root" in config:
return await self._handle_static(request, config) return await self._handle_static(request, config)
if config.get("spa_fallback"): if config.get("spa_fallback"):
return await self._handle_spa_fallback(request, config) return await self._handle_spa_fallback(request, config)
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response: async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response:
root = Path(config["root"]) root = Path(config["root"])
path = request.url.path.lstrip("/") path = request.url.path.lstrip("/")
if not path or path == "/": if not path or path == "/":
index_file = config.get("index_file", "index.html") index_file = config.get("index_file", "index.html")
file_path = root / index_file file_path = root / index_file
else: else:
file_path = root / path file_path = root / path
try: try:
file_path = file_path.resolve() file_path = file_path.resolve()
root = root.resolve() root = root.resolve()
@ -125,59 +128,59 @@ class RequestHandler:
return PlainTextResponse("403 Forbidden", status_code=403) return PlainTextResponse("403 Forbidden", status_code=403)
except OSError: except OSError:
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
if not file_path.exists() or not file_path.is_file(): if not file_path.exists() or not file_path.is_file():
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
content_type, _ = mimetypes.guess_type(str(file_path)) content_type, _ = mimetypes.guess_type(str(file_path))
response = FileResponse(str(file_path), media_type=content_type) response = FileResponse(str(file_path), media_type=content_type)
if "headers" in config: if "headers" in config:
for header in config["headers"]: for header in config["headers"]:
if ":" in header: if ":" in header:
name, value = header.split(":", 1) name, value = header.split(":", 1)
response.headers[name.strip()] = value.strip() response.headers[name.strip()] = value.strip()
if "cache_control" in config: if "cache_control" in config:
response.headers["Cache-Control"] = config["cache_control"] response.headers["Cache-Control"] = config["cache_control"]
return response return response
async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response: async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response:
path = request.url.path path = request.url.path
exclude_patterns = config.get("exclude_patterns", []) exclude_patterns = config.get("exclude_patterns", [])
for pattern in exclude_patterns: for pattern in exclude_patterns:
if path.startswith(pattern): if path.startswith(pattern):
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
root = Path(config.get("root", self.static_dir)) root = Path(config.get("root", self.static_dir))
index_file = config.get("index_file", "index.html") index_file = config.get("index_file", "index.html")
file_path = root / index_file file_path = root / index_file
if file_path.exists() and file_path.is_file(): if file_path.exists() and file_path.is_file():
return FileResponse(str(file_path), media_type="text/html") return FileResponse(str(file_path), media_type="text/html")
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_proxy(self, request: Request, config: Dict[str, Any], async def _handle_proxy(self, request: Request, config: Dict[str, Any],
params: Dict[str, str]) -> Response: params: Dict[str, str]) -> Response:
# TODO: Реализовать полноценное проксирование # TODO: Реализовать полноценное проксирование
proxy_url = config["proxy_pass"] proxy_url = config["proxy_pass"]
for key, value in params.items(): for key, value in params.items():
proxy_url = proxy_url.replace(f"{{{key}}}", value) proxy_url = proxy_url.replace(f"{{{key}}}", value)
logger.info(f"Proxying request to: {proxy_url}") logger.info(f"Proxying request to: {proxy_url}")
return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200) return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200)
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router: def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:
router = Router() router = Router()
for pattern, config in regex_locations.items(): for pattern, config in regex_locations.items():
router.add_route(pattern, config) router.add_route(pattern, config)
return router return router

View File

@ -4,8 +4,8 @@ import time
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse from starlette.responses import Response, PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route from starlette.routing import Route
from starlette.types import ASGIApp, Receive, Scope, Send
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
@ -17,22 +17,28 @@ from . import __version__
logger = get_logger(__name__) logger = get_logger(__name__)
class PyServeMiddleware(BaseHTTPMiddleware): class PyServeMiddleware:
def __init__(self, app, extension_manager: ExtensionManager): def __init__(self, app: ASGIApp, extension_manager: ExtensionManager):
super().__init__(app) self.app = app
self.extension_manager = extension_manager self.extension_manager = extension_manager
self.access_logger = get_logger('pyserve.access') self.access_logger = get_logger('pyserve.access')
async def dispatch(self, request: Request, call_next): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time() start_time = time.time()
request = Request(scope, receive)
response = await self.extension_manager.process_request(request) response = await self.extension_manager.process_request(request)
if response is None: if response is None:
response = await call_next(request) await self.app(scope, receive, send)
return
response = await self.extension_manager.process_response(request, response) response = await self.extension_manager.process_response(request, response)
response.headers["Server"] = f"pyserve/{__version__}" response.headers["Server"] = f"pyserve/{__version__}"
client_ip = request.client.host if request.client else "unknown" client_ip = request.client.host if request.client else "unknown"
method = request.method method = request.method
path = str(request.url.path) path = str(request.url.path)
@ -41,13 +47,13 @@ class PyServeMiddleware(BaseHTTPMiddleware):
path += f"?{query}" path += f"?{query}"
status_code = response.status_code status_code = response.status_code
process_time = round((time.time() - start_time) * 1000, 2) process_time = round((time.time() - start_time) * 1000, 2)
self.access_logger.info(f"{client_ip} - {method} {path} - {status_code} - {process_time}ms") self.access_logger.info(f"{client_ip} - {method} {path} - {status_code} - {process_time}ms")
return response await response(scope, receive, send)
class PyServeServer: class PyServeServer:
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
self.extension_manager = ExtensionManager() self.extension_manager = ExtensionManager()
@ -55,34 +61,45 @@ class PyServeServer:
self._setup_logging() self._setup_logging()
self._load_extensions() self._load_extensions()
self._create_app() self._create_app()
def _setup_logging(self) -> None: def _setup_logging(self) -> None:
self.config.setup_logging() self.config.setup_logging()
logger.info("PyServe сервер инициализирован") logger.info("PyServe server initialized")
def _load_extensions(self) -> None: def _load_extensions(self) -> None:
for ext_config in self.config.extensions: for ext_config in self.config.extensions:
self.extension_manager.load_extension( self.extension_manager.load_extension(
ext_config.type, ext_config.type,
ext_config.config ext_config.config
) )
def _create_app(self) -> None: def _create_app(self) -> None:
routes = [ routes = [
Route("/health", self._health_check, methods=["GET"]), Route("/health", self._health_check, methods=["GET"]),
Route("/metrics", self._metrics, methods=["GET"]), Route("/metrics", self._metrics, methods=["GET"]),
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]), Route(
"/{path:path}",
self._catch_all,
methods=[
"GET",
"POST",
"PUT",
"DELETE",
"PATCH",
"OPTIONS"
]
),
] ]
self.app = Starlette(routes=routes) self.app = Starlette(routes=routes)
self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager) self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager)
async def _health_check(self, request: Request) -> Response: async def _health_check(self, request: Request) -> Response:
return PlainTextResponse("OK", status_code=200) return PlainTextResponse("OK", status_code=200)
async def _metrics(self, request: Request) -> Response: async def _metrics(self, request: Request) -> Response:
metrics = {} metrics = {}
for extension in self.extension_manager.extensions: for extension in self.extension_manager.extensions:
if hasattr(extension, 'get_metrics'): if hasattr(extension, 'get_metrics'):
try: try:
@ -96,22 +113,22 @@ class PyServeServer:
json.dumps(metrics, ensure_ascii=False, indent=2), json.dumps(metrics, ensure_ascii=False, indent=2),
media_type="application/json" media_type="application/json"
) )
async def _catch_all(self, request: Request) -> Response: async def _catch_all(self, request: Request) -> Response:
return PlainTextResponse("404 Not Found", status_code=404) return PlainTextResponse("404 Not Found", status_code=404)
def _create_ssl_context(self) -> Optional[ssl.SSLContext]: def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
if not self.config.ssl.enabled: if not self.config.ssl.enabled:
return None return None
if not Path(self.config.ssl.cert_file).exists(): if not Path(self.config.ssl.cert_file).exists():
logger.error(f"SSL certificate not found: {self.config.ssl.cert_file}") logger.error(f"SSL certificate not found: {self.config.ssl.cert_file}")
return None return None
if not Path(self.config.ssl.key_file).exists(): if not Path(self.config.ssl.key_file).exists():
logger.error(f"SSL key not found: {self.config.ssl.key_file}") logger.error(f"SSL key not found: {self.config.ssl.key_file}")
return None return None
try: try:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain( context.load_cert_chain(
@ -123,17 +140,16 @@ class PyServeServer:
except Exception as e: except Exception as e:
logger.error(f"Error creating SSL context: {e}") logger.error(f"Error creating SSL context: {e}")
return None return None
def run(self) -> None: def run(self) -> None:
if not self.config.validate(): if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started") logger.error("Configuration is invalid, server cannot be started")
return return
self._ensure_directories() self._ensure_directories()
ssl_context = self._create_ssl_context() ssl_context = self._create_ssl_context()
uvicorn_config = { uvicorn_config: Dict[str, Any] = {
"app": self.app,
"host": self.config.server.host, "host": self.config.server.host,
"port": self.config.server.port, "port": self.config.server.port,
"log_level": "critical", "log_level": "critical",
@ -141,7 +157,7 @@ class PyServeServer:
"use_colors": False, "use_colors": False,
"server_header": False, "server_header": False,
} }
if ssl_context: if ssl_context:
uvicorn_config.update({ uvicorn_config.update({
"ssl_keyfile": self.config.ssl.key_file, "ssl_keyfile": self.config.ssl.key_file,
@ -154,21 +170,22 @@ class PyServeServer:
logger.info(f"Starting PyServe server at {protocol}://{self.config.server.host}:{self.config.server.port}") logger.info(f"Starting PyServe server at {protocol}://{self.config.server.host}:{self.config.server.port}")
try: try:
uvicorn.run(**uvicorn_config) assert self.app is not None, "App not initialized"
uvicorn.run(self.app, **uvicorn_config)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Received shutdown signal") logger.info("Received shutdown signal")
except Exception as e: except Exception as e:
logger.error(f"Error starting server: {e}") logger.error(f"Error starting server: {e}")
finally: finally:
self.shutdown() self.shutdown()
async def run_async(self) -> None: async def run_async(self) -> None:
if not self.config.validate(): if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started") logger.error("Configuration is invalid, server cannot be started")
return return
self._ensure_directories() self._ensure_directories()
config = uvicorn.Config( config = uvicorn.Config(
app=self.app, # type: ignore app=self.app, # type: ignore
host=self.config.server.host, host=self.config.server.host,
@ -177,24 +194,24 @@ class PyServeServer:
access_log=False, access_log=False,
use_colors=False, use_colors=False,
) )
server = uvicorn.Server(config) server = uvicorn.Server(config)
try: try:
await server.serve() await server.serve()
finally: finally:
self.shutdown() self.shutdown()
def _ensure_directories(self) -> None: def _ensure_directories(self) -> None:
directories = [ directories = [
self.config.http.static_dir, self.config.http.static_dir,
self.config.http.templates_dir, self.config.http.templates_dir,
] ]
log_dir = Path(self.config.logging.log_file).parent log_dir = Path(self.config.logging.log_file).parent
if log_dir != Path("."): if log_dir != Path("."):
directories.append(str(log_dir)) directories.append(str(log_dir))
for directory in directories: for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True) Path(directory).mkdir(parents=True, exist_ok=True)
logger.debug(f"Created/checked directory: {directory}") logger.debug(f"Created/checked directory: {directory}")
@ -202,7 +219,7 @@ class PyServeServer:
def shutdown(self) -> None: def shutdown(self) -> None:
logger.info("Shutting down PyServe server") logger.info("Shutting down PyServe server")
self.extension_manager.cleanup() self.extension_manager.cleanup()
from .logging_utils import shutdown_logging from .logging_utils import shutdown_logging
shutdown_logging() shutdown_logging()
@ -210,10 +227,10 @@ class PyServeServer:
def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None: def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
self.extension_manager.load_extension(extension_type, config) self.extension_manager.load_extension(extension_type, config)
def get_metrics(self) -> Dict[str, Any]: def get_metrics(self) -> Dict[str, Any]:
metrics = {"server_status": "running"} metrics = {"server_status": "running"}
for extension in self.extension_manager.extensions: for extension in self.extension_manager.extensions:
if hasattr(extension, 'get_metrics'): if hasattr(extension, 'get_metrics'):
try: try: