diff --git a/api_testgen/core/spec_parser.py b/api_testgen/core/spec_parser.py new file mode 100644 index 0000000..4da87db --- /dev/null +++ b/api_testgen/core/spec_parser.py @@ -0,0 +1,307 @@ +"""OpenAPI Specification Parser.""" + +import yaml +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from openapi_spec_validator import validate +from openapi_spec_validator.versions import consts as validator_consts + +from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError + + +class SpecParser: + """Parse and validate OpenAPI/Swagger specifications.""" + + SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"] + + def __init__(self, spec_path: Union[str, Path]): + """Initialize the spec parser. + + Args: + spec_path: Path to OpenAPI specification file (YAML or JSON). + """ + self.spec_path = Path(spec_path) + self.spec: Dict[str, Any] = {} + self.version: str = "" + self.base_path: str = "" + self.servers: List[Dict[str, str]] = [] + + def load(self) -> Dict[str, Any]: + """Load and validate the OpenAPI specification. + + Returns: + The parsed specification dictionary. + + Raises: + InvalidOpenAPISpecError: If the specification is invalid. + UnsupportedVersionError: If the OpenAPI version is not supported. + """ + self.spec = self._load_file() + self._validate() + self._extract_metadata() + return self.spec + + def _load_file(self) -> Dict[str, Any]: + """Load the specification file. + + Returns: + Parsed specification dictionary. + + Raises: + InvalidOpenAPISpecError: If the file cannot be loaded or parsed. + """ + if not self.spec_path.exists(): + raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}") + + try: + with open(self.spec_path, 'r', encoding='utf-8') as f: + if self.spec_path.suffix in ['.yaml', '.yml']: + return yaml.safe_load(f) or {} + elif self.spec_path.suffix == '.json': + return json.load(f) + else: + return yaml.safe_load(f) or {} + except (yaml.YAMLError, json.JSONDecodeError) as e: + raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}") + + def _validate(self) -> None: + """Validate the specification. + + Raises: + InvalidOpenAPISpecError: If the specification is invalid. + UnsupportedVersionError: If the OpenAPI version is not supported. + """ + try: + validate(self.spec) + except Exception as e: + raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}") + + version = self._get_version() + if version not in self.SUPPORTED_VERSIONS: + raise UnsupportedVersionError( + f"Unsupported OpenAPI version: {version}. " + f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}" + ) + + def _get_version(self) -> str: + """Extract the OpenAPI version from the spec. + + Returns: + The OpenAPI version string. + """ + if "openapi" in self.spec: + return self.spec["openapi"] + elif "swagger" in self.spec: + return self.spec["swagger"] + return "2.0" + + def _extract_metadata(self) -> None: + """Extract metadata from the specification.""" + self.version = self._get_version() + self.base_path = self.spec.get("basePath", "") + self.servers = self.spec.get("servers", []) + + def get_paths(self) -> Dict[str, Any]: + """Get all paths from the specification. + + Returns: + Dictionary of paths and their operations. + """ + return self.spec.get("paths", {}) + + def get_endpoints(self) -> List[Dict[str, Any]]: + """Extract all endpoints from the specification. + + Returns: + List of endpoint dictionaries with path, method, and details. + """ + endpoints = [] + paths = self.get_paths() + + for path, path_item in paths.items(): + for method, operation in path_item.items(): + if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]: + endpoint = { + "path": path, + "method": method.lower(), + "operation_id": operation.get("operationId", ""), + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "tags": operation.get("tags", []), + "parameters": self._extract_parameters(path_item, operation), + "request_body": self._extract_request_body(operation), + "responses": self._extract_responses(operation), + "security": self._extract_security(operation), + } + endpoints.append(endpoint) + + return endpoints + + def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]: + """Extract parameters from path item and operation. + + Args: + path_item: The path item dictionary. + operation: The operation dictionary. + + Returns: + List of parameter dictionaries. + """ + parameters = [] + + for param in path_item.get("parameters", []): + if param.get("in") != "body": + parameters.append(self._normalize_parameter(param)) + + for param in operation.get("parameters", []): + if param.get("in") != "body": + parameters.append(self._normalize_parameter(param)) + + return parameters + + def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]: + """Normalize a parameter. + + Args: + param: The parameter dictionary. + + Returns: + Normalized parameter dictionary. + """ + return { + "name": param.get("name", ""), + "in": param.get("in", ""), + "description": param.get("description", ""), + "required": param.get("required", False), + "schema": param.get("schema", {}), + "type": param.get("type", ""), + "enum": param.get("enum", []), + "default": param.get("default"), + } + + def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]: + """Extract request body from operation. + + Args: + operation: The operation dictionary. + + Returns: + Request body dictionary or None. + """ + if self.version.startswith("3."): + request_body = operation.get("requestBody", {}) + if not request_body: + return None + + content = request_body.get("content", {}) + media_types = list(content.keys()) + + return { + "description": request_body.get("description", ""), + "required": request_body.get("required", False), + "media_types": media_types, + "schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {}, + } + else: + params = operation.get("parameters", []) + for param in params: + if param.get("in") == "body": + return { + "description": param.get("description", ""), + "required": param.get("required", False), + "schema": param.get("schema", {}), + } + return None + + def _extract_responses(self, operation: Dict) -> Dict[str, Any]: + """Extract responses from operation. + + Args: + operation: The operation dictionary. + + Returns: + Dictionary of response status codes and their details. + """ + responses = {} + + for status_code, response in operation.get("responses", {}).items(): + content = response.get("content", {}) + + if self.version.startswith("3."): + media_types = list(content.keys()) + schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {} + else: + schema = response.get("schema", {}) + + responses[status_code] = { + "description": response.get("description", ""), + "schema": schema, + "media_type": list(content.keys())[0] if content else "application/json", + } + + return responses + + def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]: + """Extract security requirements from operation. + + Args: + operation: The operation dictionary. + + Returns: + List of security requirement dictionaries. + """ + return operation.get("security", self.spec.get("security", [])) + + def get_security_schemes(self) -> Dict[str, Any]: + """Get security schemes from the specification. + + Returns: + Dictionary of security scheme names and their definitions. + """ + if self.version.startswith("3."): + return self.spec.get("components", {}).get("securitySchemes", {}) + else: + return self.spec.get("securityDefinitions", {}) + + def get_definitions(self) -> Dict[str, Any]: + """Get schema definitions from the specification. + + Returns: + Dictionary of schema definitions. + """ + if self.version.startswith("3."): + return self.spec.get("components", {}).get("schemas", {}) + else: + return self.spec.get("definitions", {}) + + def get_info(self) -> Dict[str, str]: + """Get API info from the specification. + + Returns: + Dictionary with title, version, and description. + """ + info = self.spec.get("info", {}) + return { + "title": info.get("title", "API"), + "version": info.get("version", "1.0.0"), + "description": info.get("description", ""), + } + + def to_dict(self) -> Dict[str, Any]: + """Convert the spec to a dictionary. + + Returns: + Dictionary representation of the parsed spec. + """ + return { + "version": self.version, + "base_path": self.base_path, + "servers": self.servers, + "info": self.get_info(), + "paths": self.get_paths(), + "endpoints": self.get_endpoints(), + "security_schemes": self.get_security_schemes(), + "definitions": self.get_definitions(), + }