From 6a284941ebabd96c6a6b39cc7d512ddcef10ea05 Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Fri, 6 Feb 2026 04:49:19 +0000 Subject: [PATCH] Add generators and mock server modules --- api_testgen/generators/go.py | 229 +++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 api_testgen/generators/go.py diff --git a/api_testgen/generators/go.py b/api_testgen/generators/go.py new file mode 100644 index 0000000..d10f481 --- /dev/null +++ b/api_testgen/generators/go.py @@ -0,0 +1,229 @@ +"""Go test generator.""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError + +from ..core import SpecParser, AuthConfig +from ..core.exceptions import GeneratorError, TemplateRenderError + + +class GoGenerator: + """Generate Go-compatible test files.""" + + def __init__( + self, + spec_parser: SpecParser, + output_dir: str = "tests", + mock_server_url: str = "http://localhost:4010", + package_name: str = "apitest", + ): + """Initialize the go generator. + + Args: + spec_parser: The OpenAPI specification parser. + output_dir: Directory for generated test files. + mock_server_url: URL of the mock server for testing. + package_name: Go package name for test files. + """ + self.spec_parser = spec_parser + self.output_dir = Path(output_dir) + self.mock_server_url = mock_server_url + self.package_name = package_name + self.env = Environment( + loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")), + trim_blocks=True, + lstrip_blocks=True, + ) + + def generate(self, output_file: Optional[str] = None) -> List[Path]: + """Generate Go test files. + + Args: + output_file: Optional specific output file path. + + Returns: + List of generated file paths. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + endpoints = self.spec_parser.get_endpoints() + info = self.spec_parser.get_info() + + grouped_endpoints = self._group_endpoints_by_path(endpoints) + + context = { + "package_name": self.package_name, + "api_title": info["title"], + "api_version": info["version"], + "endpoints": endpoints, + "grouped_endpoints": grouped_endpoints, + "mock_server_url": self.mock_server_url, + "security_schemes": self.spec_parser.get_security_schemes(), + "definitions": self.spec_parser.get_definitions(), + } + + generated_files = [] + + try: + template = self.env.get_template("api_test.go.j2") + content = template.render(context) + + if output_file: + output_path = Path(output_file) + else: + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) + output_path = self.output_dir / f"{safe_name}_test.go" + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + generated_files.append(output_path) + + except (TemplateSyntaxError, UndefinedError) as e: + raise TemplateRenderError(f"Failed to render Go template: {e}") + + return generated_files + + def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """Group endpoints by their path. + + Args: + endpoints: List of endpoint dictionaries. + + Returns: + Dictionary mapping paths to their endpoints. + """ + grouped = {} + for endpoint in endpoints: + path = endpoint["path"] + if path not in grouped: + grouped[path] = [] + grouped[path].append(endpoint) + return grouped + + def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: + """Generate test for a specific endpoint. + + Args: + endpoint_path: The API endpoint path. + method: The HTTP method. + + Returns: + String containing the test code. + """ + endpoints = self.spec_parser.get_endpoints() + + for endpoint in endpoints: + if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): + return self._generate_single_test(endpoint) + + return "" + + def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: + """Generate test code for a single endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing the test code. + """ + test_name = self._generate_test_name(endpoint) + params = self._generate_params(endpoint) + url_params = self._generate_url_params(endpoint) + + test_code = f''' +func Test{test_name}(t *testing.T) {{ + client := &http.Client{{Timeout: 10 * time.Second}} + url := baseURL + "{endpoint['path']}" + + var req *http.Request + var err error + + {params} + + req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil) + if err != nil {{ + t.Fatal(err) + }} + + for k, v := range getAuthHeaders() {{ + req.Header.Set(k, v) + }} + + resp, err := client.Do(req) + if err != nil {{ + t.Fatalf("Request failed: %v", err) + }} + defer resp.Body.Close() + + if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{ + t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode) + }} +}} +''' + return test_code + + def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: + """Generate a valid Go test function name. + + Args: + endpoint: The endpoint dictionary. + + Returns: + A valid Go function name. + """ + path = endpoint["path"] + method = endpoint["method"] + + name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) + name = re.sub(r"_+/", "_", name) + name = re.sub(r"_+$", "", name) + name = name.title().replace("_", "") + + return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default" + + def _generate_params(self, endpoint: Dict[str, Any]) -> str: + """Generate parameter variables for test. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing parameter declarations. + """ + params = [] + + for param in endpoint.get("parameters", []): + param_name = param["name"] + + if param["in"] == "path": + params.append(f'{param_name} := "test_{param_name}"') + params.append(f'url = strings.Replace(url, "{{"+param_name+"}}", {param_name}, 1)') + + elif param["in"] == "query": + params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}') + params.append(f'url += "?" + q.Encode()') + + return "\n ".join(params) if params else "" + + def _generate_url_params(self, endpoint: Dict[str, Any]) -> str: + """Generate URL parameters. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing URL parameter handling. + """ + path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"] + query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"] + + parts = [] + + for param in path_params: + parts.append(f'strings.Replace(url, "{{"+param["name"]+"}}", "test_' + param['name'] + '", 1)') + + return ""