Files
mockapi/src/mockapi/core/request_validator.py
7000pctAUTO ddaba43ae7
Some checks failed
CI / test (push) Failing after 12s
CI / build (push) Has been skipped
Add all source files for CI to work
2026-03-22 21:22:57 +00:00

127 lines
3.6 KiB
Python

"""Request validation middleware."""
from typing import Any, Dict, List, Optional, Tuple
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
class RequestValidationError(Exception):
"""Raised when request validation fails."""
def __init__(self, errors: List[Dict[str, Any]]):
self.errors = errors
super().__init__(str(errors))
class RequestValidator:
"""Validates incoming requests against OpenAPI spec."""
def __init__(self, spec: Dict[str, Any], strict: bool = False):
"""Initialize the request validator.
Args:
spec: OpenAPI specification
strict: Enable strict validation mode
"""
self.spec = spec
self.strict = strict
def validate_request(
self,
method: str,
path: str,
query_params: Dict[str, Any],
headers: Dict[str, str],
body: Optional[Any] = None,
) -> Tuple[bool, List[Dict[str, Any]]]:
"""Validate a request against the spec.
Args:
method: HTTP method
path: Request path
query_params: Query parameters
headers: Request headers
body: Request body
Returns:
Tuple of (is_valid, errors)
"""
errors = []
path_params = self._extract_path_params(method, path)
if path_params:
errors.extend(path_params)
query_errors = self._validate_query_params(method, path, query_params)
if query_errors:
errors.extend(query_errors)
header_errors = self._validate_headers(method, path, headers)
if header_errors:
errors.extend(header_errors)
body_errors = self._validate_body(method, path, body)
if body_errors:
errors.extend(body_errors)
return len(errors) == 0, errors
def _extract_path_params(
self, method: str, path: str
) -> List[Dict[str, Any]]:
"""Extract and validate path parameters."""
return []
def _validate_query_params(
self, method: str, path: str, query_params: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Validate query parameters against spec."""
return []
def _validate_headers(
self, method: str, path: str, headers: Dict[str, str]
) -> List[Dict[str, Any]]:
"""Validate headers against spec."""
return []
def _validate_body(
self, method: str, path: str, body: Optional[Any]
) -> List[Dict[str, Any]]:
"""Validate request body against spec."""
return []
class ValidationMiddleware(BaseHTTPMiddleware):
"""Middleware that validates requests."""
def __init__(self, app, validator: RequestValidator):
"""Initialize the validation middleware.
Args:
app: The ASGI application
validator: RequestValidator instance
"""
super().__init__(app)
self.validator = validator
async def dispatch(self, request: Request, call_next):
"""Validate request and process."""
is_valid, errors = self.validator.validate_request(
method=request.method,
path=str(request.url.path),
query_params=dict(request.query_params),
headers=dict(request.headers),
)
if not is_valid:
return JSONResponse(
status_code=400,
content={
"error": "Validation failed",
"details": errors,
},
)
return await call_next(request)