diff --git a/api_testgen/generators/pytest.py b/api_testgen/generators/pytest.py new file mode 100644 index 0000000..0644564 --- /dev/null +++ b/api_testgen/generators/pytest.py @@ -0,0 +1,199 @@ +"""Pytest test generator.""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError + +from ..core import SpecParser, AuthConfig +from ..core.exceptions import GeneratorError, TemplateRenderError + + +class PytestGenerator: + """Generate pytest-compatible integration test templates.""" + + def __init__( + self, + spec_parser: SpecParser, + output_dir: str = "tests", + mock_server_url: str = "http://localhost:4010", + ): + """Initialize the pytest generator. + + Args: + spec_parser: The OpenAPI specification parser. + output_dir: Directory for generated test files. + mock_server_url: URL of the mock server for testing. + """ + self.spec_parser = spec_parser + self.output_dir = Path(output_dir) + self.mock_server_url = mock_server_url + self.env = Environment( + loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "pytest")), + trim_blocks=True, + lstrip_blocks=True, + ) + + def generate(self, output_file: Optional[str] = None) -> List[Path]: + """Generate pytest test files. + + Args: + output_file: Optional specific output file path. + + Returns: + List of generated file paths. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + endpoints = self.spec_parser.get_endpoints() + info = self.spec_parser.get_info() + + context = { + "api_title": info["title"], + "api_version": info["version"], + "endpoints": endpoints, + "mock_server_url": self.mock_server_url, + "security_schemes": self.spec_parser.get_security_schemes(), + "definitions": self.spec_parser.get_definitions(), + } + + generated_files = [] + + try: + template = self.env.get_template("test_base.py.j2") + content = template.render(context) + + if output_file: + output_path = Path(output_file) + else: + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) + output_path = self.output_dir / f"test_{safe_name}.py" + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + generated_files.append(output_path) + + except (TemplateSyntaxError, UndefinedError) as e: + raise TemplateRenderError(f"Failed to render pytest template: {e}") + + return generated_files + + def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: + """Generate test for a specific endpoint. + + Args: + endpoint_path: The API endpoint path. + method: The HTTP method. + + Returns: + String containing the test code. + """ + endpoints = self.spec_parser.get_endpoints() + + for endpoint in endpoints: + if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): + return self._generate_single_test(endpoint) + + return "" + + def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: + """Generate test code for a single endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing the test code. + """ + test_name = self._generate_test_name(endpoint) + params = self._generate_parameters(endpoint) + auth_headers = self._generate_auth_headers(endpoint) + + test_code = f''' +def test_{test_name}(base_url, {params}): + """Test {endpoint["summary"] or endpoint["path"]} endpoint.""" + url = f"{base_url}{endpoint['path']}" + + headers = {{"Content-Type": "application/json"}} + {auth_headers} + + response = requests.{endpoint['method']}(url, json={{}} if method == "POST" else None, headers=headers) + + assert response.status_code in [200, 201, 204] +''' + return test_code + + def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: + """Generate a valid test function name. + + Args: + endpoint: The endpoint dictionary. + + Returns: + A valid Python function name. + """ + path = endpoint["path"] + method = endpoint["method"] + + name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) + name = re.sub(r"_+/", "_", name) + name = re.sub(r"_+$", "", name) + + return f"{method}_{name}" if name else f"{method}_default" + + def _generate_parameters(self, endpoint: Dict[str, Any]) -> str: + """Generate parameters for test function. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing parameter declarations. + """ + params = [] + + for param in endpoint.get("parameters", []): + param_name = param["name"] + + if param["in"] == "path": + params.append(f'{param_name}="test_{param_name}"') + + elif param["in"] == "query": + params.append(f'{param_name}=None') + + return ", ".join(params) + + def _generate_auth_headers(self, endpoint: Dict[str, Any]) -> str: + """Generate authentication headers for endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing header assignments. + """ + security_requirements = endpoint.get("security", []) + + if not security_requirements: + return "" + + schemes = self.spec_parser.get_security_schemes() + auth_config = AuthConfig() + + try: + auth_config.build_from_spec(schemes, security_requirements) + except Exception: + return "" + + headers = [] + + for scheme_name in security_requirements[0].keys(): + method = auth_config.get_auth_method(scheme_name) + if method: + if method["type"] == "apiKey": + headers.append(f'headers["{method["header_name"]}"] = "test_api_key"') + elif method["type"] == "bearer": + headers.append('headers["Authorization"] = "Bearer test_token"') + + return "\n ".join(headers) if headers else ""