"""Go test generator.""" import re from pathlib import Path from typing import Any, Dict, List, Optional from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError from ..core import SpecParser, AuthConfig from ..core.exceptions import GeneratorError, TemplateRenderError class GoGenerator: """Generate Go-compatible test files.""" def __init__( self, spec_parser: SpecParser, output_dir: str = "tests", mock_server_url: str = "http://localhost:4010", package_name: str = "apitest", ): """Initialize the go generator. Args: spec_parser: The OpenAPI specification parser. output_dir: Directory for generated test files. mock_server_url: URL of the mock server for testing. package_name: Go package name for test files. """ self.spec_parser = spec_parser self.output_dir = Path(output_dir) self.mock_server_url = mock_server_url self.package_name = package_name self.env = Environment( loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")), trim_blocks=True, lstrip_blocks=True, ) def generate(self, output_file: Optional[str] = None) -> List[Path]: """Generate Go test files. Args: output_file: Optional specific output file path. Returns: List of generated file paths. """ self.output_dir.mkdir(parents=True, exist_ok=True) endpoints = self.spec_parser.get_endpoints() info = self.spec_parser.get_info() grouped_endpoints = self._group_endpoints_by_path(endpoints) context = { "package_name": self.package_name, "api_title": info["title"], "api_version": info["version"], "endpoints": endpoints, "grouped_endpoints": grouped_endpoints, "mock_server_url": self.mock_server_url, "security_schemes": self.spec_parser.get_security_schemes(), "definitions": self.spec_parser.get_definitions(), } generated_files = [] try: template = self.env.get_template("api_test.go.j2") content = template.render(context) if output_file: output_path = Path(output_file) else: safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) output_path = self.output_dir / f"{safe_name}_test.go" output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(content) generated_files.append(output_path) except (TemplateSyntaxError, UndefinedError) as e: raise TemplateRenderError(f"Failed to render Go template: {e}") return generated_files def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: """Group endpoints by their path. Args: endpoints: List of endpoint dictionaries. Returns: Dictionary mapping paths to their endpoints. """ grouped = {} for endpoint in endpoints: path = endpoint["path"] if path not in grouped: grouped[path] = [] grouped[path].append(endpoint) return grouped def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: """Generate test for a specific endpoint. Args: endpoint_path: The API endpoint path. method: The HTTP method. Returns: String containing the test code. """ endpoints = self.spec_parser.get_endpoints() for endpoint in endpoints: if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): return self._generate_single_test(endpoint) return "" def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: """Generate test code for a single endpoint. Args: endpoint: The endpoint dictionary. Returns: String containing the test code. """ test_name = self._generate_test_name(endpoint) params = self._generate_params(endpoint) url_params = self._generate_url_params(endpoint) test_code = f''' func Test{test_name}(t *testing.T) {{ client := &http.Client{{Timeout: 10 * time.Second}} url := baseURL + "{endpoint['path']}" var req *http.Request var err error {params} req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil) if err != nil {{ t.Fatal(err) }} for k, v := range getAuthHeaders() {{ req.Header.Set(k, v) }} resp, err := client.Do(req) if err != nil {{ t.Fatalf("Request failed: %v", err) }} defer resp.Body.Close() if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{ t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode) }} }} ''' return test_code def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: """Generate a valid Go test function name. Args: endpoint: The endpoint dictionary. Returns: A valid Go function name. """ path = endpoint["path"] method = endpoint["method"] name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) name = re.sub(r"_+/", "_", name) name = re.sub(r"_+$", "", name) name = name.title().replace("_", "") return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default" def _generate_params(self, endpoint: Dict[str, Any]) -> str: """Generate parameter variables for test. Args: endpoint: The endpoint dictionary. Returns: String containing parameter declarations. """ params = [] for param in endpoint.get("parameters", []): param_name = param["name"] if param["in"] == "path": params.append(f'{param_name} := "test_{param_name}"') params.append(f'url = strings.Replace(url, "{{"+param_name+"}}", {param_name}, 1)') elif param["in"] == "query": params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}') params.append(f'url += "?" + q.Encode()') return "\n ".join(params) if params else "" def _generate_url_params(self, endpoint: Dict[str, Any]) -> str: """Generate URL parameters. Args: endpoint: The endpoint dictionary. Returns: String containing URL parameter handling. """ path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"] query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"] parts = [] for param in path_params: parts.append(f'strings.Replace(url, "{{"+param["name"]+"}}", "test_' + param['name'] + '", 1)') return ""