This commit is contained in:
467
src/core/server.py
Normal file
467
src/core/server.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""HTTP Mock Server for serving mock API responses."""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from re import Pattern
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from src.core.generator import ResponseGenerator
|
||||
from src.core.parser import OpenAPIParser
|
||||
from src.core.validator import RequestValidator
|
||||
|
||||
|
||||
class MockServerError(Exception):
|
||||
"""Base exception for mock server errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PortInUseError(MockServerError):
|
||||
"""Raised when the specified port is already in use."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RouteNotFoundError(MockServerError):
|
||||
"""Raised when no route matches the request."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MockServer:
|
||||
"""HTTP mock server that serves responses based on OpenAPI specifications."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_file: str,
|
||||
port: int = 8080,
|
||||
delay: int = 0,
|
||||
fuzzing: bool = False,
|
||||
seed: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the mock server.
|
||||
|
||||
Args:
|
||||
spec_file: Path to OpenAPI specification file.
|
||||
port: Port to listen on.
|
||||
delay: Response delay in milliseconds.
|
||||
fuzzing: Enable fuzzing mode for edge case testing.
|
||||
seed: Random seed for reproducible responses.
|
||||
"""
|
||||
self.parser = OpenAPIParser(spec_file)
|
||||
self.port = port
|
||||
self.delay = delay
|
||||
self.fuzzing = fuzzing
|
||||
self.seed = seed
|
||||
self.validator: RequestValidator | None = None
|
||||
self.generator = ResponseGenerator(seed=seed)
|
||||
self._route_cache: dict[str, tuple[dict[str, Any], Pattern, list[str]]] = {}
|
||||
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
"""Compile route patterns for efficient matching."""
|
||||
paths = self.parser.get_paths()
|
||||
|
||||
for path, path_item in paths.items():
|
||||
pattern = self._path_to_regex(path)
|
||||
if pattern:
|
||||
self._route_cache[path] = (path_item, pattern, self._extract_param_names(path))
|
||||
|
||||
def _path_to_regex(self, path: str) -> Pattern[str] | None:
|
||||
"""Convert OpenAPI path to regex pattern.
|
||||
|
||||
Args:
|
||||
path: OpenAPI path with {param} placeholders.
|
||||
|
||||
Returns:
|
||||
Compiled regex pattern.
|
||||
"""
|
||||
try:
|
||||
regex_path = re.sub(r"\{([^}]+)\}", r"(?P<\1>[^/]+)", path)
|
||||
return re.compile(f"^{regex_path}$")
|
||||
except re.error:
|
||||
return None
|
||||
|
||||
def _extract_param_names(self, path: str) -> list[str]:
|
||||
"""Extract parameter names from path.
|
||||
|
||||
Args:
|
||||
path: OpenAPI path.
|
||||
|
||||
Returns:
|
||||
List of parameter names.
|
||||
"""
|
||||
return re.findall(r"\{([^}]+)\}", path)
|
||||
|
||||
def load_spec(self) -> dict[str, Any]:
|
||||
"""Load and validate the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
The loaded specification.
|
||||
"""
|
||||
spec = self.parser.load()
|
||||
self.parser.validate_spec()
|
||||
self.validator = RequestValidator(spec)
|
||||
self.generator.set_resolved_schemas(self.parser.get_schemas())
|
||||
return spec
|
||||
|
||||
def _parse_request_body(
|
||||
self, content_type: str | None, body: bytes
|
||||
) -> Any:
|
||||
"""Parse request body based on content type.
|
||||
|
||||
Args:
|
||||
content_type: Content-Type header value.
|
||||
body: Raw request body.
|
||||
|
||||
Returns:
|
||||
Parsed body or None.
|
||||
"""
|
||||
if not body:
|
||||
return None
|
||||
|
||||
if content_type and "application/json" in content_type:
|
||||
try:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _get_delay(self) -> float:
|
||||
"""Get the response delay.
|
||||
|
||||
Returns:
|
||||
Delay in seconds.
|
||||
"""
|
||||
if self.delay <= 0:
|
||||
return 0.0
|
||||
|
||||
if isinstance(self.delay, tuple):
|
||||
min_ms, max_ms = self.delay
|
||||
return random.uniform(min_ms, max_ms) / 1000.0
|
||||
|
||||
return self.delay / 1000.0
|
||||
|
||||
def _find_route(
|
||||
self, path: str, method: str
|
||||
) -> tuple[dict[str, Any], dict[str, str]]:
|
||||
"""Find matching route for the request.
|
||||
|
||||
Args:
|
||||
request_path: Request path.
|
||||
method: HTTP method.
|
||||
|
||||
Returns:
|
||||
Tuple of (path_item, path_params) or raises RouteNotFoundError.
|
||||
"""
|
||||
for route_path, (path_item, pattern, param_names) in self._route_cache.items():
|
||||
match = pattern.match(path)
|
||||
if match:
|
||||
path_params = match.groupdict()
|
||||
method_lower = method.lower()
|
||||
if method_lower in path_item:
|
||||
return path_item, path_params
|
||||
|
||||
raise RouteNotFoundError(f"No route found for {method.upper()} {path}")
|
||||
|
||||
def _get_response_schema(
|
||||
self, path_item: dict[str, Any], method: str, status_code: str = "200"
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get response schema for the operation.
|
||||
|
||||
Args:
|
||||
path_item: Path item from OpenAPI spec.
|
||||
method: HTTP method.
|
||||
status_code: Response status code.
|
||||
|
||||
Returns:
|
||||
Response schema or None.
|
||||
"""
|
||||
operation = path_item.get(method.lower(), {})
|
||||
responses = operation.get("responses", {})
|
||||
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
response = responses[status_code]
|
||||
content = response.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
|
||||
return json_content.get("schema")
|
||||
|
||||
def _handle_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, Any],
|
||||
query_params: dict[str, Any],
|
||||
body: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Handle an incoming request.
|
||||
|
||||
Args:
|
||||
method: HTTP method.
|
||||
path: Request path.
|
||||
headers: Request headers.
|
||||
query_params: Query parameters.
|
||||
body: Request body.
|
||||
|
||||
Returns:
|
||||
Response dictionary.
|
||||
"""
|
||||
try:
|
||||
path_item, path_params = self._find_route(path, method)
|
||||
except RouteNotFoundError:
|
||||
return {
|
||||
"status_code": 404,
|
||||
"body": {
|
||||
"error": {
|
||||
"type": "Not Found",
|
||||
"message": f"No route found for {method.upper()} {path}",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
operation = path_item.get(method.lower(), {})
|
||||
|
||||
if self.validator:
|
||||
validation_error = self.validator.validate_request(
|
||||
method, path, headers, query_params, path_params, body, operation
|
||||
)
|
||||
if validation_error:
|
||||
return self.validator.create_error_response(validation_error)
|
||||
|
||||
status_code = 200
|
||||
response_schema = self._get_response_schema(path_item, method)
|
||||
|
||||
if self.fuzzing:
|
||||
from src.core.fuzzer import Fuzzer
|
||||
fuzzer = Fuzzer()
|
||||
if response_schema:
|
||||
body = fuzzer.fuzz_schema(response_schema)
|
||||
elif response_schema:
|
||||
body = self.generator.generate(response_schema)
|
||||
|
||||
return {
|
||||
"status_code": status_code,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
def start(self, blocking: bool = True) -> None:
|
||||
"""Start the mock server.
|
||||
|
||||
Args:
|
||||
blocking: Whether to block the main thread.
|
||||
"""
|
||||
self.load_spec()
|
||||
|
||||
server_address = ("", self.port)
|
||||
handler = self._create_request_handler()
|
||||
|
||||
try:
|
||||
httpd = HTTPServer(server_address, handler)
|
||||
except OSError as e:
|
||||
if e.errno == 98:
|
||||
raise PortInUseError(f"Port {self.port} is already in use") from e
|
||||
raise
|
||||
|
||||
print(f"Mock server starting on port {self.port}...")
|
||||
print(f"Fuzzing mode: {'enabled' if self.fuzzing else 'disabled'}")
|
||||
|
||||
if blocking:
|
||||
try:
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down mock server...")
|
||||
httpd.shutdown()
|
||||
|
||||
def _create_request_handler(self) -> type:
|
||||
"""Create the HTTP request handler class.
|
||||
|
||||
Returns:
|
||||
Request handler class.
|
||||
"""
|
||||
delay = self.delay
|
||||
fuzzing = self.fuzzing
|
||||
generator = self.generator
|
||||
validator = self.validator
|
||||
route_cache = self._route_cache
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
"""Suppress default logging."""
|
||||
pass
|
||||
|
||||
def do_method(self, method: str) -> None:
|
||||
"""Handle any HTTP method."""
|
||||
parsed_url = urlparse(self.path)
|
||||
request_path = parsed_url.path
|
||||
query_string = parsed_url.query
|
||||
|
||||
query_params = parse_qs(query_string) if query_string else {}
|
||||
|
||||
content_length = int(self.headers.get("Content-Length", 0))
|
||||
body = self.rfile.read(content_length) if content_length > 0 else b""
|
||||
|
||||
content_type = self.headers.get("Content-Type")
|
||||
parsed_body = self._parse_request_body(content_type, body)
|
||||
|
||||
request_headers: dict[str, Any] = {}
|
||||
for header_name, header_value in self.headers.items():
|
||||
request_headers[header_name.lower()] = header_value
|
||||
|
||||
delay_time = self._get_delay() if delay else 0
|
||||
if delay_time > 0:
|
||||
time.sleep(delay_time)
|
||||
|
||||
response = self._handle_request(
|
||||
method, request_path, request_headers, query_params, parsed_body
|
||||
)
|
||||
|
||||
self._send_response(response)
|
||||
|
||||
def _parse_request_body(self, content_type: str | None, body: bytes) -> Any:
|
||||
"""Parse request body."""
|
||||
if not body:
|
||||
return None
|
||||
|
||||
if content_type and "application/json" in content_type:
|
||||
try:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _get_delay(self) -> float:
|
||||
"""Get response delay."""
|
||||
if delay <= 0:
|
||||
return 0.0
|
||||
|
||||
if isinstance(delay, tuple):
|
||||
min_ms, max_ms = delay
|
||||
return random.uniform(min_ms, max_ms) / 1000.0
|
||||
|
||||
return delay / 1000.0
|
||||
|
||||
def _handle_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, Any],
|
||||
query_params: dict[str, Any],
|
||||
body: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Handle request."""
|
||||
try:
|
||||
for route_path, (path_item, pattern, param_names) in route_cache.items():
|
||||
match = pattern.match(path)
|
||||
if match:
|
||||
path_params = match.groupdict()
|
||||
method_lower = method.lower()
|
||||
if method_lower in path_item:
|
||||
break
|
||||
else:
|
||||
return {
|
||||
"status_code": 404,
|
||||
"body": {
|
||||
"error": {
|
||||
"type": "Not Found",
|
||||
"message": f"No route found for {method.upper()} {path}",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
operation = path_item.get(method_lower, {})
|
||||
|
||||
if validator:
|
||||
validation_error = validator.validate_request(
|
||||
method, path, headers, query_params, path_params, body, operation
|
||||
)
|
||||
if validation_error:
|
||||
return validator.create_error_response(validation_error)
|
||||
|
||||
response_schema = self._get_response_schema(path_item, method_lower)
|
||||
|
||||
if fuzzing:
|
||||
from src.core.fuzzer import Fuzzer
|
||||
fuzzer = Fuzzer()
|
||||
if response_schema:
|
||||
body = fuzzer.fuzz_schema(response_schema)
|
||||
elif response_schema:
|
||||
body = generator.generate(response_schema)
|
||||
else:
|
||||
body = None
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status_code": 500,
|
||||
"body": {
|
||||
"error": {
|
||||
"type": "Internal Server Error",
|
||||
"message": str(e),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _get_response_schema(
|
||||
self, path_item: dict[str, Any], method: str, status_code: str = "200"
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get response schema."""
|
||||
operation = path_item.get(method, {})
|
||||
responses = operation.get("responses", {})
|
||||
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
response = responses[status_code]
|
||||
content = response.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
|
||||
return json_content.get("schema")
|
||||
|
||||
def _send_response(self, response: dict[str, Any]) -> None:
|
||||
"""Send HTTP response."""
|
||||
status_code = response.get("status_code", 200)
|
||||
body = response.get("body")
|
||||
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
if body is not None:
|
||||
body_json = json.dumps(body, indent=2)
|
||||
self.wfile.write(body_json.encode("utf-8"))
|
||||
|
||||
def do_GET(self) -> None:
|
||||
"""Handle GET requests."""
|
||||
self.do_method("GET")
|
||||
|
||||
def do_POST(self) -> None:
|
||||
"""Handle POST requests."""
|
||||
self.do_method("POST")
|
||||
|
||||
def do_PUT(self) -> None:
|
||||
"""Handle PUT requests."""
|
||||
self.do_method("PUT")
|
||||
|
||||
def do_PATCH(self) -> None:
|
||||
"""Handle PATCH requests."""
|
||||
self.do_method("PATCH")
|
||||
|
||||
def do_DELETE(self) -> None:
|
||||
"""Handle DELETE requests."""
|
||||
self.do_method("DELETE")
|
||||
|
||||
return RequestHandler
|
||||
Reference in New Issue
Block a user