Files
api-testgen-cli/api_testgen/core/spec_parser.py
CI Bot 123a4f7d1d fix: resolve CI linting and code quality issues
- Remove unused imports across all generator files
- Remove unused variables (spec, url_params, query_params, test_name)
- Fix f-strings without placeholders in auth.py and go.py
- Fix duplicate BASIC auth handling with wrong indentation
- Add missing pytest fixtures (sample_openapi_spec, temp_spec_file, temp_json_spec_file)
- Add missing TemplateRenderError import to generator files
2026-02-06 07:05:58 +00:00

307 lines
10 KiB
Python

"""OpenAPI Specification Parser."""
import yaml
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from openapi_spec_validator import validate
from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
class SpecParser:
"""Parse and validate OpenAPI/Swagger specifications."""
SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
def __init__(self, spec_path: Union[str, Path]):
"""Initialize the spec parser.
Args:
spec_path: Path to OpenAPI specification file (YAML or JSON).
"""
self.spec_path = Path(spec_path)
self.spec: Dict[str, Any] = {}
self.version: str = ""
self.base_path: str = ""
self.servers: List[Dict[str, str]] = []
def load(self) -> Dict[str, Any]:
"""Load and validate the OpenAPI specification.
Returns:
The parsed specification dictionary.
Raises:
InvalidOpenAPISpecError: If the specification is invalid.
UnsupportedVersionError: If the OpenAPI version is not supported.
"""
self.spec = self._load_file()
self._validate()
self._extract_metadata()
return self.spec
def _load_file(self) -> Dict[str, Any]:
"""Load the specification file.
Returns:
Parsed specification dictionary.
Raises:
InvalidOpenAPISpecError: If the file cannot be loaded or parsed.
"""
if not self.spec_path.exists():
raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}")
try:
with open(self.spec_path, 'r', encoding='utf-8') as f:
if self.spec_path.suffix in ['.yaml', '.yml']:
return yaml.safe_load(f) or {}
elif self.spec_path.suffix == '.json':
return json.load(f)
else:
return yaml.safe_load(f) or {}
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}")
def _validate(self) -> None:
"""Validate the specification.
Raises:
InvalidOpenAPISpecError: If the specification is invalid.
UnsupportedVersionError: If the OpenAPI version is not supported.
"""
try:
validate(self.spec)
except Exception as e:
raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}")
version = self._get_version()
if version not in self.SUPPORTED_VERSIONS:
raise UnsupportedVersionError(
f"Unsupported OpenAPI version: {version}. "
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
)
def _get_version(self) -> str:
"""Extract the OpenAPI version from the spec.
Returns:
The OpenAPI version string.
"""
if "openapi" in self.spec:
return self.spec["openapi"]
elif "swagger" in self.spec:
return self.spec["swagger"]
return "2.0"
def _extract_metadata(self) -> None:
"""Extract metadata from the specification."""
self.version = self._get_version()
self.base_path = self.spec.get("basePath", "")
self.servers = self.spec.get("servers", [])
def get_paths(self) -> Dict[str, Any]:
"""Get all paths from the specification.
Returns:
Dictionary of paths and their operations.
"""
return self.spec.get("paths", {})
def get_endpoints(self) -> List[Dict[str, Any]]:
"""Extract all endpoints from the specification.
Returns:
List of endpoint dictionaries with path, method, and details.
"""
endpoints = []
paths = self.get_paths()
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]:
endpoint = {
"path": path,
"method": method.lower(),
"operation_id": operation.get("operationId", ""),
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"tags": operation.get("tags", []),
"parameters": self._extract_parameters(path_item, operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"security": self._extract_security(operation),
}
endpoints.append(endpoint)
return endpoints
def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]:
"""Extract parameters from path item and operation.
Args:
path_item: The path item dictionary.
operation: The operation dictionary.
Returns:
List of parameter dictionaries.
"""
parameters = []
for param in path_item.get("parameters", []):
if param.get("in") != "body":
parameters.append(self._normalize_parameter(param))
for param in operation.get("parameters", []):
if param.get("in") != "body":
parameters.append(self._normalize_parameter(param))
return parameters
def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize a parameter.
Args:
param: The parameter dictionary.
Returns:
Normalized parameter dictionary.
"""
return {
"name": param.get("name", ""),
"in": param.get("in", ""),
"description": param.get("description", ""),
"required": param.get("required", False),
"schema": param.get("schema", {}),
"type": param.get("type", ""),
"enum": param.get("enum", []),
"default": param.get("default"),
}
def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]:
"""Extract request body from operation.
Args:
operation: The operation dictionary.
Returns:
Request body dictionary or None.
"""
if self.version.startswith("3."):
request_body = operation.get("requestBody", {})
if not request_body:
return None
content = request_body.get("content", {})
media_types = list(content.keys())
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"media_types": media_types,
"schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {},
}
else:
params = operation.get("parameters", [])
for param in params:
if param.get("in") == "body":
return {
"description": param.get("description", ""),
"required": param.get("required", False),
"schema": param.get("schema", {}),
}
return None
def _extract_responses(self, operation: Dict) -> Dict[str, Any]:
"""Extract responses from operation.
Args:
operation: The operation dictionary.
Returns:
Dictionary of response status codes and their details.
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
content = response.get("content", {})
if self.version.startswith("3."):
media_types = list(content.keys())
schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {}
else:
schema = response.get("schema", {})
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"media_type": list(content.keys())[0] if content else "application/json",
}
return responses
def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]:
"""Extract security requirements from operation.
Args:
operation: The operation dictionary.
Returns:
List of security requirement dictionaries.
"""
return operation.get("security", self.spec.get("security", []))
def get_security_schemes(self) -> Dict[str, Any]:
"""Get security schemes from the specification.
Returns:
Dictionary of security scheme names and their definitions.
"""
if self.version.startswith("3."):
return self.spec.get("components", {}).get("securitySchemes", {})
else:
return self.spec.get("securityDefinitions", {})
def get_definitions(self) -> Dict[str, Any]:
"""Get schema definitions from the specification.
Returns:
Dictionary of schema definitions.
"""
if self.version.startswith("3."):
return self.spec.get("components", {}).get("schemas", {})
else:
return self.spec.get("definitions", {})
def get_info(self) -> Dict[str, str]:
"""Get API info from the specification.
Returns:
Dictionary with title, version, and description.
"""
info = self.spec.get("info", {})
return {
"title": info.get("title", "API"),
"version": info.get("version", "1.0.0"),
"description": info.get("description", ""),
}
def to_dict(self) -> Dict[str, Any]:
"""Convert the spec to a dictionary.
Returns:
Dictionary representation of the parsed spec.
"""
return {
"version": self.version,
"base_path": self.base_path,
"servers": self.servers,
"info": self.get_info(),
"paths": self.get_paths(),
"endpoints": self.get_endpoints(),
"security_schemes": self.get_security_schemes(),
"definitions": self.get_definitions(),
}