348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""FastAPI server with dynamic route generation from OpenAPI specs."""
|
|
|
|
import re
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
from fastapi import FastAPI, Request, Response, Query, Path
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from openapi_mock.core.spec_parser import (
|
|
load_spec,
|
|
extract_paths,
|
|
extract_schemas,
|
|
extract_path_params,
|
|
get_operation_id,
|
|
get_response_schema,
|
|
)
|
|
from openapi_mock.generators.data_generator import DataGenerator
|
|
|
|
|
|
class ResponseDelayMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to add configurable response delays."""
|
|
|
|
def __init__(self, app: FastAPI, delay_range: Optional[tuple] = None):
|
|
"""Initialize the delay middleware.
|
|
|
|
Args:
|
|
app: FastAPI application.
|
|
delay_range: Optional tuple of (min_delay, max_delay) in seconds.
|
|
"""
|
|
super().__init__(app)
|
|
self.delay_range = delay_range
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
"""Add delay before processing request.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
call_next: The next middleware/handler.
|
|
|
|
Returns:
|
|
The response.
|
|
"""
|
|
import asyncio
|
|
import random
|
|
|
|
if self.delay_range:
|
|
min_delay, max_delay = self.delay_range
|
|
delay = random.uniform(min_delay, max_delay)
|
|
await asyncio.sleep(delay)
|
|
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for API authentication simulation."""
|
|
|
|
def __init__(
|
|
self,
|
|
app: FastAPI,
|
|
auth_type: str = 'none',
|
|
api_keys: Optional[List[str]] = None
|
|
):
|
|
"""Initialize the auth middleware.
|
|
|
|
Args:
|
|
app: FastAPI application.
|
|
auth_type: Type of authentication (none, bearer, api_key, basic).
|
|
api_keys: List of valid API keys.
|
|
"""
|
|
super().__init__(app)
|
|
self.auth_type = auth_type
|
|
self.api_keys = api_keys or []
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
"""Validate authentication credentials.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
call_next: The next middleware/handler.
|
|
|
|
Returns:
|
|
The response or error.
|
|
"""
|
|
if self.auth_type == 'none':
|
|
return await call_next(request)
|
|
|
|
auth_header = request.headers.get('Authorization', '')
|
|
api_key = request.headers.get('X-API-Key', '')
|
|
|
|
if self.auth_type == 'bearer':
|
|
if not auth_header.startswith('Bearer '):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Missing or invalid Authorization header'}
|
|
)
|
|
token = auth_header[7:]
|
|
if not token:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Invalid token'}
|
|
)
|
|
|
|
elif self.auth_type == 'api_key':
|
|
if not api_key and not auth_header.startswith('ApiKey '):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Missing API key'}
|
|
)
|
|
key = api_key or auth_header[7:]
|
|
if key not in self.api_keys:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Invalid API key'}
|
|
)
|
|
|
|
elif self.auth_type == 'basic':
|
|
if not auth_header.startswith('Basic '):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Missing or invalid Authorization header'}
|
|
)
|
|
import base64
|
|
try:
|
|
decoded = base64.b64decode(auth_header[6:]).decode('utf-8')
|
|
username, password = decoded.split(':', 1)
|
|
if not self._validate_basic_credentials(username, password):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Invalid credentials'}
|
|
)
|
|
except Exception:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={'error': 'Invalid authorization format'}
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
def _validate_basic_credentials(self, username: str, password: str) -> bool:
|
|
"""Validate basic auth credentials.
|
|
|
|
Args:
|
|
username: Username from Basic auth.
|
|
password: Password from Basic auth.
|
|
|
|
Returns:
|
|
True if credentials are valid.
|
|
"""
|
|
return bool(username and password)
|
|
|
|
|
|
def _convert_param_name(name: str) -> str:
|
|
"""Convert parameter name to valid Python identifier.
|
|
|
|
Args:
|
|
name: Parameter name from OpenAPI spec.
|
|
|
|
Returns:
|
|
Valid Python identifier.
|
|
"""
|
|
return re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
|
|
|
|
class OpenAPIMockServer:
|
|
"""Mock server generated from OpenAPI specification."""
|
|
|
|
def __init__(
|
|
self,
|
|
spec_path: str,
|
|
delay_range: Optional[tuple] = None,
|
|
auth_type: str = 'none',
|
|
auth_config: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""Initialize the mock server.
|
|
|
|
Args:
|
|
spec_path: Path to the OpenAPI specification file.
|
|
delay_range: Optional response delay range (min, max) in seconds.
|
|
auth_type: Type of authentication to simulate.
|
|
auth_config: Additional authentication configuration.
|
|
"""
|
|
self.spec_path = spec_path
|
|
self.spec = load_spec(spec_path)
|
|
self.delay_range = delay_range
|
|
self.auth_type = auth_type
|
|
self.auth_config = auth_config or {}
|
|
|
|
self.app = FastAPI(
|
|
title=self.spec.get('info', {}).get('title', 'Mock API'),
|
|
version=self.spec.get('info', {}).get('version', '1.0.0'),
|
|
description=self.spec.get('info', {}).get('description', ''),
|
|
openapi_version=self.spec.get('openapi', '3.0.0'),
|
|
)
|
|
|
|
self._setup_middleware()
|
|
self._setup_routes()
|
|
|
|
def _setup_middleware(self) -> None:
|
|
"""Set up middleware for delay and authentication."""
|
|
if self.delay_range:
|
|
self.app.add_middleware(ResponseDelayMiddleware, delay_range=self.delay_range)
|
|
|
|
if self.auth_type != 'none':
|
|
api_keys = self.auth_config.get('api_keys', [])
|
|
self.app.add_middleware(
|
|
AuthMiddleware,
|
|
auth_type=self.auth_type,
|
|
api_keys=api_keys
|
|
)
|
|
|
|
def _setup_routes(self) -> None:
|
|
"""Set up routes from the OpenAPI spec."""
|
|
paths = extract_paths(self.spec)
|
|
schemas = extract_schemas(self.spec)
|
|
self.data_generator = DataGenerator()
|
|
self.data_generator.set_ref_cache(schemas)
|
|
|
|
for path, path_item in paths.items():
|
|
path_params = extract_path_params(path)
|
|
|
|
for method, operation in path_item.items():
|
|
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch', 'options', 'head']:
|
|
continue
|
|
|
|
operation_id = get_operation_id(path, method)
|
|
summary = operation.get('summary', '')
|
|
description = operation.get('description', '')
|
|
tags = operation.get('tags', [])
|
|
|
|
response_schema = get_response_schema(self.spec, path, method.lower())
|
|
|
|
self._create_route_handler(
|
|
path=path,
|
|
method=method.lower(),
|
|
path_params=path_params,
|
|
operation=operation,
|
|
response_schema=response_schema,
|
|
operation_id=operation_id,
|
|
summary=summary,
|
|
description=description,
|
|
tags=tags
|
|
)
|
|
|
|
def _create_route_handler(
|
|
self,
|
|
path: str,
|
|
method: str,
|
|
path_params: List[str],
|
|
operation: Dict[str, Any],
|
|
response_schema: Optional[Dict[str, Any]],
|
|
operation_id: str,
|
|
summary: str,
|
|
description: str,
|
|
tags: List[str]
|
|
) -> None:
|
|
"""Create and register a route handler.
|
|
|
|
Args:
|
|
path: API path.
|
|
method: HTTP method.
|
|
path_params: List of path parameter names.
|
|
operation: Operation definition from spec.
|
|
response_schema: Response schema.
|
|
operation_id: Generated operation ID.
|
|
summary: Operation summary.
|
|
description: Operation description.
|
|
tags: Operation tags.
|
|
"""
|
|
parameters = operation.get('parameters', [])
|
|
|
|
param_types: Dict[str, Any] = {}
|
|
|
|
for param in parameters:
|
|
param_name = param.get('name', '')
|
|
param_in = param.get('in', 'query')
|
|
param_schema = param.get('schema', {})
|
|
param_type = param_schema.get('type', 'string')
|
|
param_default = param_schema.get('default')
|
|
|
|
if param_in == 'path':
|
|
param_types[param_name] = (Path, param_type, param_default)
|
|
elif param_in == 'query':
|
|
param_types[param_name] = (Query, param_type, param_default)
|
|
|
|
def create_handler(
|
|
method: str,
|
|
response_schema: Optional[Dict[str, Any]],
|
|
param_types: Dict[str, Any]
|
|
):
|
|
async def handler(request: Request) -> Dict[str, Any]:
|
|
if response_schema:
|
|
return self.data_generator.generate(response_schema)
|
|
return {'message': 'No schema defined'}
|
|
|
|
handler.__name__ = f"{method}_{path.replace('/', '_').replace('-', '_')}"
|
|
return handler
|
|
|
|
handler = create_handler(method, response_schema, param_types)
|
|
|
|
decorator_kwargs = {
|
|
'summary': summary or f"{method.upper()} {path}",
|
|
'description': description,
|
|
'tags': tags if tags else None,
|
|
'operation_id': operation_id,
|
|
}
|
|
|
|
if method == 'get':
|
|
self.app.get(path, **decorator_kwargs)(handler)
|
|
elif method == 'post':
|
|
self.app.post(path, **decorator_kwargs)(handler)
|
|
elif method == 'put':
|
|
self.app.put(path, **decorator_kwargs)(handler)
|
|
elif method == 'delete':
|
|
self.app.delete(path, **decorator_kwargs)(handler)
|
|
elif method == 'patch':
|
|
self.app.patch(path, **decorator_kwargs)(handler)
|
|
elif method == 'options':
|
|
self.app.options(path, **decorator_kwargs)(handler)
|
|
elif method == 'head':
|
|
self.app.head(path, **decorator_kwargs)(handler)
|
|
|
|
|
|
def create_app(
|
|
spec_path: str,
|
|
delay_range: Optional[tuple] = None,
|
|
auth_type: str = 'none',
|
|
auth_config: Optional[Dict[str, Any]] = None
|
|
) -> FastAPI:
|
|
"""Create a FastAPI application from an OpenAPI specification.
|
|
|
|
Args:
|
|
spec_path: Path to the OpenAPI specification file.
|
|
delay_range: Optional response delay range (min, max) in seconds.
|
|
auth_type: Type of authentication to simulate.
|
|
auth_config: Additional authentication configuration.
|
|
|
|
Returns:
|
|
Configured FastAPI application.
|
|
"""
|
|
server = OpenAPIMockServer(
|
|
spec_path=spec_path,
|
|
delay_range=delay_range,
|
|
auth_type=auth_type,
|
|
auth_config=auth_config
|
|
)
|
|
return server.app
|