This commit is contained in:
324
src/core/parser.py
Normal file
324
src/core/parser.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""OpenAPI Specification Parser."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
from openapi_spec_validator import validate
|
||||
|
||||
|
||||
class OpenAPIParserError(Exception):
|
||||
"""Base exception for OpenAPI parser errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidOpenAPIFormat(OpenAPIParserError):
|
||||
"""Raised when the OpenAPI specification format is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedOpenAPIVersion(OpenAPIParserError):
|
||||
"""Raised when the OpenAPI version is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAPIParser:
|
||||
"""Parser for OpenAPI specifications (3.0, 3.1, and Swagger 2.0)."""
|
||||
|
||||
SUPPORTED_VERSIONS = {
|
||||
"3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.0.4",
|
||||
"3.1.0", "3.1.1",
|
||||
}
|
||||
|
||||
def __init__(self, spec_file: str) -> None:
|
||||
"""Initialize the parser with a spec file path.
|
||||
|
||||
Args:
|
||||
spec_file: Path to the OpenAPI specification file (YAML or JSON).
|
||||
"""
|
||||
self.spec_file = Path(spec_file)
|
||||
self.spec: dict[str, Any] | None = None
|
||||
self.version: str | None = None
|
||||
|
||||
def load(self) -> dict[str, Any]:
|
||||
"""Load and parse the OpenAPI specification from file.
|
||||
|
||||
Returns:
|
||||
The parsed OpenAPI specification as a dictionary.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the spec file does not exist.
|
||||
InvalidOpenAPIFormat: If the file format is invalid.
|
||||
"""
|
||||
if not self.spec_file.exists():
|
||||
raise FileNotFoundError(f"OpenAPI spec file not found: {self.spec_file}")
|
||||
|
||||
content = self._read_file()
|
||||
self.spec = self._parse_content(content)
|
||||
self.version = self._extract_version(self.spec)
|
||||
|
||||
return self.spec
|
||||
|
||||
def _read_file(self) -> str:
|
||||
"""Read the spec file content.
|
||||
|
||||
Returns:
|
||||
The file content as a string.
|
||||
|
||||
Raises:
|
||||
InvalidOpenAPIFormat: If the file cannot be read.
|
||||
"""
|
||||
try:
|
||||
return self.spec_file.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
raise InvalidOpenAPIFormat(f"Failed to read spec file: {e}") from e
|
||||
|
||||
def _parse_content(self, content: str) -> dict[str, Any]:
|
||||
"""Parse the content based on file extension.
|
||||
|
||||
Args:
|
||||
content: The raw file content.
|
||||
|
||||
Returns:
|
||||
The parsed specification as a dictionary.
|
||||
|
||||
Raises:
|
||||
InvalidOpenAPIFormat: If the content cannot be parsed.
|
||||
"""
|
||||
suffix = self.spec_file.suffix.lower()
|
||||
|
||||
try:
|
||||
if suffix in {".yaml", ".yml"}:
|
||||
return yaml.safe_load(content)
|
||||
elif suffix == ".json":
|
||||
return json.loads(content)
|
||||
else:
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
except yaml.YAMLError:
|
||||
return json.loads(content)
|
||||
except (json.JSONDecodeError, yaml.YAMLError) as e:
|
||||
raise InvalidOpenAPIFormat(f"Failed to parse spec file: {e}") from e
|
||||
|
||||
def _extract_version(self, spec: dict[str, Any]) -> str:
|
||||
"""Extract the OpenAPI version from the spec.
|
||||
|
||||
Args:
|
||||
spec: The parsed OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
The OpenAPI version string.
|
||||
|
||||
Raises:
|
||||
UnsupportedOpenAPIVersion: If the version is not supported.
|
||||
"""
|
||||
openapi_version = spec.get("openapi", "")
|
||||
if not openapi_version:
|
||||
raise InvalidOpenAPIFormat("Missing 'openapi' version field")
|
||||
|
||||
if openapi_version not in self.SUPPORTED_VERSIONS:
|
||||
raise UnsupportedOpenAPIVersion(
|
||||
f"Unsupported OpenAPI version: {openapi_version}. "
|
||||
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
|
||||
)
|
||||
|
||||
return openapi_version
|
||||
|
||||
def validate_spec(self) -> list[str]:
|
||||
"""Validate the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
List of validation errors (empty if valid).
|
||||
|
||||
Raises:
|
||||
InvalidOpenAPIFormat: If validation fails.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
try:
|
||||
validate(self.spec)
|
||||
return []
|
||||
except Exception as e:
|
||||
raise InvalidOpenAPIFormat(f"Specification validation failed: {e}") from e
|
||||
|
||||
def get_paths(self) -> dict[str, Any]:
|
||||
"""Extract all paths from the specification.
|
||||
|
||||
Returns:
|
||||
Dictionary of path strings to path item objects.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
return self.spec.get("paths", {})
|
||||
|
||||
def get_schemas(self) -> dict[str, Any]:
|
||||
"""Extract all schemas from components.
|
||||
|
||||
Returns:
|
||||
Dictionary of schema names to schema objects.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
components = self.spec.get("components", {})
|
||||
return components.get("schemas", {})
|
||||
|
||||
def get_definitions(self) -> dict[str, Any]:
|
||||
"""Extract all definitions (Swagger 2.0).
|
||||
|
||||
Returns:
|
||||
Dictionary of definition names to schema objects.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
return self.spec.get("definitions", {})
|
||||
|
||||
def get_request_body_schema(
|
||||
self, path: str, method: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get the request body schema for a specific path and method.
|
||||
|
||||
Args:
|
||||
path: The API path.
|
||||
method: The HTTP method (get, post, put, delete, etc.).
|
||||
|
||||
Returns:
|
||||
The request body schema or None if not found.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
path_item = self.get_paths().get(path, {})
|
||||
operation = path_item.get(method.lower(), {})
|
||||
request_body = operation.get("requestBody", {})
|
||||
|
||||
if not request_body:
|
||||
return None
|
||||
|
||||
content = request_body.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
|
||||
if json_content:
|
||||
schema = json_content.get("schema", {})
|
||||
return schema
|
||||
|
||||
return None
|
||||
|
||||
def get_response_schema(
|
||||
self, path: str, method: str, status_code: str = "200"
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get the response schema for a specific path, method, and status code.
|
||||
|
||||
Args:
|
||||
path: The API path.
|
||||
method: The HTTP method.
|
||||
status_code: The response status code.
|
||||
|
||||
Returns:
|
||||
The response schema or None if not found.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
path_item = self.get_paths().get(path, {})
|
||||
operation = path_item.get(method.lower(), {})
|
||||
responses = operation.get("responses", {})
|
||||
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
response = responses[status_code]
|
||||
content = response.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
|
||||
if json_content:
|
||||
return json_content.get("schema", {})
|
||||
|
||||
return None
|
||||
|
||||
def get_path_parameters(self, path: str) -> list[dict[str, Any]]:
|
||||
"""Extract path parameters for a given path.
|
||||
|
||||
Args:
|
||||
path: The API path.
|
||||
|
||||
Returns:
|
||||
List of path parameter definitions.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
path_item = self.get_paths().get(path, {})
|
||||
parameters = path_item.get("parameters", [])
|
||||
|
||||
return parameters
|
||||
|
||||
def get_operation_parameters(
|
||||
self, path: str, method: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all parameters for a specific operation.
|
||||
|
||||
Args:
|
||||
path: The API path.
|
||||
method: The HTTP method.
|
||||
|
||||
Returns:
|
||||
List of parameter definitions.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
path_item = self.get_paths().get(path, {})
|
||||
operation = path_item.get(method.lower(), {})
|
||||
|
||||
return operation.get("parameters", [])
|
||||
|
||||
def get_security_schemes(self) -> dict[str, Any]:
|
||||
"""Get security schemes from components.
|
||||
|
||||
Returns:
|
||||
Dictionary of security scheme definitions.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
components = self.spec.get("components", {})
|
||||
return components.get("securitySchemes", {})
|
||||
|
||||
def get_servers(self) -> list[dict[str, Any]]:
|
||||
"""Get servers from the specification.
|
||||
|
||||
Returns:
|
||||
List of server definitions.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
return self.spec.get("servers", [])
|
||||
|
||||
def get_base_path(self) -> str:
|
||||
"""Get the base path from the specification.
|
||||
|
||||
Returns:
|
||||
The base path string.
|
||||
"""
|
||||
if self.spec is None:
|
||||
self.load()
|
||||
|
||||
if self.version and self.version.startswith("3."):
|
||||
servers = self.get_servers()
|
||||
if servers:
|
||||
url = servers[0].get("url", "")
|
||||
parsed = urlparse(url)
|
||||
return parsed.path
|
||||
return ""
|
||||
|
||||
return self.spec.get("basePath", "")
|
||||
Reference in New Issue
Block a user