fix: add Gitea Actions CI workflow for automated testing

This commit is contained in:
CI Bot
2026-02-06 06:37:08 +00:00
parent 40a6a4f7d4
commit 839317c44b
24 changed files with 3115 additions and 0 deletions

3
api_testgen/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""API TestGen - OpenAPI Specification Test Generator."""
__version__ = "0.1.0"

View File

@@ -0,0 +1 @@
"""CLI module for API TestGen."""

277
api_testgen/cli/main.py Normal file
View File

@@ -0,0 +1,277 @@
"""CLI interface for API TestGen."""
from pathlib import Path
from typing import Optional
import click
import yaml
from ..core import SpecParser, AuthConfig
from ..core.exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
from ..generators import PytestGenerator, JestGenerator, GoGenerator
from ..mocks import MockServerGenerator
@click.group()
@click.version_option(version="0.1.0")
@click.option(
"--spec",
"-s",
type=click.Path(exists=True, file_okay=True, dir_okay=False),
help="Path to OpenAPI specification file",
)
@click.option(
"--output",
"-o",
type=click.Path(file_okay=False, dir_okay=True),
help="Output directory for generated files",
)
@click.option(
"--mock-url",
default="http://localhost:4010",
help="URL of the mock server",
)
@click.pass_context
def main(
ctx: click.Context,
spec: str,
output: str,
mock_url: str,
):
"""API TestGen - Generate integration tests from OpenAPI specifications."""
ctx.ensure_object(dict)
ctx.obj["spec"] = spec
ctx.obj["output"] = output or "./generated"
ctx.obj["mock_url"] = mock_url
@main.command("parse")
@click.pass_context
def parse_spec(ctx: click.Context):
"""Parse and validate an OpenAPI specification."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
try:
parser = SpecParser(spec_path)
spec = parser.load()
info = parser.get_info()
endpoints = parser.get_endpoints()
security_schemes = parser.get_security_schemes()
click.echo(f"API: {info['title']} v{info['version']}")
click.echo(f"OpenAPI Version: {parser.version}")
click.echo(f"Base Path: {parser.base_path}")
click.echo(f"Endpoints: {len(endpoints)}")
click.echo(f"Security Schemes: {len(security_schemes)}")
click.echo()
for endpoint in endpoints:
click.echo(f" {endpoint['method'].upper():6} {endpoint['path']}")
ctx.obj["parser"] = parser
except InvalidOpenAPISpecError as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
except UnsupportedVersionError as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("generate")
@click.argument("framework", type=click.Choice(["pytest", "jest", "go"]))
@click.option(
"--output-file",
"-f",
type=click.Path(file_okay=True, dir_okay=False),
help="Specific output file path",
)
@click.option(
"--package-name",
default="apitest",
help="Go package name (only for go framework)",
)
@click.pass_context
def generate_tests(
ctx: click.Context,
framework: str,
output_file: str,
package_name: str,
):
"""Generate test files for a framework (pytest, jest, or go)."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
output_dir = ctx.obj["output"]
mock_url = ctx.obj["mock_url"]
try:
parser = SpecParser(spec_path)
parser.load()
if framework == "pytest":
generator = PytestGenerator(parser, output_dir, mock_url)
files = generator.generate(output_file)
elif framework == "jest":
generator = JestGenerator(parser, output_dir, mock_url)
files = generator.generate(output_file)
elif framework == "go":
generator = GoGenerator(parser, output_dir, mock_url, package_name)
files = generator.generate(output_file)
click.echo(f"Generated {len(files)} test file(s):")
for f in files:
click.echo(f" - {f}")
except Exception as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("mock")
@click.option(
"--no-prism-config",
is_flag=True,
help="Skip generating prism-config.json",
)
@click.option(
"--no-docker-compose",
is_flag=True,
help="Skip generating docker-compose.yml",
)
@click.option(
"--no-dockerfile",
is_flag=True,
help="Skip generating Dockerfile",
)
@click.pass_context
def generate_mock(
ctx: click.Context,
no_prism_config: bool,
no_docker_compose: bool,
no_dockerfile: bool,
):
"""Generate mock server configuration files."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
output_dir = ctx.obj["output"]
try:
parser = SpecParser(spec_path)
parser.load()
generator = MockServerGenerator(parser, output_dir)
files = generator.generate(
prism_config=not no_prism_config,
docker_compose=not no_docker_compose,
dockerfile=not no_dockerfile,
)
click.echo(f"Generated {len(files)} mock server file(s):")
for f in files:
click.echo(f" - {f}")
except Exception as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("all")
@click.argument("framework", type=click.Choice(["pytest", "jest", "go"]))
@click.pass_context
def generate_all(
ctx: click.Context,
framework: str,
):
"""Generate test files and mock server configuration."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
output_dir = ctx.obj["output"]
mock_url = ctx.obj["mock_url"]
try:
parser = SpecParser(spec_path)
parser.load()
click.echo("Generating tests...")
if framework == "pytest":
generator = PytestGenerator(parser, output_dir, mock_url)
files = generator.generate()
elif framework == "jest":
generator = JestGenerator(parser, output_dir, mock_url)
files = generator.generate()
elif framework == "go":
generator = GoGenerator(parser, output_dir, mock_url)
files = generator.generate()
click.echo(f"Generated {len(files)} test file(s)")
click.echo("Generating mock server configuration...")
mock_generator = MockServerGenerator(parser, output_dir)
mock_files = mock_generator.generate()
click.echo(f"Generated {len(mock_files)} mock server file(s)")
click.echo("\nAll files generated successfully!")
except Exception as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("auth")
@click.argument("scheme_name")
@click.option("--type", "auth_type", type=click.Choice(["apiKey", "bearer", "basic"]), help="Authentication type")
@click.option("--header", help="Header name for API key", default="X-API-Key")
@click.option("--token", help="Bearer token or API key value")
@click.option("--username", help="Username for Basic auth")
@click.option("--password", help="Password for Basic auth")
@click.pass_context
def configure_auth(
ctx: click.Context,
scheme_name: str,
auth_type: str,
header: str,
token: str,
username: str,
password: str,
):
"""Configure authentication for a security scheme."""
auth_config = AuthConfig()
if auth_type == "apiKey":
auth_config.add_api_key(scheme_name, header, token or "")
elif auth_type == "bearer":
auth_config.add_bearer(scheme_name, token or "")
elif auth_type == "basic":
auth_config.add_basic(scheme_name, username or "", password or "")
click.echo(f"Authentication scheme '{scheme_name}' configured:")
methods = auth_config.get_all_methods()
for name, method in methods.items():
click.echo(f" - {name}: {method['type'].value}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
"""Core module for API TestGen."""
from .spec_parser import SpecParser
from .auth import AuthConfig
__all__ = ["SpecParser", "AuthConfig"]

313
api_testgen/core/auth.py Normal file
View File

@@ -0,0 +1,313 @@
"""Authentication configuration for API TestGen."""
from typing import Any, Dict, List, Optional, Union
from enum import Enum
from .exceptions import AuthConfigError, MissingSecuritySchemeError
class AuthType(str, Enum):
"""Types of authentication."""
API_KEY = "apiKey"
BEARER = "bearer"
BASIC = "basic"
NONE = "none"
class AuthConfig:
"""Authentication configuration for generated tests."""
def __init__(self):
"""Initialize authentication configuration."""
self._auth_methods: Dict[str, Dict[str, Any]] = {}
def add_api_key(
self,
scheme_name: str,
header_name: str = "X-API-Key",
api_key: str = "",
) -> "AuthConfig":
"""Add API key authentication.
Args:
scheme_name: Name of the security scheme in OpenAPI spec.
header_name: Name of the header containing the API key.
api_key: The API key value (can be set later).
Returns:
Self for method chaining.
"""
self._auth_methods[scheme_name] = {
"type": AuthType.API_KEY,
"header_name": header_name,
"api_key": api_key,
}
return self
def add_bearer(
self,
scheme_name: str,
token: str = "",
token_prefix: str = "Bearer",
) -> "AuthConfig":
"""Add Bearer token authentication.
Args:
scheme_name: Name of the security scheme in OpenAPI spec.
token: The Bearer token value (can be set later).
token_prefix: The token prefix (default: Bearer).
Returns:
Self for method chaining.
"""
self._auth_methods[scheme_name] = {
"type": AuthType.BEARER,
"token": token,
"token_prefix": token_prefix,
}
return self
def add_basic(
self,
scheme_name: str,
username: str = "",
password: str = "",
) -> "AuthConfig":
"""Add Basic authentication.
Args:
scheme_name: Name of the security scheme in OpenAPI spec.
username: The username (can be set later).
password: The password (can be set later).
Returns:
Self for method chaining.
"""
self._auth_methods[scheme_name] = {
"type": AuthType.BASIC,
"username": username,
"password": password,
}
return self
def get_auth_method(self, scheme_name: str) -> Optional[Dict[str, Any]]:
"""Get authentication method by scheme name.
Args:
scheme_name: Name of the security scheme.
Returns:
Authentication method configuration or None.
"""
return self._auth_methods.get(scheme_name)
def get_all_methods(self) -> Dict[str, Dict[str, Any]]:
"""Get all configured authentication methods.
Returns:
Dictionary of scheme names and their configurations.
"""
return self._auth_methods.copy()
def get_headers(self, scheme_name: str) -> Dict[str, str]:
"""Get authentication headers for a scheme.
Args:
scheme_name: Name of the security scheme.
Returns:
Dictionary of header names and values.
Raises:
AuthConfigError: If scheme is not configured.
"""
method = self.get_auth_method(scheme_name)
if not method:
raise AuthConfigError(f"Authentication scheme '{scheme_name}' not configured")
if method["type"] == AuthType.API_KEY:
return {method["header_name"]: method["api_key"]}
elif method["type"] == AuthType.BEARER:
return {"Authorization": f"{method['token_prefix']} {method['token']}"}
elif method["type"] == AuthType.BASIC:
import base64
credentials = f"{method['username']}:{method['password']}"
encoded = base64.b64encode(credentials.encode()).decode()
return {"Authorization": f"Basic {encoded}"}
return {}
def build_from_spec(
self,
security_schemes: Dict[str, Any],
security_requirements: List[Dict[str, Any]],
) -> "AuthConfig":
"""Build auth configuration from OpenAPI security schemes.
Args:
security_schemes: Security schemes from OpenAPI spec.
security_requirements: Security requirements from endpoint.
Returns:
Self for method chaining.
Raises:
MissingSecuritySchemeError: If required scheme is not defined.
"""
for requirement in security_requirements:
for scheme_name in requirement.keys():
if scheme_name not in self._auth_methods:
if scheme_name not in security_schemes:
raise MissingSecuritySchemeError(
f"Security scheme '{scheme_name}' not found in spec"
)
scheme = security_schemes[scheme_name]
self._add_scheme_from_spec(scheme_name, scheme)
return self
def _add_scheme_from_spec(self, scheme_name: str, scheme: Dict[str, Any]) -> None:
"""Add authentication scheme from OpenAPI spec definition.
Args:
scheme_name: Name of the security scheme.
scheme: The security scheme definition from OpenAPI spec.
"""
scheme_type = scheme.get("type", "")
if scheme_type == "apiKey":
self.add_api_key(
scheme_name,
header_name=scheme.get("name", "X-API-Key"),
)
elif scheme_type == "http":
scheme_scheme = scheme.get("scheme", "").lower()
if scheme_scheme == "bearer":
self.add_bearer(scheme_name)
elif scheme_scheme == "basic":
self.add_basic(scheme_name)
elif scheme_type == "openIdConnect":
self.add_bearer(scheme_name)
elif scheme_type == "oauth2":
self.add_bearer(scheme_name)
def generate_auth_code(self, scheme_name: str, framework: str = "pytest") -> str:
"""Generate authentication code for a test framework.
Args:
scheme_name: Name of the security scheme.
framework: Target test framework (pytest, jest, go).
Returns:
String containing authentication code snippet.
"""
method = self.get_auth_method(scheme_name)
if not method:
return ""
if framework == "pytest":
return self._generate_pytest_auth(method)
elif framework == "jest":
return self._generate_jest_auth(method)
elif framework == "go":
return self._generate_go_auth(method)
return ""
def _generate_pytest_auth(self, method: Dict[str, Any]) -> str:
"""Generate pytest authentication code.
Args:
method: Authentication method configuration.
Returns:
String containing pytest auth code.
"""
if method["type"] == AuthType.API_KEY:
return f'''
@pytest.fixture
def api_key_headers():
return {{"{method['header_name']}": "{method['api_key']}"}}
'''
elif method["type"] == AuthType.BEARER:
return f'''
@pytest.fixture
def bearer_headers():
return {{"Authorization": "{method['token_prefix']} {method['token']}"}}
'''
elif method["type"] == AuthType.BASIC:
return f'''
import base64
@pytest.fixture
def basic_headers():
credentials = f"{{"{method['username']}"}}:{{"{method['password']}"}}"
encoded = base64.b64encode(credentials.encode()).decode()
return {{"Authorization": f"Basic {{encoded}}"}}
'''
return ""
def _generate_jest_auth(self, method: Dict[str, Any]) -> str:
"""Generate Jest authentication code.
Args:
method: Authentication method configuration.
Returns:
String containing Jest auth code.
"""
if method["type"] == AuthType.API_KEY:
return f'''
const getApiKeyHeaders = () => ({{
"{method['header_name']}": process.env.API_KEY || "{method['api_key']}",
}});
'''
elif method["type"] == AuthType.BEARER:
return f'''
const getBearerHeaders = () => ({{
Authorization: `${{process.env.TOKEN_PREFIX || "{method['token_prefix']}"}} ${{process.env.TOKEN || "{method['token']}"}}`,
}});
'''
elif method["type"] == AuthType.BASIC:
return f'''
const getBasicHeaders = () => {{
const credentials = Buffer.from(`${{process.env.USERNAME || "{method['username']}"}}:${{process.env.PASSWORD || "{method['password']}"}}`).toString('base64');
return {{ Authorization: `Basic ${{credentials}}` }};
}};
'''
return ""
def _generate_go_auth(self, method: Dict[str, Any]) -> str:
"""Generate Go authentication code.
Args:
method: Authentication method configuration.
Returns:
String containing Go auth code.
"""
if method["type"] == AuthType.API_KEY:
return f'''
func getAPIKeyHeaders() map[string]string {{
return map[string]string{{
"{method['header_name']}": os.Getenv("API_KEY"),
}}
}}
'''
elif method["type"] == AuthType.BEARER:
return f'''
func getBearerHeaders() map[string]string {{
return map[string]string{{
"Authorization": fmt.Sprintf("%s %s", os.Getenv("TOKEN_PREFIX"), os.Getenv("TOKEN")),
}}
}}
'''
elif method["type"] == AuthType.BASIC:
return f'''
func getBasicHeaders(username, password string) map[string]string {{
auth := username + ":" + password
encoded := base64.StdEncoding.EncodeToString([]byte(auth))
return map[string]string{{
"Authorization": "Basic " + encoded,
}}
}}
'''
return ""

View File

@@ -0,0 +1,36 @@
"""Custom exceptions for API TestGen."""
class SpecParserError(Exception):
"""Base exception for spec parser errors."""
pass
class InvalidOpenAPISpecError(SpecParserError):
"""Raised when OpenAPI specification is invalid."""
pass
class UnsupportedVersionError(SpecParserError):
"""Raised when OpenAPI version is not supported."""
pass
class AuthConfigError(Exception):
"""Base exception for auth configuration errors."""
pass
class MissingSecuritySchemeError(AuthConfigError):
"""Raised when security scheme is not defined in spec."""
pass
class GeneratorError(Exception):
"""Base exception for generator errors."""
pass
class TemplateRenderError(GeneratorError):
"""Raised when template rendering fails."""
pass

View File

@@ -0,0 +1,307 @@
"""OpenAPI Specification Parser."""
import yaml
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from openapi_spec_validator import validate
from openapi_spec_validator.versions import consts as validator_consts
from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
class SpecParser:
"""Parse and validate OpenAPI/Swagger specifications."""
SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
def __init__(self, spec_path: Union[str, Path]):
"""Initialize the spec parser.
Args:
spec_path: Path to OpenAPI specification file (YAML or JSON).
"""
self.spec_path = Path(spec_path)
self.spec: Dict[str, Any] = {}
self.version: str = ""
self.base_path: str = ""
self.servers: List[Dict[str, str]] = []
def load(self) -> Dict[str, Any]:
"""Load and validate the OpenAPI specification.
Returns:
The parsed specification dictionary.
Raises:
InvalidOpenAPISpecError: If the specification is invalid.
UnsupportedVersionError: If the OpenAPI version is not supported.
"""
self.spec = self._load_file()
self._validate()
self._extract_metadata()
return self.spec
def _load_file(self) -> Dict[str, Any]:
"""Load the specification file.
Returns:
Parsed specification dictionary.
Raises:
InvalidOpenAPISpecError: If the file cannot be loaded or parsed.
"""
if not self.spec_path.exists():
raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}")
try:
with open(self.spec_path, 'r', encoding='utf-8') as f:
if self.spec_path.suffix in ['.yaml', '.yml']:
return yaml.safe_load(f) or {}
elif self.spec_path.suffix == '.json':
return json.load(f)
else:
return yaml.safe_load(f) or {}
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}")
def _validate(self) -> None:
"""Validate the specification.
Raises:
InvalidOpenAPISpecError: If the specification is invalid.
UnsupportedVersionError: If the OpenAPI version is not supported.
"""
try:
validate(self.spec)
except Exception as e:
raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}")
version = self._get_version()
if version not in self.SUPPORTED_VERSIONS:
raise UnsupportedVersionError(
f"Unsupported OpenAPI version: {version}. "
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
)
def _get_version(self) -> str:
"""Extract the OpenAPI version from the spec.
Returns:
The OpenAPI version string.
"""
if "openapi" in self.spec:
return self.spec["openapi"]
elif "swagger" in self.spec:
return self.spec["swagger"]
return "2.0"
def _extract_metadata(self) -> None:
"""Extract metadata from the specification."""
self.version = self._get_version()
self.base_path = self.spec.get("basePath", "")
self.servers = self.spec.get("servers", [])
def get_paths(self) -> Dict[str, Any]:
"""Get all paths from the specification.
Returns:
Dictionary of paths and their operations.
"""
return self.spec.get("paths", {})
def get_endpoints(self) -> List[Dict[str, Any]]:
"""Extract all endpoints from the specification.
Returns:
List of endpoint dictionaries with path, method, and details.
"""
endpoints = []
paths = self.get_paths()
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]:
endpoint = {
"path": path,
"method": method.lower(),
"operation_id": operation.get("operationId", ""),
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"tags": operation.get("tags", []),
"parameters": self._extract_parameters(path_item, operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"security": self._extract_security(operation),
}
endpoints.append(endpoint)
return endpoints
def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]:
"""Extract parameters from path item and operation.
Args:
path_item: The path item dictionary.
operation: The operation dictionary.
Returns:
List of parameter dictionaries.
"""
parameters = []
for param in path_item.get("parameters", []):
if param.get("in") != "body":
parameters.append(self._normalize_parameter(param))
for param in operation.get("parameters", []):
if param.get("in") != "body":
parameters.append(self._normalize_parameter(param))
return parameters
def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize a parameter.
Args:
param: The parameter dictionary.
Returns:
Normalized parameter dictionary.
"""
return {
"name": param.get("name", ""),
"in": param.get("in", ""),
"description": param.get("description", ""),
"required": param.get("required", False),
"schema": param.get("schema", {}),
"type": param.get("type", ""),
"enum": param.get("enum", []),
"default": param.get("default"),
}
def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]:
"""Extract request body from operation.
Args:
operation: The operation dictionary.
Returns:
Request body dictionary or None.
"""
if self.version.startswith("3."):
request_body = operation.get("requestBody", {})
if not request_body:
return None
content = request_body.get("content", {})
media_types = list(content.keys())
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"media_types": media_types,
"schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {},
}
else:
params = operation.get("parameters", [])
for param in params:
if param.get("in") == "body":
return {
"description": param.get("description", ""),
"required": param.get("required", False),
"schema": param.get("schema", {}),
}
return None
def _extract_responses(self, operation: Dict) -> Dict[str, Any]:
"""Extract responses from operation.
Args:
operation: The operation dictionary.
Returns:
Dictionary of response status codes and their details.
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
content = response.get("content", {})
if self.version.startswith("3."):
media_types = list(content.keys())
schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {}
else:
schema = response.get("schema", {})
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"media_type": list(content.keys())[0] if content else "application/json",
}
return responses
def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]:
"""Extract security requirements from operation.
Args:
operation: The operation dictionary.
Returns:
List of security requirement dictionaries.
"""
return operation.get("security", self.spec.get("security", []))
def get_security_schemes(self) -> Dict[str, Any]:
"""Get security schemes from the specification.
Returns:
Dictionary of security scheme names and their definitions.
"""
if self.version.startswith("3."):
return self.spec.get("components", {}).get("securitySchemes", {})
else:
return self.spec.get("securityDefinitions", {})
def get_definitions(self) -> Dict[str, Any]:
"""Get schema definitions from the specification.
Returns:
Dictionary of schema definitions.
"""
if self.version.startswith("3."):
return self.spec.get("components", {}).get("schemas", {})
else:
return self.spec.get("definitions", {})
def get_info(self) -> Dict[str, str]:
"""Get API info from the specification.
Returns:
Dictionary with title, version, and description.
"""
info = self.spec.get("info", {})
return {
"title": info.get("title", "API"),
"version": info.get("version", "1.0.0"),
"description": info.get("description", ""),
}
def to_dict(self) -> Dict[str, Any]:
"""Convert the spec to a dictionary.
Returns:
Dictionary representation of the parsed spec.
"""
return {
"version": self.version,
"base_path": self.base_path,
"servers": self.servers,
"info": self.get_info(),
"paths": self.get_paths(),
"endpoints": self.get_endpoints(),
"security_schemes": self.get_security_schemes(),
"definitions": self.get_definitions(),
}

View File

@@ -0,0 +1,7 @@
"""Generators module for API TestGen."""
from .pytest import PytestGenerator
from .jest import JestGenerator
from .go import GoGenerator
__all__ = ["PytestGenerator", "JestGenerator", "GoGenerator"]

View File

@@ -0,0 +1,229 @@
"""Go test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser, AuthConfig
from ..core.exceptions import GeneratorError, TemplateRenderError
class GoGenerator:
"""Generate Go-compatible test files."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
package_name: str = "apitest",
):
"""Initialize the go generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
package_name: Go package name for test files.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.package_name = package_name
self.env = Environment(
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate Go test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
grouped_endpoints = self._group_endpoints_by_path(endpoints)
context = {
"package_name": self.package_name,
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"grouped_endpoints": grouped_endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("api_test.go.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"{safe_name}_test.go"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render Go template: {e}")
return generated_files
def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""Group endpoints by their path.
Args:
endpoints: List of endpoint dictionaries.
Returns:
Dictionary mapping paths to their endpoints.
"""
grouped = {}
for endpoint in endpoints:
path = endpoint["path"]
if path not in grouped:
grouped[path] = []
grouped[path].append(endpoint)
return grouped
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
test_name = self._generate_test_name(endpoint)
params = self._generate_params(endpoint)
url_params = self._generate_url_params(endpoint)
test_code = f'''
func Test{test_name}(t *testing.T) {{
client := &http.Client{{Timeout: 10 * time.Second}}
url := baseURL + "{endpoint['path']}"
var req *http.Request
var err error
{params}
req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil)
if err != nil {{
t.Fatal(err)
}}
for k, v := range getAuthHeaders() {{
req.Header.Set(k, v)
}}
resp, err := client.Do(req)
if err != nil {{
t.Fatalf("Request failed: %v", err)
}}
defer resp.Body.Close()
if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{
t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode)
}}
}}
'''
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid Go test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid Go function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
name = name.title().replace("_", "")
return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default"
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameter variables for test.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter declarations.
"""
params = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
params.append(f'{param_name} := "test_{param_name}"')
params.append(f'url = strings.Replace(url, "{{'+param_name+'}}", {param_name}, 1)')
elif param["in"] == "query":
params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}')
params.append(f'url += "?" + q.Encode()')
return "\n ".join(params) if params else ""
def _generate_url_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate URL parameters.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing URL parameter handling.
"""
path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"]
query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"]
parts = []
for param in path_params:
parts.append(f'strings.Replace(url, "{{'+param['name']+'}}", "test_' + param['name'] + '", 1)')
return ""

View File

@@ -0,0 +1,169 @@
"""Jest test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser, AuthConfig
from ..core.exceptions import GeneratorError, TemplateRenderError
class JestGenerator:
"""Generate Jest-compatible integration test templates."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
):
"""Initialize the jest generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.env = Environment(
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "jest")),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate Jest test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
context = {
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("api.test.js.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"{safe_name}.test.js"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render Jest template: {e}")
return generated_files
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
test_name = self._generate_test_name(endpoint)
describe_name = endpoint["summary"] or endpoint["path"]
params = self._generate_params(endpoint)
endpoint_path = endpoint["path"]
endpoint_method = endpoint["method"]
test_code = f'''
describe('{describe_name}', () => {{
it('should {endpoint_method.upper()} {endpoint_path}', async () => {{
const response = await request(baseUrl)
.{endpoint_method}('{endpoint_path}'{params});
expect([200, 201, 204]).toContain(response.status);
}});
}});
'''
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid JavaScript function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
return f"{method}_{name}" if name else f"{method}_default"
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameters for test request.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter chain.
"""
parts = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
parts.append(f'{param_name}="test_{param_name}"')
elif param["in"] == "query":
parts.append(f'{param_name}')
if parts:
return ", {" + ", ".join(parts) + "}"
return ""

View File

@@ -0,0 +1,199 @@
"""Pytest test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser, AuthConfig
from ..core.exceptions import GeneratorError, TemplateRenderError
class PytestGenerator:
"""Generate pytest-compatible integration test templates."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
):
"""Initialize the pytest generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.env = Environment(
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "pytest")),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate pytest test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
context = {
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("test_base.py.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"test_{safe_name}.py"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render pytest template: {e}")
return generated_files
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
test_name = self._generate_test_name(endpoint)
params = self._generate_parameters(endpoint)
auth_headers = self._generate_auth_headers(endpoint)
test_code = f'''
def test_{test_name}(base_url, {params}):
"""Test {endpoint["summary"] or endpoint["path"]} endpoint."""
url = f"{{base_url}}{endpoint['path']}"
headers = {{"Content-Type": "application/json"}}
{auth_headers}
response = requests.{endpoint['method']}(url, json={{}} if method == "POST" else None, headers=headers)
assert response.status_code in [200, 201, 204]
'''
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid Python function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
return f"{method}_{name}" if name else f"{method}_default"
def _generate_parameters(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameters for test function.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter declarations.
"""
params = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
params.append(f'{param_name}="test_{param_name}"')
elif param["in"] == "query":
params.append(f'{param_name}=None')
return ", ".join(params)
def _generate_auth_headers(self, endpoint: Dict[str, Any]) -> str:
"""Generate authentication headers for endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing header assignments.
"""
security_requirements = endpoint.get("security", [])
if not security_requirements:
return ""
schemes = self.spec_parser.get_security_schemes()
auth_config = AuthConfig()
try:
auth_config.build_from_spec(schemes, security_requirements)
except Exception:
return ""
headers = []
for scheme_name in security_requirements[0].keys():
method = auth_config.get_auth_method(scheme_name)
if method:
if method["type"] == "apiKey":
headers.append(f'headers["{method["header_name"]}"] = "test_api_key"')
elif method["type"] == "bearer":
headers.append('headers["Authorization"] = "Bearer test_token"')
return "\n ".join(headers) if headers else ""

View File

@@ -0,0 +1,5 @@
"""Mocks module for API TestGen."""
from .generator import MockServerGenerator
__all__ = ["MockServerGenerator"]

View File

@@ -0,0 +1,278 @@
"""Mock server generator for Prism/OpenAPI mock support."""
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from ..core import SpecParser
class MockServerGenerator:
"""Generate Prism/OpenAPI mock server configurations."""
DEFAULT_PORT = 4010
DEFAULT_HOST = "0.0.0.0"
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = ".",
):
"""Initialize the mock server generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated configuration files.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
def generate(
self,
prism_config: bool = True,
docker_compose: bool = True,
dockerfile: bool = True,
) -> List[Path]:
"""Generate mock server configuration files.
Args:
prism_config: Whether to generate prism-config.json.
docker_compose: Whether to generate docker-compose.yml.
dockerfile: Whether to generate Dockerfile.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
generated_files = []
if prism_config:
generated_files.append(self.generate_prism_config())
if docker_compose:
generated_files.append(self.generate_docker_compose())
if dockerfile:
generated_files.append(self.generate_dockerfile())
return generated_files
def generate_prism_config(self, output_file: Optional[str] = None) -> Path:
"""Generate Prism mock server configuration.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "prism-config.json"
config = {
"mock": {
"host": self.DEFAULT_HOST,
"port": self.DEFAULT_PORT,
"cors": {
"enabled": True,
"allowOrigin": "*",
"allowMethods": ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"],
"allowHeaders": ["Content-Type", "Authorization", "X-API-Key"],
},
"Validation": {
"request": True,
"response": True,
},
},
"logging": {
"level": "info",
"format": "json",
},
}
output_path.write_text(json.dumps(config, indent=2))
return output_path
def generate_docker_compose(self, output_file: Optional[str] = None) -> Path:
"""Generate Docker Compose configuration for mock server.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "docker-compose.yml"
spec_info = self.spec_parser.get_info()
compose_content = f'''version: '3.8'
services:
mock-server:
image: stoplight/prism:latest
container_name: "{spec_info['title']}-mock-server"
command: >
mock
--spec /app/openapi.yaml
--port {self.DEFAULT_PORT}
--host {self.DEFAULT_HOST}
ports:
- "{self.DEFAULT_PORT}:{self.DEFAULT_PORT}"
volumes:
- ./:/app
restart: unless-stopped
healthcheck:
test: ["CMD", "wget", "-q", "--spider", f"http://localhost:{self.DEFAULT_PORT}/health"]
interval: 30s
timeout: 10s
retries: 3
mock-server-https:
image: stoplight/prism:latest
container_name: "{spec_info['title']}-mock-server-https"
command: >
mock
--spec /app/openapi.yaml
--port {self.DEFAULT_PORT + 1}
--host {self.DEFAULT_HOST}
ports:
- "{self.DEFAULT_PORT + 1}:{self.DEFAULT_PORT + 1}"
volumes:
- ./:/app
restart: unless-stopped
'''
output_path.write_text(compose_content)
return output_path
def generate_dockerfile(self, output_file: Optional[str] = None) -> Path:
"""Generate Dockerfile for mock server.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "Dockerfile"
spec_info = self.spec_parser.get_info()
dockerfile_content = f'''FROM stoplight/prism:latest
LABEL maintainer="developer@example.com"
LABEL description="Mock server for {spec_info['title']} API"
WORKDIR /app
COPY openapi.yaml .
EXPOSE {self.DEFAULT_PORT}
CMD ["mock", "--spec", "openapi.yaml", "--port", "{self.DEFAULT_PORT}", "--host", "0.0.0.0"]
'''
output_path.write_text(dockerfile_content)
return output_path
def generate_start_script(self, output_file: Optional[str] = None) -> Path:
"""Generate shell script to start mock server.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "start-mock-server.sh"
script_content = f'''#!/bin/bash
set -e
SCRIPT_DIR="$(cd "$(dirname "${{BASH_SOURCE[0]}}")" && pwd)"
cd "$SCRIPT_DIR"
PORT={self.DEFAULT_PORT}
HOST="0.0.0.0"
echo "Starting mock server for API..."
echo "Mock server will be available at: http://localhost:$PORT"
docker compose up -d mock-server
echo "Mock server started successfully!"
echo "To stop: docker compose down"
'''
output_path.write_text(script_content)
output_path.chmod(0o755)
return output_path
def generate_health_check(self) -> Path:
"""Generate health check endpoint configuration.
Returns:
Path to the generated file.
"""
output_path = self.output_dir / "health-check.json"
spec_info = self.spec_parser.get_info()
health_check = {
"openapi": self.spec_parser.version,
"info": {
"title": f"{spec_info['title']} - Health Check",
"version": spec_info["version"],
"description": "Health check endpoint for the mock server",
},
"paths": {
"/health": {
"get": {
"summary": "Health check endpoint",
"description": "Returns the health status of the mock server",
"operationId": "healthCheck",
"responses": {
"200": {
"description": "Mock server is healthy",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["healthy"],
},
"api": {
"type": "string",
},
"version": {
"type": "string",
},
},
},
}
},
}
},
}
}
},
}
output_path.write_text(json.dumps(health_check, indent=2))
return output_path