Add source files: core parser module
Some checks failed
CI / test (push) Failing after 12s

This commit is contained in:
2026-02-02 19:54:00 +00:00
parent bed71e9428
commit f111799e4f

324
src/core/parser.py Normal file
View File

@@ -0,0 +1,324 @@
"""OpenAPI Specification Parser."""
import json
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import yaml
from openapi_spec_validator import validate
class OpenAPIParserError(Exception):
"""Base exception for OpenAPI parser errors."""
pass
class InvalidOpenAPIFormat(OpenAPIParserError):
"""Raised when the OpenAPI specification format is invalid."""
pass
class UnsupportedOpenAPIVersion(OpenAPIParserError):
"""Raised when the OpenAPI version is not supported."""
pass
class OpenAPIParser:
"""Parser for OpenAPI specifications (3.0, 3.1, and Swagger 2.0)."""
SUPPORTED_VERSIONS = {
"3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.0.4",
"3.1.0", "3.1.1",
}
def __init__(self, spec_file: str) -> None:
"""Initialize the parser with a spec file path.
Args:
spec_file: Path to the OpenAPI specification file (YAML or JSON).
"""
self.spec_file = Path(spec_file)
self.spec: dict[str, Any] | None = None
self.version: str | None = None
def load(self) -> dict[str, Any]:
"""Load and parse the OpenAPI specification from file.
Returns:
The parsed OpenAPI specification as a dictionary.
Raises:
FileNotFoundError: If the spec file does not exist.
InvalidOpenAPIFormat: If the file format is invalid.
"""
if not self.spec_file.exists():
raise FileNotFoundError(f"OpenAPI spec file not found: {self.spec_file}")
content = self._read_file()
self.spec = self._parse_content(content)
self.version = self._extract_version(self.spec)
return self.spec
def _read_file(self) -> str:
"""Read the spec file content.
Returns:
The file content as a string.
Raises:
InvalidOpenAPIFormat: If the file cannot be read.
"""
try:
return self.spec_file.read_text(encoding="utf-8")
except Exception as e:
raise InvalidOpenAPIFormat(f"Failed to read spec file: {e}") from e
def _parse_content(self, content: str) -> dict[str, Any]:
"""Parse the content based on file extension.
Args:
content: The raw file content.
Returns:
The parsed specification as a dictionary.
Raises:
InvalidOpenAPIFormat: If the content cannot be parsed.
"""
suffix = self.spec_file.suffix.lower()
try:
if suffix in {".yaml", ".yml"}:
return yaml.safe_load(content)
elif suffix == ".json":
return json.loads(content)
else:
try:
return yaml.safe_load(content)
except yaml.YAMLError:
return json.loads(content)
except (json.JSONDecodeError, yaml.YAMLError) as e:
raise InvalidOpenAPIFormat(f"Failed to parse spec file: {e}") from e
def _extract_version(self, spec: dict[str, Any]) -> str:
"""Extract the OpenAPI version from the spec.
Args:
spec: The parsed OpenAPI specification.
Returns:
The OpenAPI version string.
Raises:
UnsupportedOpenAPIVersion: If the version is not supported.
"""
openapi_version = spec.get("openapi", "")
if not openapi_version:
raise InvalidOpenAPIFormat("Missing 'openapi' version field")
if openapi_version not in self.SUPPORTED_VERSIONS:
raise UnsupportedOpenAPIVersion(
f"Unsupported OpenAPI version: {openapi_version}. "
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
)
return openapi_version
def validate_spec(self) -> list[str]:
"""Validate the OpenAPI specification.
Returns:
List of validation errors (empty if valid).
Raises:
InvalidOpenAPIFormat: If validation fails.
"""
if self.spec is None:
self.load()
try:
validate(self.spec)
return []
except Exception as e:
raise InvalidOpenAPIFormat(f"Specification validation failed: {e}") from e
def get_paths(self) -> dict[str, Any]:
"""Extract all paths from the specification.
Returns:
Dictionary of path strings to path item objects.
"""
if self.spec is None:
self.load()
return self.spec.get("paths", {})
def get_schemas(self) -> dict[str, Any]:
"""Extract all schemas from components.
Returns:
Dictionary of schema names to schema objects.
"""
if self.spec is None:
self.load()
components = self.spec.get("components", {})
return components.get("schemas", {})
def get_definitions(self) -> dict[str, Any]:
"""Extract all definitions (Swagger 2.0).
Returns:
Dictionary of definition names to schema objects.
"""
if self.spec is None:
self.load()
return self.spec.get("definitions", {})
def get_request_body_schema(
self, path: str, method: str
) -> dict[str, Any] | None:
"""Get the request body schema for a specific path and method.
Args:
path: The API path.
method: The HTTP method (get, post, put, delete, etc.).
Returns:
The request body schema or None if not found.
"""
if self.spec is None:
self.load()
path_item = self.get_paths().get(path, {})
operation = path_item.get(method.lower(), {})
request_body = operation.get("requestBody", {})
if not request_body:
return None
content = request_body.get("content", {})
json_content = content.get("application/json", {})
if json_content:
schema = json_content.get("schema", {})
return schema
return None
def get_response_schema(
self, path: str, method: str, status_code: str = "200"
) -> dict[str, Any] | None:
"""Get the response schema for a specific path, method, and status code.
Args:
path: The API path.
method: The HTTP method.
status_code: The response status code.
Returns:
The response schema or None if not found.
"""
if self.spec is None:
self.load()
path_item = self.get_paths().get(path, {})
operation = path_item.get(method.lower(), {})
responses = operation.get("responses", {})
if status_code not in responses:
return None
response = responses[status_code]
content = response.get("content", {})
json_content = content.get("application/json", {})
if json_content:
return json_content.get("schema", {})
return None
def get_path_parameters(self, path: str) -> list[dict[str, Any]]:
"""Extract path parameters for a given path.
Args:
path: The API path.
Returns:
List of path parameter definitions.
"""
if self.spec is None:
self.load()
path_item = self.get_paths().get(path, {})
parameters = path_item.get("parameters", [])
return parameters
def get_operation_parameters(
self, path: str, method: str
) -> list[dict[str, Any]]:
"""Get all parameters for a specific operation.
Args:
path: The API path.
method: The HTTP method.
Returns:
List of parameter definitions.
"""
if self.spec is None:
self.load()
path_item = self.get_paths().get(path, {})
operation = path_item.get(method.lower(), {})
return operation.get("parameters", [])
def get_security_schemes(self) -> dict[str, Any]:
"""Get security schemes from components.
Returns:
Dictionary of security scheme definitions.
"""
if self.spec is None:
self.load()
components = self.spec.get("components", {})
return components.get("securitySchemes", {})
def get_servers(self) -> list[dict[str, Any]]:
"""Get servers from the specification.
Returns:
List of server definitions.
"""
if self.spec is None:
self.load()
return self.spec.get("servers", [])
def get_base_path(self) -> str:
"""Get the base path from the specification.
Returns:
The base path string.
"""
if self.spec is None:
self.load()
if self.version and self.version.startswith("3."):
servers = self.get_servers()
if servers:
url = servers[0].get("url", "")
parsed = urlparse(url)
return parsed.path
return ""
return self.spec.get("basePath", "")