- Remove unused imports across all generator files - Remove unused variables (spec, url_params, query_params, test_name) - Fix f-strings without placeholders in auth.py and go.py - Fix duplicate BASIC auth handling with wrong indentation - Add missing pytest fixtures (sample_openapi_spec, temp_spec_file, temp_json_spec_file) - Add missing TemplateRenderError import to generator files
307 lines
10 KiB
Python
307 lines
10 KiB
Python
"""OpenAPI Specification Parser."""
|
|
|
|
import yaml
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from openapi_spec_validator import validate
|
|
|
|
from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
|
|
|
|
|
|
class SpecParser:
|
|
"""Parse and validate OpenAPI/Swagger specifications."""
|
|
|
|
SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
|
|
|
|
def __init__(self, spec_path: Union[str, Path]):
|
|
"""Initialize the spec parser.
|
|
|
|
Args:
|
|
spec_path: Path to OpenAPI specification file (YAML or JSON).
|
|
"""
|
|
self.spec_path = Path(spec_path)
|
|
self.spec: Dict[str, Any] = {}
|
|
self.version: str = ""
|
|
self.base_path: str = ""
|
|
self.servers: List[Dict[str, str]] = []
|
|
|
|
def load(self) -> Dict[str, Any]:
|
|
"""Load and validate the OpenAPI specification.
|
|
|
|
Returns:
|
|
The parsed specification dictionary.
|
|
|
|
Raises:
|
|
InvalidOpenAPISpecError: If the specification is invalid.
|
|
UnsupportedVersionError: If the OpenAPI version is not supported.
|
|
"""
|
|
self.spec = self._load_file()
|
|
self._validate()
|
|
self._extract_metadata()
|
|
return self.spec
|
|
|
|
def _load_file(self) -> Dict[str, Any]:
|
|
"""Load the specification file.
|
|
|
|
Returns:
|
|
Parsed specification dictionary.
|
|
|
|
Raises:
|
|
InvalidOpenAPISpecError: If the file cannot be loaded or parsed.
|
|
"""
|
|
if not self.spec_path.exists():
|
|
raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}")
|
|
|
|
try:
|
|
with open(self.spec_path, 'r', encoding='utf-8') as f:
|
|
if self.spec_path.suffix in ['.yaml', '.yml']:
|
|
return yaml.safe_load(f) or {}
|
|
elif self.spec_path.suffix == '.json':
|
|
return json.load(f)
|
|
else:
|
|
return yaml.safe_load(f) or {}
|
|
except (yaml.YAMLError, json.JSONDecodeError) as e:
|
|
raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}")
|
|
|
|
def _validate(self) -> None:
|
|
"""Validate the specification.
|
|
|
|
Raises:
|
|
InvalidOpenAPISpecError: If the specification is invalid.
|
|
UnsupportedVersionError: If the OpenAPI version is not supported.
|
|
"""
|
|
try:
|
|
validate(self.spec)
|
|
except Exception as e:
|
|
raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}")
|
|
|
|
version = self._get_version()
|
|
if version not in self.SUPPORTED_VERSIONS:
|
|
raise UnsupportedVersionError(
|
|
f"Unsupported OpenAPI version: {version}. "
|
|
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
|
|
)
|
|
|
|
def _get_version(self) -> str:
|
|
"""Extract the OpenAPI version from the spec.
|
|
|
|
Returns:
|
|
The OpenAPI version string.
|
|
"""
|
|
if "openapi" in self.spec:
|
|
return self.spec["openapi"]
|
|
elif "swagger" in self.spec:
|
|
return self.spec["swagger"]
|
|
return "2.0"
|
|
|
|
def _extract_metadata(self) -> None:
|
|
"""Extract metadata from the specification."""
|
|
self.version = self._get_version()
|
|
self.base_path = self.spec.get("basePath", "")
|
|
self.servers = self.spec.get("servers", [])
|
|
|
|
def get_paths(self) -> Dict[str, Any]:
|
|
"""Get all paths from the specification.
|
|
|
|
Returns:
|
|
Dictionary of paths and their operations.
|
|
"""
|
|
return self.spec.get("paths", {})
|
|
|
|
def get_endpoints(self) -> List[Dict[str, Any]]:
|
|
"""Extract all endpoints from the specification.
|
|
|
|
Returns:
|
|
List of endpoint dictionaries with path, method, and details.
|
|
"""
|
|
endpoints = []
|
|
paths = self.get_paths()
|
|
|
|
for path, path_item in paths.items():
|
|
for method, operation in path_item.items():
|
|
if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]:
|
|
endpoint = {
|
|
"path": path,
|
|
"method": method.lower(),
|
|
"operation_id": operation.get("operationId", ""),
|
|
"summary": operation.get("summary", ""),
|
|
"description": operation.get("description", ""),
|
|
"tags": operation.get("tags", []),
|
|
"parameters": self._extract_parameters(path_item, operation),
|
|
"request_body": self._extract_request_body(operation),
|
|
"responses": self._extract_responses(operation),
|
|
"security": self._extract_security(operation),
|
|
}
|
|
endpoints.append(endpoint)
|
|
|
|
return endpoints
|
|
|
|
def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]:
|
|
"""Extract parameters from path item and operation.
|
|
|
|
Args:
|
|
path_item: The path item dictionary.
|
|
operation: The operation dictionary.
|
|
|
|
Returns:
|
|
List of parameter dictionaries.
|
|
"""
|
|
parameters = []
|
|
|
|
for param in path_item.get("parameters", []):
|
|
if param.get("in") != "body":
|
|
parameters.append(self._normalize_parameter(param))
|
|
|
|
for param in operation.get("parameters", []):
|
|
if param.get("in") != "body":
|
|
parameters.append(self._normalize_parameter(param))
|
|
|
|
return parameters
|
|
|
|
def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Normalize a parameter.
|
|
|
|
Args:
|
|
param: The parameter dictionary.
|
|
|
|
Returns:
|
|
Normalized parameter dictionary.
|
|
"""
|
|
return {
|
|
"name": param.get("name", ""),
|
|
"in": param.get("in", ""),
|
|
"description": param.get("description", ""),
|
|
"required": param.get("required", False),
|
|
"schema": param.get("schema", {}),
|
|
"type": param.get("type", ""),
|
|
"enum": param.get("enum", []),
|
|
"default": param.get("default"),
|
|
}
|
|
|
|
def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]:
|
|
"""Extract request body from operation.
|
|
|
|
Args:
|
|
operation: The operation dictionary.
|
|
|
|
Returns:
|
|
Request body dictionary or None.
|
|
"""
|
|
if self.version.startswith("3."):
|
|
request_body = operation.get("requestBody", {})
|
|
if not request_body:
|
|
return None
|
|
|
|
content = request_body.get("content", {})
|
|
media_types = list(content.keys())
|
|
|
|
return {
|
|
"description": request_body.get("description", ""),
|
|
"required": request_body.get("required", False),
|
|
"media_types": media_types,
|
|
"schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {},
|
|
}
|
|
else:
|
|
params = operation.get("parameters", [])
|
|
for param in params:
|
|
if param.get("in") == "body":
|
|
return {
|
|
"description": param.get("description", ""),
|
|
"required": param.get("required", False),
|
|
"schema": param.get("schema", {}),
|
|
}
|
|
return None
|
|
|
|
def _extract_responses(self, operation: Dict) -> Dict[str, Any]:
|
|
"""Extract responses from operation.
|
|
|
|
Args:
|
|
operation: The operation dictionary.
|
|
|
|
Returns:
|
|
Dictionary of response status codes and their details.
|
|
"""
|
|
responses = {}
|
|
|
|
for status_code, response in operation.get("responses", {}).items():
|
|
content = response.get("content", {})
|
|
|
|
if self.version.startswith("3."):
|
|
media_types = list(content.keys())
|
|
schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {}
|
|
else:
|
|
schema = response.get("schema", {})
|
|
|
|
responses[status_code] = {
|
|
"description": response.get("description", ""),
|
|
"schema": schema,
|
|
"media_type": list(content.keys())[0] if content else "application/json",
|
|
}
|
|
|
|
return responses
|
|
|
|
def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]:
|
|
"""Extract security requirements from operation.
|
|
|
|
Args:
|
|
operation: The operation dictionary.
|
|
|
|
Returns:
|
|
List of security requirement dictionaries.
|
|
"""
|
|
return operation.get("security", self.spec.get("security", []))
|
|
|
|
def get_security_schemes(self) -> Dict[str, Any]:
|
|
"""Get security schemes from the specification.
|
|
|
|
Returns:
|
|
Dictionary of security scheme names and their definitions.
|
|
"""
|
|
if self.version.startswith("3."):
|
|
return self.spec.get("components", {}).get("securitySchemes", {})
|
|
else:
|
|
return self.spec.get("securityDefinitions", {})
|
|
|
|
def get_definitions(self) -> Dict[str, Any]:
|
|
"""Get schema definitions from the specification.
|
|
|
|
Returns:
|
|
Dictionary of schema definitions.
|
|
"""
|
|
if self.version.startswith("3."):
|
|
return self.spec.get("components", {}).get("schemas", {})
|
|
else:
|
|
return self.spec.get("definitions", {})
|
|
|
|
def get_info(self) -> Dict[str, str]:
|
|
"""Get API info from the specification.
|
|
|
|
Returns:
|
|
Dictionary with title, version, and description.
|
|
"""
|
|
info = self.spec.get("info", {})
|
|
return {
|
|
"title": info.get("title", "API"),
|
|
"version": info.get("version", "1.0.0"),
|
|
"description": info.get("description", ""),
|
|
}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the spec to a dictionary.
|
|
|
|
Returns:
|
|
Dictionary representation of the parsed spec.
|
|
"""
|
|
return {
|
|
"version": self.version,
|
|
"base_path": self.base_path,
|
|
"servers": self.servers,
|
|
"info": self.get_info(),
|
|
"paths": self.get_paths(),
|
|
"endpoints": self.get_endpoints(),
|
|
"security_schemes": self.get_security_schemes(),
|
|
"definitions": self.get_definitions(),
|
|
}
|