"""OpenAPI Specification Parser.""" import yaml import json from pathlib import Path from typing import Any, Dict, List, Optional, Union from openapi_spec_validator import validate from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError class SpecParser: """Parse and validate OpenAPI/Swagger specifications.""" SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"] def __init__(self, spec_path: Union[str, Path]): """Initialize the spec parser. Args: spec_path: Path to OpenAPI specification file (YAML or JSON). """ self.spec_path = Path(spec_path) self.spec: Dict[str, Any] = {} self.version: str = "" self.base_path: str = "" self.servers: List[Dict[str, str]] = [] def load(self) -> Dict[str, Any]: """Load and validate the OpenAPI specification. Returns: The parsed specification dictionary. Raises: InvalidOpenAPISpecError: If the specification is invalid. UnsupportedVersionError: If the OpenAPI version is not supported. """ self.spec = self._load_file() self._validate() self._extract_metadata() return self.spec def _load_file(self) -> Dict[str, Any]: """Load the specification file. Returns: Parsed specification dictionary. Raises: InvalidOpenAPISpecError: If the file cannot be loaded or parsed. """ if not self.spec_path.exists(): raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}") try: with open(self.spec_path, "r", encoding="utf-8") as f: if self.spec_path.suffix in [".yaml", ".yml"]: return yaml.safe_load(f) or {} elif self.spec_path.suffix == ".json": return json.load(f) else: return yaml.safe_load(f) or {} except (yaml.YAMLError, json.JSONDecodeError) as e: raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}") def _validate(self) -> None: """Validate the specification. Raises: InvalidOpenAPISpecError: If the specification is invalid. UnsupportedVersionError: If the OpenAPI version is not supported. """ try: validate(self.spec) except Exception as e: raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}") version = self._get_version() if version not in self.SUPPORTED_VERSIONS: raise UnsupportedVersionError( f"Unsupported OpenAPI version: {version}. " f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}" ) def _get_version(self) -> str: """Extract the OpenAPI version from the spec. Returns: The OpenAPI version string. """ if "openapi" in self.spec: return self.spec["openapi"] elif "swagger" in self.spec: return self.spec["swagger"] return "2.0" def _extract_metadata(self) -> None: """Extract metadata from the specification.""" self.version = self._get_version() self.base_path = self.spec.get("basePath", "") self.servers = self.spec.get("servers", []) def get_paths(self) -> Dict[str, Any]: """Get all paths from the specification. Returns: Dictionary of paths and their operations. """ return self.spec.get("paths", {}) def get_endpoints(self) -> List[Dict[str, Any]]: """Extract all endpoints from the specification. Returns: List of endpoint dictionaries with path, method, and details. """ endpoints = [] paths = self.get_paths() for path, path_item in paths.items(): for method, operation in path_item.items(): if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]: endpoint = { "path": path, "method": method.lower(), "operation_id": operation.get("operationId", ""), "summary": operation.get("summary", ""), "description": operation.get("description", ""), "tags": operation.get("tags", []), "parameters": self._extract_parameters(path_item, operation), "request_body": self._extract_request_body(operation), "responses": self._extract_responses(operation), "security": self._extract_security(operation), } endpoints.append(endpoint) return endpoints def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]: """Extract parameters from path item and operation. Args: path_item: The path item dictionary. operation: The operation dictionary. Returns: List of parameter dictionaries. """ parameters = [] for param in path_item.get("parameters", []): if param.get("in") != "body": parameters.append(self._normalize_parameter(param)) for param in operation.get("parameters", []): if param.get("in") != "body": parameters.append(self._normalize_parameter(param)) return parameters def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]: """Normalize a parameter. Args: param: The parameter dictionary. Returns: Normalized parameter dictionary. """ return { "name": param.get("name", ""), "in": param.get("in", ""), "description": param.get("description", ""), "required": param.get("required", False), "schema": param.get("schema", {}), "type": param.get("type", ""), "enum": param.get("enum", []), "default": param.get("default"), } def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]: """Extract request body from operation. Args: operation: The operation dictionary. Returns: Request body dictionary or None. """ if self.version.startswith("3."): request_body = operation.get("requestBody", {}) if not request_body: return None content = request_body.get("content", {}) media_types = list(content.keys()) return { "description": request_body.get("description", ""), "required": request_body.get("required", False), "media_types": media_types, "schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {}, } else: params = operation.get("parameters", []) for param in params: if param.get("in") == "body": return { "description": param.get("description", ""), "required": param.get("required", False), "schema": param.get("schema", {}), } return None def _extract_responses(self, operation: Dict) -> Dict[str, Any]: """Extract responses from operation. Args: operation: The operation dictionary. Returns: Dictionary of response status codes and their details. """ responses = {} for status_code, response in operation.get("responses", {}).items(): content = response.get("content", {}) if self.version.startswith("3."): media_types = list(content.keys()) schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {} else: schema = response.get("schema", {}) responses[status_code] = { "description": response.get("description", ""), "schema": schema, "media_type": list(content.keys())[0] if content else "application/json", } return responses def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]: """Extract security requirements from operation. Args: operation: The operation dictionary. Returns: List of security requirement dictionaries. """ return operation.get("security", self.spec.get("security", [])) def get_security_schemes(self) -> Dict[str, Any]: """Get security schemes from the specification. Returns: Dictionary of security scheme names and their definitions. """ if self.version.startswith("3."): return self.spec.get("components", {}).get("securitySchemes", {}) else: return self.spec.get("securityDefinitions", {}) def get_definitions(self) -> Dict[str, Any]: """Get schema definitions from the specification. Returns: Dictionary of schema definitions. """ if self.version.startswith("3."): return self.spec.get("components", {}).get("schemas", {}) else: return self.spec.get("definitions", {}) def get_info(self) -> Dict[str, str]: """Get API info from the specification. Returns: Dictionary with title, version, and description. """ info = self.spec.get("info", {}) return { "title": info.get("title", "API"), "version": info.get("version", "1.0.0"), "description": info.get("description", ""), } def to_dict(self) -> Dict[str, Any]: """Convert the spec to a dictionary. Returns: Dictionary representation of the parsed spec. """ return { "version": self.version, "base_path": self.base_path, "servers": self.servers, "info": self.get_info(), "paths": self.get_paths(), "endpoints": self.get_endpoints(), "security_schemes": self.get_security_schemes(), "definitions": self.get_definitions(), }