This commit is contained in:
467
src/core/server.py
Normal file
467
src/core/server.py
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
"""HTTP Mock Server for serving mock API responses."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from re import Pattern
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from src.core.generator import ResponseGenerator
|
||||||
|
from src.core.parser import OpenAPIParser
|
||||||
|
from src.core.validator import RequestValidator
|
||||||
|
|
||||||
|
|
||||||
|
class MockServerError(Exception):
|
||||||
|
"""Base exception for mock server errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PortInUseError(MockServerError):
|
||||||
|
"""Raised when the specified port is already in use."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RouteNotFoundError(MockServerError):
|
||||||
|
"""Raised when no route matches the request."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockServer:
|
||||||
|
"""HTTP mock server that serves responses based on OpenAPI specifications."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spec_file: str,
|
||||||
|
port: int = 8080,
|
||||||
|
delay: int = 0,
|
||||||
|
fuzzing: bool = False,
|
||||||
|
seed: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the mock server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec_file: Path to OpenAPI specification file.
|
||||||
|
port: Port to listen on.
|
||||||
|
delay: Response delay in milliseconds.
|
||||||
|
fuzzing: Enable fuzzing mode for edge case testing.
|
||||||
|
seed: Random seed for reproducible responses.
|
||||||
|
"""
|
||||||
|
self.parser = OpenAPIParser(spec_file)
|
||||||
|
self.port = port
|
||||||
|
self.delay = delay
|
||||||
|
self.fuzzing = fuzzing
|
||||||
|
self.seed = seed
|
||||||
|
self.validator: RequestValidator | None = None
|
||||||
|
self.generator = ResponseGenerator(seed=seed)
|
||||||
|
self._route_cache: dict[str, tuple[dict[str, Any], Pattern, list[str]]] = {}
|
||||||
|
|
||||||
|
self._setup_routes()
|
||||||
|
|
||||||
|
def _setup_routes(self) -> None:
|
||||||
|
"""Compile route patterns for efficient matching."""
|
||||||
|
paths = self.parser.get_paths()
|
||||||
|
|
||||||
|
for path, path_item in paths.items():
|
||||||
|
pattern = self._path_to_regex(path)
|
||||||
|
if pattern:
|
||||||
|
self._route_cache[path] = (path_item, pattern, self._extract_param_names(path))
|
||||||
|
|
||||||
|
def _path_to_regex(self, path: str) -> Pattern[str] | None:
|
||||||
|
"""Convert OpenAPI path to regex pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: OpenAPI path with {param} placeholders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compiled regex pattern.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
regex_path = re.sub(r"\{([^}]+)\}", r"(?P<\1>[^/]+)", path)
|
||||||
|
return re.compile(f"^{regex_path}$")
|
||||||
|
except re.error:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_param_names(self, path: str) -> list[str]:
|
||||||
|
"""Extract parameter names from path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: OpenAPI path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of parameter names.
|
||||||
|
"""
|
||||||
|
return re.findall(r"\{([^}]+)\}", path)
|
||||||
|
|
||||||
|
def load_spec(self) -> dict[str, Any]:
|
||||||
|
"""Load and validate the OpenAPI specification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded specification.
|
||||||
|
"""
|
||||||
|
spec = self.parser.load()
|
||||||
|
self.parser.validate_spec()
|
||||||
|
self.validator = RequestValidator(spec)
|
||||||
|
self.generator.set_resolved_schemas(self.parser.get_schemas())
|
||||||
|
return spec
|
||||||
|
|
||||||
|
def _parse_request_body(
|
||||||
|
self, content_type: str | None, body: bytes
|
||||||
|
) -> Any:
|
||||||
|
"""Parse request body based on content type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_type: Content-Type header value.
|
||||||
|
body: Raw request body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed body or None.
|
||||||
|
"""
|
||||||
|
if not body:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if content_type and "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
return json.loads(body.decode("utf-8"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_delay(self) -> float:
|
||||||
|
"""Get the response delay.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delay in seconds.
|
||||||
|
"""
|
||||||
|
if self.delay <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if isinstance(self.delay, tuple):
|
||||||
|
min_ms, max_ms = self.delay
|
||||||
|
return random.uniform(min_ms, max_ms) / 1000.0
|
||||||
|
|
||||||
|
return self.delay / 1000.0
|
||||||
|
|
||||||
|
def _find_route(
|
||||||
|
self, path: str, method: str
|
||||||
|
) -> tuple[dict[str, Any], dict[str, str]]:
|
||||||
|
"""Find matching route for the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_path: Request path.
|
||||||
|
method: HTTP method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (path_item, path_params) or raises RouteNotFoundError.
|
||||||
|
"""
|
||||||
|
for route_path, (path_item, pattern, param_names) in self._route_cache.items():
|
||||||
|
match = pattern.match(path)
|
||||||
|
if match:
|
||||||
|
path_params = match.groupdict()
|
||||||
|
method_lower = method.lower()
|
||||||
|
if method_lower in path_item:
|
||||||
|
return path_item, path_params
|
||||||
|
|
||||||
|
raise RouteNotFoundError(f"No route found for {method.upper()} {path}")
|
||||||
|
|
||||||
|
def _get_response_schema(
|
||||||
|
self, path_item: dict[str, Any], method: str, status_code: str = "200"
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Get response schema for the operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_item: Path item from OpenAPI spec.
|
||||||
|
method: HTTP method.
|
||||||
|
status_code: Response status code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response schema or None.
|
||||||
|
"""
|
||||||
|
operation = path_item.get(method.lower(), {})
|
||||||
|
responses = operation.get("responses", {})
|
||||||
|
|
||||||
|
if status_code not in responses:
|
||||||
|
return None
|
||||||
|
|
||||||
|
response = responses[status_code]
|
||||||
|
content = response.get("content", {})
|
||||||
|
json_content = content.get("application/json", {})
|
||||||
|
|
||||||
|
return json_content.get("schema")
|
||||||
|
|
||||||
|
def _handle_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
headers: dict[str, Any],
|
||||||
|
query_params: dict[str, Any],
|
||||||
|
body: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle an incoming request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method.
|
||||||
|
path: Request path.
|
||||||
|
headers: Request headers.
|
||||||
|
query_params: Query parameters.
|
||||||
|
body: Request body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dictionary.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path_item, path_params = self._find_route(path, method)
|
||||||
|
except RouteNotFoundError:
|
||||||
|
return {
|
||||||
|
"status_code": 404,
|
||||||
|
"body": {
|
||||||
|
"error": {
|
||||||
|
"type": "Not Found",
|
||||||
|
"message": f"No route found for {method.upper()} {path}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = path_item.get(method.lower(), {})
|
||||||
|
|
||||||
|
if self.validator:
|
||||||
|
validation_error = self.validator.validate_request(
|
||||||
|
method, path, headers, query_params, path_params, body, operation
|
||||||
|
)
|
||||||
|
if validation_error:
|
||||||
|
return self.validator.create_error_response(validation_error)
|
||||||
|
|
||||||
|
status_code = 200
|
||||||
|
response_schema = self._get_response_schema(path_item, method)
|
||||||
|
|
||||||
|
if self.fuzzing:
|
||||||
|
from src.core.fuzzer import Fuzzer
|
||||||
|
fuzzer = Fuzzer()
|
||||||
|
if response_schema:
|
||||||
|
body = fuzzer.fuzz_schema(response_schema)
|
||||||
|
elif response_schema:
|
||||||
|
body = self.generator.generate(response_schema)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status_code": status_code,
|
||||||
|
"body": body,
|
||||||
|
}
|
||||||
|
|
||||||
|
def start(self, blocking: bool = True) -> None:
|
||||||
|
"""Start the mock server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocking: Whether to block the main thread.
|
||||||
|
"""
|
||||||
|
self.load_spec()
|
||||||
|
|
||||||
|
server_address = ("", self.port)
|
||||||
|
handler = self._create_request_handler()
|
||||||
|
|
||||||
|
try:
|
||||||
|
httpd = HTTPServer(server_address, handler)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == 98:
|
||||||
|
raise PortInUseError(f"Port {self.port} is already in use") from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"Mock server starting on port {self.port}...")
|
||||||
|
print(f"Fuzzing mode: {'enabled' if self.fuzzing else 'disabled'}")
|
||||||
|
|
||||||
|
if blocking:
|
||||||
|
try:
|
||||||
|
httpd.serve_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nShutting down mock server...")
|
||||||
|
httpd.shutdown()
|
||||||
|
|
||||||
|
def _create_request_handler(self) -> type:
|
||||||
|
"""Create the HTTP request handler class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Request handler class.
|
||||||
|
"""
|
||||||
|
delay = self.delay
|
||||||
|
fuzzing = self.fuzzing
|
||||||
|
generator = self.generator
|
||||||
|
validator = self.validator
|
||||||
|
route_cache = self._route_cache
|
||||||
|
|
||||||
|
class RequestHandler(BaseHTTPRequestHandler):
|
||||||
|
def log_message(self, format: str, *args: Any) -> None:
|
||||||
|
"""Suppress default logging."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_method(self, method: str) -> None:
|
||||||
|
"""Handle any HTTP method."""
|
||||||
|
parsed_url = urlparse(self.path)
|
||||||
|
request_path = parsed_url.path
|
||||||
|
query_string = parsed_url.query
|
||||||
|
|
||||||
|
query_params = parse_qs(query_string) if query_string else {}
|
||||||
|
|
||||||
|
content_length = int(self.headers.get("Content-Length", 0))
|
||||||
|
body = self.rfile.read(content_length) if content_length > 0 else b""
|
||||||
|
|
||||||
|
content_type = self.headers.get("Content-Type")
|
||||||
|
parsed_body = self._parse_request_body(content_type, body)
|
||||||
|
|
||||||
|
request_headers: dict[str, Any] = {}
|
||||||
|
for header_name, header_value in self.headers.items():
|
||||||
|
request_headers[header_name.lower()] = header_value
|
||||||
|
|
||||||
|
delay_time = self._get_delay() if delay else 0
|
||||||
|
if delay_time > 0:
|
||||||
|
time.sleep(delay_time)
|
||||||
|
|
||||||
|
response = self._handle_request(
|
||||||
|
method, request_path, request_headers, query_params, parsed_body
|
||||||
|
)
|
||||||
|
|
||||||
|
self._send_response(response)
|
||||||
|
|
||||||
|
def _parse_request_body(self, content_type: str | None, body: bytes) -> Any:
|
||||||
|
"""Parse request body."""
|
||||||
|
if not body:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if content_type and "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
return json.loads(body.decode("utf-8"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_delay(self) -> float:
|
||||||
|
"""Get response delay."""
|
||||||
|
if delay <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if isinstance(delay, tuple):
|
||||||
|
min_ms, max_ms = delay
|
||||||
|
return random.uniform(min_ms, max_ms) / 1000.0
|
||||||
|
|
||||||
|
return delay / 1000.0
|
||||||
|
|
||||||
|
def _handle_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
headers: dict[str, Any],
|
||||||
|
query_params: dict[str, Any],
|
||||||
|
body: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle request."""
|
||||||
|
try:
|
||||||
|
for route_path, (path_item, pattern, param_names) in route_cache.items():
|
||||||
|
match = pattern.match(path)
|
||||||
|
if match:
|
||||||
|
path_params = match.groupdict()
|
||||||
|
method_lower = method.lower()
|
||||||
|
if method_lower in path_item:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status_code": 404,
|
||||||
|
"body": {
|
||||||
|
"error": {
|
||||||
|
"type": "Not Found",
|
||||||
|
"message": f"No route found for {method.upper()} {path}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = path_item.get(method_lower, {})
|
||||||
|
|
||||||
|
if validator:
|
||||||
|
validation_error = validator.validate_request(
|
||||||
|
method, path, headers, query_params, path_params, body, operation
|
||||||
|
)
|
||||||
|
if validation_error:
|
||||||
|
return validator.create_error_response(validation_error)
|
||||||
|
|
||||||
|
response_schema = self._get_response_schema(path_item, method_lower)
|
||||||
|
|
||||||
|
if fuzzing:
|
||||||
|
from src.core.fuzzer import Fuzzer
|
||||||
|
fuzzer = Fuzzer()
|
||||||
|
if response_schema:
|
||||||
|
body = fuzzer.fuzz_schema(response_schema)
|
||||||
|
elif response_schema:
|
||||||
|
body = generator.generate(response_schema)
|
||||||
|
else:
|
||||||
|
body = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status_code": 200,
|
||||||
|
"body": body,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status_code": 500,
|
||||||
|
"body": {
|
||||||
|
"error": {
|
||||||
|
"type": "Internal Server Error",
|
||||||
|
"message": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_response_schema(
|
||||||
|
self, path_item: dict[str, Any], method: str, status_code: str = "200"
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Get response schema."""
|
||||||
|
operation = path_item.get(method, {})
|
||||||
|
responses = operation.get("responses", {})
|
||||||
|
|
||||||
|
if status_code not in responses:
|
||||||
|
return None
|
||||||
|
|
||||||
|
response = responses[status_code]
|
||||||
|
content = response.get("content", {})
|
||||||
|
json_content = content.get("application/json", {})
|
||||||
|
|
||||||
|
return json_content.get("schema")
|
||||||
|
|
||||||
|
def _send_response(self, response: dict[str, Any]) -> None:
|
||||||
|
"""Send HTTP response."""
|
||||||
|
status_code = response.get("status_code", 200)
|
||||||
|
body = response.get("body")
|
||||||
|
|
||||||
|
self.send_response(status_code)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
if body is not None:
|
||||||
|
body_json = json.dumps(body, indent=2)
|
||||||
|
self.wfile.write(body_json.encode("utf-8"))
|
||||||
|
|
||||||
|
def do_GET(self) -> None:
|
||||||
|
"""Handle GET requests."""
|
||||||
|
self.do_method("GET")
|
||||||
|
|
||||||
|
def do_POST(self) -> None:
|
||||||
|
"""Handle POST requests."""
|
||||||
|
self.do_method("POST")
|
||||||
|
|
||||||
|
def do_PUT(self) -> None:
|
||||||
|
"""Handle PUT requests."""
|
||||||
|
self.do_method("PUT")
|
||||||
|
|
||||||
|
def do_PATCH(self) -> None:
|
||||||
|
"""Handle PATCH requests."""
|
||||||
|
self.do_method("PATCH")
|
||||||
|
|
||||||
|
def do_DELETE(self) -> None:
|
||||||
|
"""Handle DELETE requests."""
|
||||||
|
self.do_method("DELETE")
|
||||||
|
|
||||||
|
return RequestHandler
|
||||||
Reference in New Issue
Block a user