"""Pytest test generator.""" import re from pathlib import Path from typing import Any, Dict, List, Optional from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError from ..core import SpecParser, AuthConfig from ..core.exceptions import GeneratorError, TemplateRenderError class PytestGenerator: """Generate pytest-compatible integration test templates.""" def __init__( self, spec_parser: SpecParser, output_dir: str = "tests", mock_server_url: str = "http://localhost:4010", ): """Initialize the pytest generator. Args: spec_parser: The OpenAPI specification parser. output_dir: Directory for generated test files. mock_server_url: URL of the mock server for testing. """ self.spec_parser = spec_parser self.output_dir = Path(output_dir) self.mock_server_url = mock_server_url self.env = Environment( loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "pytest")), trim_blocks=True, lstrip_blocks=True, ) def generate(self, output_file: Optional[str] = None) -> List[Path]: """Generate pytest test files. Args: output_file: Optional specific output file path. Returns: List of generated file paths. """ self.output_dir.mkdir(parents=True, exist_ok=True) endpoints = self.spec_parser.get_endpoints() info = self.spec_parser.get_info() context = { "api_title": info["title"], "api_version": info["version"], "endpoints": endpoints, "mock_server_url": self.mock_server_url, "security_schemes": self.spec_parser.get_security_schemes(), "definitions": self.spec_parser.get_definitions(), } generated_files = [] try: template = self.env.get_template("test_base.py.j2") content = template.render(context) if output_file: output_path = Path(output_file) else: safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) output_path = self.output_dir / f"test_{safe_name}.py" output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(content) generated_files.append(output_path) except (TemplateSyntaxError, UndefinedError) as e: raise TemplateRenderError(f"Failed to render pytest template: {e}") return generated_files def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: """Generate test for a specific endpoint. Args: endpoint_path: The API endpoint path. method: The HTTP method. Returns: String containing the test code. """ endpoints = self.spec_parser.get_endpoints() for endpoint in endpoints: if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): return self._generate_single_test(endpoint) return "" def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: """Generate test code for a single endpoint. Args: endpoint: The endpoint dictionary. Returns: String containing the test code. """ test_name = self._generate_test_name(endpoint) params = self._generate_parameters(endpoint) auth_headers = self._generate_auth_headers(endpoint) test_code = f''' def test_{test_name}(base_url, {params}): """Test {endpoint["summary"] or endpoint["path"]} endpoint.""" url = f"{base_url}{endpoint['path']}" headers = {{"Content-Type": "application/json"}} {auth_headers} response = requests.{endpoint['method']}(url, json={{}} if method == "POST" else None, headers=headers) assert response.status_code in [200, 201, 204] ''' return test_code def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: """Generate a valid test function name. Args: endpoint: The endpoint dictionary. Returns: A valid Python function name. """ path = endpoint["path"] method = endpoint["method"] name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) name = re.sub(r"_+/", "_", name) name = re.sub(r"_+$", "", name) return f"{method}_{name}" if name else f"{method}_default" def _generate_parameters(self, endpoint: Dict[str, Any]) -> str: """Generate parameters for test function. Args: endpoint: The endpoint dictionary. Returns: String containing parameter declarations. """ params = [] for param in endpoint.get("parameters", []): param_name = param["name"] if param["in"] == "path": params.append(f'{param_name}="test_{param_name}"') elif param["in"] == "query": params.append(f'{param_name}=None') return ", ".join(params) def _generate_auth_headers(self, endpoint: Dict[str, Any]) -> str: """Generate authentication headers for endpoint. Args: endpoint: The endpoint dictionary. Returns: String containing header assignments. """ security_requirements = endpoint.get("security", []) if not security_requirements: return "" schemes = self.spec_parser.get_security_schemes() auth_config = AuthConfig() try: auth_config.build_from_spec(schemes, security_requirements) except Exception: return "" headers = [] for scheme_name in security_requirements[0].keys(): method = auth_config.get_auth_method(scheme_name) if method: if method["type"] == "apiKey": headers.append(f'headers["{method["header_name"]}"] = "test_api_key"') elif method["type"] == "bearer": headers.append('headers["Authorization"] = "Bearer test_token"') return "\n ".join(headers) if headers else ""