This commit is contained in:
@@ -1,174 +1,297 @@
|
|||||||
"""Parser for OpenAPI specifications."""
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import yaml
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
from openapi_spec_validator import validate as validate_spec
|
from openapi_spec_validator import validate
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from src.core.models import OpenAPISpec
|
from src.core.models import OpenAPISpec, Operation, Parameter, PathItem, Response, Schema
|
||||||
|
|
||||||
|
|
||||||
class ParseError(Exception):
|
|
||||||
"""Exception raised when parsing fails."""
|
|
||||||
|
|
||||||
def __init__(self, message: str, path: Optional[str] = None, line: Optional[int] = None):
|
|
||||||
self.message = message
|
|
||||||
self.path = path
|
|
||||||
self.line = line
|
|
||||||
super().__init__(self._format_message())
|
|
||||||
|
|
||||||
def _format_message(self) -> str:
|
|
||||||
if self.path and self.line:
|
|
||||||
return f"{self.message} (at {self.path}:{self.line})"
|
|
||||||
elif self.path:
|
|
||||||
return f"{self.message} (at {self.path})"
|
|
||||||
return self.message
|
|
||||||
|
|
||||||
|
|
||||||
class SpecValidationError(ParseError):
|
|
||||||
"""Exception raised when spec validation fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIParser:
|
class OpenAPIParser:
|
||||||
"""Parser for OpenAPI 3.0/3.1 specifications."""
|
def __init__(self, spec_data: dict[str, Any]):
|
||||||
|
self.spec_data = spec_data
|
||||||
SUPPORTED_VERSIONS = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
|
self._resolved_refs: dict[str, Any] = {}
|
||||||
|
self._components_schemas: dict[str, Schema] = {}
|
||||||
def __init__(self, spec_path: str):
|
self._components_responses: dict[str, Response] = {}
|
||||||
self.spec_path = Path(spec_path)
|
self._components_request_bodies: dict[str, Any] = {}
|
||||||
|
|
||||||
def load(self) -> Dict[str, Any]:
|
|
||||||
"""Load the spec file and return its contents."""
|
|
||||||
if not self.spec_path.exists():
|
|
||||||
raise ParseError(f"Spec file not found: {self.spec_path}", str(self.spec_path))
|
|
||||||
|
|
||||||
content = self.spec_path.read_text()
|
|
||||||
|
|
||||||
if self.spec_path.suffix.lower() in [".yaml", ".yml"]:
|
|
||||||
try:
|
|
||||||
return yaml.safe_load(content)
|
|
||||||
except yaml.YAMLError as e:
|
|
||||||
raise ParseError(f"Invalid YAML format: {e}", str(self.spec_path))
|
|
||||||
elif self.spec_path.suffix.lower() == ".json":
|
|
||||||
try:
|
|
||||||
return json.loads(content)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ParseError(f"Invalid JSON format: {e}", str(self.spec_path))
|
|
||||||
else:
|
|
||||||
raise ParseError(
|
|
||||||
f"Unsupported file format: {self.spec_path.suffix}. Use .yaml, .yml, or .json",
|
|
||||||
str(self.spec_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate(self, spec_data: Optional[Dict[str, Any]] = None) -> Tuple[bool, List[str]]:
|
|
||||||
"""
|
|
||||||
Validate the OpenAPI specification.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, list of errors)
|
|
||||||
"""
|
|
||||||
if spec_data is None:
|
|
||||||
spec_data = self.load()
|
|
||||||
|
|
||||||
|
def validate(self) -> list[str]:
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validate_spec(spec_data)
|
validate(self.spec_data)
|
||||||
return True, []
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
errors.append(str(e))
|
||||||
errors.append(error_msg)
|
return errors
|
||||||
return False, errors
|
|
||||||
|
|
||||||
def parse(self) -> OpenAPISpec:
|
def parse(self) -> OpenAPISpec:
|
||||||
"""
|
self._extract_components()
|
||||||
Parse and validate the OpenAPI specification.
|
return OpenAPISpec(
|
||||||
|
openapi=self.spec_data.get("openapi", "3.0.0"),
|
||||||
Returns:
|
info=self._parse_info(),
|
||||||
OpenAPISpec object
|
servers=self._parse_servers(),
|
||||||
|
paths=self._parse_paths(),
|
||||||
Raises:
|
components=self._parse_components(),
|
||||||
ParseError: If the spec cannot be parsed
|
security=self.spec_data.get("security"),
|
||||||
SpecValidationError: If the spec is invalid
|
tags=self._parse_tags(),
|
||||||
"""
|
external_docs=self.spec_data.get("externalDocs"),
|
||||||
spec_data = self.load()
|
|
||||||
|
|
||||||
is_valid, errors = self.validate(spec_data)
|
|
||||||
if not is_valid:
|
|
||||||
error_text = "; ".join(errors)
|
|
||||||
raise SpecValidationError(f"Invalid OpenAPI specification: {error_text}", str(self.spec_path))
|
|
||||||
|
|
||||||
try:
|
|
||||||
return OpenAPISpec(**spec_data)
|
|
||||||
except ValidationError as e:
|
|
||||||
error_messages = []
|
|
||||||
for error in e.errors():
|
|
||||||
loc = ".".join(str(l) for l in error["loc"])
|
|
||||||
error_messages.append(f"{loc}: {error['msg']}")
|
|
||||||
raise ParseError(
|
|
||||||
f"Schema validation failed: {'; '.join(error_messages)}",
|
|
||||||
str(self.spec_path)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_with_examples(self) -> Dict[str, Any]:
|
def _extract_components(self) -> None:
|
||||||
"""
|
components = self.spec_data.get("components", {})
|
||||||
Parse the spec and add generated examples.
|
if "schemas" in components:
|
||||||
|
for name, schema_data in components["schemas"].items():
|
||||||
Returns:
|
self._components_schemas[name] = self._parse_schema(schema_data)
|
||||||
Dictionary containing parsed spec with examples
|
if "responses" in components:
|
||||||
"""
|
self._components_responses = components["responses"]
|
||||||
from src.utils.examples import ExampleGenerator
|
if "requestBodies" in components:
|
||||||
|
self._components_request_bodies = components["requestBodies"]
|
||||||
spec = self.parse()
|
|
||||||
generator = ExampleGenerator()
|
|
||||||
|
|
||||||
endpoints = []
|
|
||||||
for endpoint in spec.get_endpoints():
|
|
||||||
endpoint_dict = endpoint.model_dump(mode="json", exclude_none=True)
|
|
||||||
|
|
||||||
if endpoint.requestBody:
|
|
||||||
endpoint_dict["requestBodyExample"] = generator.generate_from_content(
|
|
||||||
endpoint.requestBody.content or {}
|
|
||||||
)
|
|
||||||
|
|
||||||
responses = {}
|
|
||||||
for status_code, response in endpoint.responses.items():
|
|
||||||
response_dict = response.model_dump(mode="json", exclude_none=True)
|
|
||||||
if response.content:
|
|
||||||
response_dict["example"] = generator.generate_from_content(response.content or {})
|
|
||||||
responses[status_code] = response_dict
|
|
||||||
endpoint_dict["responses"] = responses
|
|
||||||
|
|
||||||
endpoints.append(endpoint_dict)
|
|
||||||
|
|
||||||
schemas = {}
|
|
||||||
for name, schema in spec.get_schemas().items():
|
|
||||||
schemas[name] = {
|
|
||||||
"name": name,
|
|
||||||
"schema": schema.model_dump(mode="json", exclude_none=True),
|
|
||||||
"example": generator.generate_example(schema),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
def _parse_info(self) -> dict[str, Any]:
|
||||||
|
info_data = self.spec_data.get("info", {})
|
||||||
|
contact_data = info_data.get("contact", {})
|
||||||
|
license_data = info_data.get("license", {})
|
||||||
return {
|
return {
|
||||||
"spec": spec.model_dump(mode="json", exclude_none=True),
|
"title": info_data.get("title", "API"),
|
||||||
"endpoints": endpoints,
|
"version": info_data.get("version", "1.0.0"),
|
||||||
"schemas": schemas,
|
"description": info_data.get("description"),
|
||||||
"tags": [tag.model_dump(mode="json", exclude_none=True) for tag in spec.get_tags()],
|
"terms_of_service": info_data.get("termsOfService"),
|
||||||
|
"contact": {
|
||||||
|
"name": contact_data.get("name"),
|
||||||
|
"url": contact_data.get("url"),
|
||||||
|
"email": contact_data.get("email"),
|
||||||
|
} if contact_data else None,
|
||||||
|
"license": {
|
||||||
|
"name": license_data.get("name", ""),
|
||||||
|
"url": license_data.get("url"),
|
||||||
|
} if license_data else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _parse_servers(self) -> Optional[list[dict[str, Any]]]:
|
||||||
|
servers = self.spec_data.get("servers", [])
|
||||||
|
return [{"url": s.get("url", "/"), "description": s.get("description")} for s in servers]
|
||||||
|
|
||||||
def parse_spec_file(spec_path: str) -> OpenAPISpec:
|
def _parse_paths(self) -> dict[str, PathItem]:
|
||||||
"""Convenience function to parse an OpenAPI spec file."""
|
paths = {}
|
||||||
parser = OpenAPIParser(spec_path)
|
for path, path_item in self.spec_data.get("paths", {}).items():
|
||||||
|
if path.startswith("/"):
|
||||||
|
path_item_data = path_item if path_item else {}
|
||||||
|
paths[path] = self._parse_path_item(path_item_data)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def _parse_path_item(self, data: dict[str, Any]) -> PathItem:
|
||||||
|
operations = {}
|
||||||
|
for method in ["get", "put", "post", "delete", "options", "head", "patch", "trace"]:
|
||||||
|
if method in data:
|
||||||
|
operations[method] = self._parse_operation(data[method])
|
||||||
|
return PathItem(
|
||||||
|
ref=data.get("$ref"),
|
||||||
|
summary=data.get("summary"),
|
||||||
|
description=data.get("description"),
|
||||||
|
servers=data.get("servers"),
|
||||||
|
parameters=self._parse_parameters(data.get("parameters", [])),
|
||||||
|
**operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_operation(self, data: dict[str, Any]) -> Operation:
|
||||||
|
parameters = data.get("parameters", [])
|
||||||
|
request_body = data.get("requestBody")
|
||||||
|
responses = {}
|
||||||
|
for status_code, response_data in data.get("responses", {}).items():
|
||||||
|
responses[status_code] = self._parse_response(response_data)
|
||||||
|
return Operation(
|
||||||
|
tags=data.get("tags"),
|
||||||
|
summary=data.get("summary"),
|
||||||
|
description=data.get("description"),
|
||||||
|
external_docs=data.get("externalDocs"),
|
||||||
|
operation_id=data.get("operationId"),
|
||||||
|
parameters=self._parse_parameters(parameters),
|
||||||
|
request_body=self._parse_request_body(request_body) if request_body else None,
|
||||||
|
responses=responses,
|
||||||
|
deprecated=data.get("deprecated"),
|
||||||
|
security=data.get("security"),
|
||||||
|
servers=data.get("servers"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_parameters(self, params: list[dict[str, Any]]) -> list[Parameter]:
|
||||||
|
return [
|
||||||
|
Parameter(
|
||||||
|
name=p.get("name", ""),
|
||||||
|
in_=p.get("in", "query"),
|
||||||
|
description=p.get("description"),
|
||||||
|
required=p.get("required"),
|
||||||
|
deprecated=p.get("deprecated"),
|
||||||
|
allow_empty_value=p.get("allowEmptyValue"),
|
||||||
|
style=p.get("style"),
|
||||||
|
explode=p.get("explode"),
|
||||||
|
allow_reserved=p.get("allowReserved"),
|
||||||
|
schema=self._parse_schema(p.get("schema")) if p.get("schema") else None,
|
||||||
|
example=p.get("example"),
|
||||||
|
examples=p.get("examples"),
|
||||||
|
)
|
||||||
|
for p in params
|
||||||
|
]
|
||||||
|
|
||||||
|
def _parse_response(self, data: dict[str, Any]) -> Response:
|
||||||
|
content = {}
|
||||||
|
for content_type, content_data in data.get("content", {}).items():
|
||||||
|
content[content_type] = {
|
||||||
|
"schema": self._parse_schema(content_data.get("schema"))
|
||||||
|
if content_data.get("schema") else None,
|
||||||
|
"example": content_data.get("example"),
|
||||||
|
"examples": content_data.get("examples"),
|
||||||
|
}
|
||||||
|
return Response(
|
||||||
|
description=data.get("description", ""),
|
||||||
|
content=content,
|
||||||
|
headers=data.get("headers"),
|
||||||
|
links=data.get("links"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_request_body(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
content = {}
|
||||||
|
for content_type, content_data in data.get("content", {}).items():
|
||||||
|
content[content_type] = {
|
||||||
|
"schema": self._parse_schema(content_data.get("schema"))
|
||||||
|
if content_data.get("schema") else None,
|
||||||
|
"example": content_data.get("example"),
|
||||||
|
"examples": content_data.get("examples"),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"description": data.get("description"),
|
||||||
|
"required": data.get("required"),
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_schema(self, data: Any) -> Optional[Schema]:
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "$ref" in data:
|
||||||
|
ref = data["$ref"]
|
||||||
|
resolved = self._resolve_ref(ref)
|
||||||
|
if resolved:
|
||||||
|
return self._parse_schema(resolved)
|
||||||
|
schema_data = dict(data)
|
||||||
|
for key in ["allOf", "anyOf", "oneOf", "not"]:
|
||||||
|
if key in schema_data:
|
||||||
|
nested = schema_data[key]
|
||||||
|
if isinstance(nested, list):
|
||||||
|
schema_data[key] = [
|
||||||
|
self._parse_schema(item) if isinstance(item, dict) else item
|
||||||
|
for item in nested
|
||||||
|
]
|
||||||
|
elif isinstance(nested, dict):
|
||||||
|
schema_data[key] = self._parse_schema(nested)
|
||||||
|
if "items" in schema_data and isinstance(schema_data["items"], dict):
|
||||||
|
schema_data["items"] = self._parse_schema(schema_data["items"])
|
||||||
|
if "properties" in schema_data:
|
||||||
|
schema_data["properties"] = {
|
||||||
|
k: self._parse_schema(v) if isinstance(v, dict) else v
|
||||||
|
for k, v in schema_data["properties"].items()
|
||||||
|
}
|
||||||
|
if "additionalProperties" in schema_data and isinstance(
|
||||||
|
schema_data["additionalProperties"], dict
|
||||||
|
):
|
||||||
|
schema_data["additionalProperties"] = self._parse_schema(
|
||||||
|
schema_data["additionalProperties"]
|
||||||
|
)
|
||||||
|
return Schema(**schema_data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_ref(self, ref: str) -> Optional[dict[str, Any]]:
|
||||||
|
if ref in self._resolved_refs:
|
||||||
|
return self._resolved_refs[ref]
|
||||||
|
if ref.startswith("#/components/"):
|
||||||
|
parts = ref.split("/")[2:]
|
||||||
|
current = self.spec_data.get("components", {})
|
||||||
|
for part in parts:
|
||||||
|
if isinstance(current, dict) and part in current:
|
||||||
|
current = current[part]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
self._resolved_refs[ref] = current
|
||||||
|
return current
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_components(self) -> Optional[dict[str, Any]]:
|
||||||
|
components = self.spec_data.get("components")
|
||||||
|
if not components:
|
||||||
|
return None
|
||||||
|
security_schemes = {}
|
||||||
|
for name, scheme in components.get("securitySchemes", {}).items():
|
||||||
|
security_schemes[name] = {
|
||||||
|
"type": scheme.get("type"),
|
||||||
|
"scheme": scheme.get("scheme"),
|
||||||
|
"bearer_format": scheme.get("bearerFormat"),
|
||||||
|
"flows": scheme.get("flows"),
|
||||||
|
"open_id_connect_url": scheme.get("openIdConnectUrl"),
|
||||||
|
"description": scheme.get("description"),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"schemas": self._components_schemas,
|
||||||
|
"responses": self._components_responses,
|
||||||
|
"parameters": components.get("parameters"),
|
||||||
|
"request_bodies": self._components_request_bodies,
|
||||||
|
"headers": components.get("headers"),
|
||||||
|
"security_schemes": security_schemes,
|
||||||
|
"links": components.get("links"),
|
||||||
|
"callbacks": components.get("callbacks"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_tags(self) -> Optional[list[dict[str, Any]]]:
|
||||||
|
tags = self.spec_data.get("tags", [])
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": t.get("name"),
|
||||||
|
"description": t.get("description"),
|
||||||
|
"external_docs": t.get("externalDocs"),
|
||||||
|
}
|
||||||
|
for t in tags
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _basic_validate(spec_data: dict[str, Any]) -> tuple:
|
||||||
|
errors = []
|
||||||
|
if not isinstance(spec_data, dict):
|
||||||
|
errors.append("Spec must be a dictionary")
|
||||||
|
return False, errors
|
||||||
|
if "openapi" not in spec_data:
|
||||||
|
errors.append("Missing 'openapi' version")
|
||||||
|
return False, errors
|
||||||
|
if "info" not in spec_data:
|
||||||
|
errors.append("Missing 'info' object")
|
||||||
|
return False, errors
|
||||||
|
info = spec_data.get("info", {})
|
||||||
|
if not isinstance(info, dict):
|
||||||
|
errors.append("'info' must be an object")
|
||||||
|
return False, errors
|
||||||
|
if "title" not in info:
|
||||||
|
errors.append("Missing 'info.title'")
|
||||||
|
return False, errors
|
||||||
|
if "version" not in info:
|
||||||
|
errors.append("Missing 'info.version'")
|
||||||
|
return False, errors
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
|
||||||
|
def parse_openapi_spec(spec_source: str | Path | dict[str, Any]) -> OpenAPISpec:
|
||||||
|
if isinstance(spec_source, dict):
|
||||||
|
spec_data = spec_source
|
||||||
|
elif isinstance(spec_source, Path):
|
||||||
|
spec_data = _load_file(spec_source)
|
||||||
|
else:
|
||||||
|
spec_data = _load_file(Path(spec_source))
|
||||||
|
parser = OpenAPIParser(spec_data)
|
||||||
|
errors = parser.validate()
|
||||||
|
if errors:
|
||||||
|
raise ValueError(f"Invalid OpenAPI spec: {errors}")
|
||||||
return parser.parse()
|
return parser.parse()
|
||||||
|
|
||||||
|
|
||||||
def validate_spec_file(spec_path: str) -> Tuple[bool, List[str]]:
|
def _load_file(path: Path) -> dict[str, Any]:
|
||||||
"""Convenience function to validate an OpenAPI spec file."""
|
content = path.read_text()
|
||||||
parser = OpenAPIParser(spec_path)
|
if path.suffix in [".yaml", ".yml"]:
|
||||||
return parser.validate()
|
import yaml
|
||||||
|
|
||||||
|
return yaml.safe_load(content)
|
||||||
|
return json.loads(content)
|
||||||
|
|||||||
Reference in New Issue
Block a user