127 lines
3.6 KiB
Python
127 lines
3.6 KiB
Python
"""Request validation middleware."""
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
|
|
|
|
class RequestValidationError(Exception):
|
|
"""Raised when request validation fails."""
|
|
|
|
def __init__(self, errors: List[Dict[str, Any]]):
|
|
self.errors = errors
|
|
super().__init__(str(errors))
|
|
|
|
|
|
class RequestValidator:
|
|
"""Validates incoming requests against OpenAPI spec."""
|
|
|
|
def __init__(self, spec: Dict[str, Any], strict: bool = False):
|
|
"""Initialize the request validator.
|
|
|
|
Args:
|
|
spec: OpenAPI specification
|
|
strict: Enable strict validation mode
|
|
"""
|
|
self.spec = spec
|
|
self.strict = strict
|
|
|
|
def validate_request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
query_params: Dict[str, Any],
|
|
headers: Dict[str, str],
|
|
body: Optional[Any] = None,
|
|
) -> Tuple[bool, List[Dict[str, Any]]]:
|
|
"""Validate a request against the spec.
|
|
|
|
Args:
|
|
method: HTTP method
|
|
path: Request path
|
|
query_params: Query parameters
|
|
headers: Request headers
|
|
body: Request body
|
|
|
|
Returns:
|
|
Tuple of (is_valid, errors)
|
|
"""
|
|
errors = []
|
|
|
|
path_params = self._extract_path_params(method, path)
|
|
if path_params:
|
|
errors.extend(path_params)
|
|
|
|
query_errors = self._validate_query_params(method, path, query_params)
|
|
if query_errors:
|
|
errors.extend(query_errors)
|
|
|
|
header_errors = self._validate_headers(method, path, headers)
|
|
if header_errors:
|
|
errors.extend(header_errors)
|
|
|
|
body_errors = self._validate_body(method, path, body)
|
|
if body_errors:
|
|
errors.extend(body_errors)
|
|
|
|
return len(errors) == 0, errors
|
|
|
|
def _extract_path_params(
|
|
self, method: str, path: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Extract and validate path parameters."""
|
|
return []
|
|
|
|
def _validate_query_params(
|
|
self, method: str, path: str, query_params: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Validate query parameters against spec."""
|
|
return []
|
|
|
|
def _validate_headers(
|
|
self, method: str, path: str, headers: Dict[str, str]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Validate headers against spec."""
|
|
return []
|
|
|
|
def _validate_body(
|
|
self, method: str, path: str, body: Optional[Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Validate request body against spec."""
|
|
return []
|
|
|
|
|
|
class ValidationMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that validates requests."""
|
|
|
|
def __init__(self, app, validator: RequestValidator):
|
|
"""Initialize the validation middleware.
|
|
|
|
Args:
|
|
app: The ASGI application
|
|
validator: RequestValidator instance
|
|
"""
|
|
super().__init__(app)
|
|
self.validator = validator
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""Validate request and process."""
|
|
is_valid, errors = self.validator.validate_request(
|
|
method=request.method,
|
|
path=str(request.url.path),
|
|
query_params=dict(request.query_params),
|
|
headers=dict(request.headers),
|
|
)
|
|
|
|
if not is_valid:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"error": "Validation failed",
|
|
"details": errors,
|
|
},
|
|
)
|
|
|
|
return await call_next(request) |