Add generators and mock server modules
This commit is contained in:
229
api_testgen/generators/go.py
Normal file
229
api_testgen/generators/go.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Go test generator."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
|
||||
|
||||
from ..core import SpecParser, AuthConfig
|
||||
from ..core.exceptions import GeneratorError, TemplateRenderError
|
||||
|
||||
|
||||
class GoGenerator:
|
||||
"""Generate Go-compatible test files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_parser: SpecParser,
|
||||
output_dir: str = "tests",
|
||||
mock_server_url: str = "http://localhost:4010",
|
||||
package_name: str = "apitest",
|
||||
):
|
||||
"""Initialize the go generator.
|
||||
|
||||
Args:
|
||||
spec_parser: The OpenAPI specification parser.
|
||||
output_dir: Directory for generated test files.
|
||||
mock_server_url: URL of the mock server for testing.
|
||||
package_name: Go package name for test files.
|
||||
"""
|
||||
self.spec_parser = spec_parser
|
||||
self.output_dir = Path(output_dir)
|
||||
self.mock_server_url = mock_server_url
|
||||
self.package_name = package_name
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
def generate(self, output_file: Optional[str] = None) -> List[Path]:
|
||||
"""Generate Go test files.
|
||||
|
||||
Args:
|
||||
output_file: Optional specific output file path.
|
||||
|
||||
Returns:
|
||||
List of generated file paths.
|
||||
"""
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
info = self.spec_parser.get_info()
|
||||
|
||||
grouped_endpoints = self._group_endpoints_by_path(endpoints)
|
||||
|
||||
context = {
|
||||
"package_name": self.package_name,
|
||||
"api_title": info["title"],
|
||||
"api_version": info["version"],
|
||||
"endpoints": endpoints,
|
||||
"grouped_endpoints": grouped_endpoints,
|
||||
"mock_server_url": self.mock_server_url,
|
||||
"security_schemes": self.spec_parser.get_security_schemes(),
|
||||
"definitions": self.spec_parser.get_definitions(),
|
||||
}
|
||||
|
||||
generated_files = []
|
||||
|
||||
try:
|
||||
template = self.env.get_template("api_test.go.j2")
|
||||
content = template.render(context)
|
||||
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
|
||||
output_path = self.output_dir / f"{safe_name}_test.go"
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content)
|
||||
generated_files.append(output_path)
|
||||
|
||||
except (TemplateSyntaxError, UndefinedError) as e:
|
||||
raise TemplateRenderError(f"Failed to render Go template: {e}")
|
||||
|
||||
return generated_files
|
||||
|
||||
def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Group endpoints by their path.
|
||||
|
||||
Args:
|
||||
endpoints: List of endpoint dictionaries.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping paths to their endpoints.
|
||||
"""
|
||||
grouped = {}
|
||||
for endpoint in endpoints:
|
||||
path = endpoint["path"]
|
||||
if path not in grouped:
|
||||
grouped[path] = []
|
||||
grouped[path].append(endpoint)
|
||||
return grouped
|
||||
|
||||
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
|
||||
"""Generate test for a specific endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_path: The API endpoint path.
|
||||
method: The HTTP method.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
|
||||
for endpoint in endpoints:
|
||||
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
|
||||
return self._generate_single_test(endpoint)
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate test code for a single endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
test_name = self._generate_test_name(endpoint)
|
||||
params = self._generate_params(endpoint)
|
||||
url_params = self._generate_url_params(endpoint)
|
||||
|
||||
test_code = f'''
|
||||
func Test{test_name}(t *testing.T) {{
|
||||
client := &http.Client{{Timeout: 10 * time.Second}}
|
||||
url := baseURL + "{endpoint['path']}"
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
{params}
|
||||
|
||||
req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil)
|
||||
if err != nil {{
|
||||
t.Fatal(err)
|
||||
}}
|
||||
|
||||
for k, v := range getAuthHeaders() {{
|
||||
req.Header.Set(k, v)
|
||||
}}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {{
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{
|
||||
t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode)
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
return test_code
|
||||
|
||||
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate a valid Go test function name.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
A valid Go function name.
|
||||
"""
|
||||
path = endpoint["path"]
|
||||
method = endpoint["method"]
|
||||
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
|
||||
name = re.sub(r"_+/", "_", name)
|
||||
name = re.sub(r"_+$", "", name)
|
||||
name = name.title().replace("_", "")
|
||||
|
||||
return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default"
|
||||
|
||||
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate parameter variables for test.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing parameter declarations.
|
||||
"""
|
||||
params = []
|
||||
|
||||
for param in endpoint.get("parameters", []):
|
||||
param_name = param["name"]
|
||||
|
||||
if param["in"] == "path":
|
||||
params.append(f'{param_name} := "test_{param_name}"')
|
||||
params.append(f'url = strings.Replace(url, "{{"+param_name+"}}", {param_name}, 1)')
|
||||
|
||||
elif param["in"] == "query":
|
||||
params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}')
|
||||
params.append(f'url += "?" + q.Encode()')
|
||||
|
||||
return "\n ".join(params) if params else ""
|
||||
|
||||
def _generate_url_params(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate URL parameters.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing URL parameter handling.
|
||||
"""
|
||||
path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"]
|
||||
query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"]
|
||||
|
||||
parts = []
|
||||
|
||||
for param in path_params:
|
||||
parts.append(f'strings.Replace(url, "{{"+param["name"]+"}}", "test_' + param['name'] + '", 1)')
|
||||
|
||||
return ""
|
||||
Reference in New Issue
Block a user