Files
api-testgen-cli/api_testgen/generators/jest.py
CI Bot d369d3b1f8
Some checks failed
CI / test (3.10) (push) Failing after 1m21s
CI / test (3.11) (push) Failing after 1m19s
CI / test (3.9) (push) Failing after 1m22s
CI / lint (push) Failing after 43s
fix: Apply black formatting to resolve CI formatting issues
2026-02-06 07:56:02 +00:00

171 lines
5.0 KiB
Python

"""Jest test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser
from ..core.exceptions import TemplateRenderError
class JestGenerator:
"""Generate Jest-compatible integration test templates."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
):
"""Initialize the jest generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.env = Environment(
loader=FileSystemLoader(
str(Path(__file__).parent.parent.parent / "templates" / "jest")
),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate Jest test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
context = {
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("api.test.js.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"{safe_name}.test.js"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render Jest template: {e}")
return generated_files
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
describe_name = endpoint["summary"] or endpoint["path"]
params = self._generate_params(endpoint)
endpoint_path = endpoint["path"]
endpoint_method = endpoint["method"]
test_code = f"""
describe('{describe_name}', () => {{
it('should {endpoint_method.upper()} {endpoint_path}', async () => {{
const response = await request(baseUrl)
.{endpoint_method}('{endpoint_path}'{params});
expect([200, 201, 204]).toContain(response.status);
}});
}});
"""
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid JavaScript function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
return f"{method}_{name}" if name else f"{method}_default"
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameters for test request.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter chain.
"""
parts = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
parts.append(f'{param_name}="test_{param_name}"')
elif param["in"] == "query":
parts.append(f"{param_name}")
if parts:
return ", {" + ", ".join(parts) + "}"
return ""