fix: add Gitea Actions CI workflow for automated testing
This commit is contained in:
3
api_testgen/__init__.py
Normal file
3
api_testgen/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""API TestGen - OpenAPI Specification Test Generator."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
1
api_testgen/cli/__init__.py
Normal file
1
api_testgen/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CLI module for API TestGen."""
|
||||
277
api_testgen/cli/main.py
Normal file
277
api_testgen/cli/main.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""CLI interface for API TestGen."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from ..core import SpecParser, AuthConfig
|
||||
from ..core.exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
|
||||
from ..generators import PytestGenerator, JestGenerator, GoGenerator
|
||||
from ..mocks import MockServerGenerator
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version="0.1.0")
|
||||
@click.option(
|
||||
"--spec",
|
||||
"-s",
|
||||
type=click.Path(exists=True, file_okay=True, dir_okay=False),
|
||||
help="Path to OpenAPI specification file",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
type=click.Path(file_okay=False, dir_okay=True),
|
||||
help="Output directory for generated files",
|
||||
)
|
||||
@click.option(
|
||||
"--mock-url",
|
||||
default="http://localhost:4010",
|
||||
help="URL of the mock server",
|
||||
)
|
||||
@click.pass_context
|
||||
def main(
|
||||
ctx: click.Context,
|
||||
spec: str,
|
||||
output: str,
|
||||
mock_url: str,
|
||||
):
|
||||
"""API TestGen - Generate integration tests from OpenAPI specifications."""
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["spec"] = spec
|
||||
ctx.obj["output"] = output or "./generated"
|
||||
ctx.obj["mock_url"] = mock_url
|
||||
|
||||
|
||||
@main.command("parse")
|
||||
@click.pass_context
|
||||
def parse_spec(ctx: click.Context):
|
||||
"""Parse and validate an OpenAPI specification."""
|
||||
spec_path = ctx.obj["spec"]
|
||||
|
||||
if not spec_path:
|
||||
click.echo("Error: --spec option is required", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
try:
|
||||
parser = SpecParser(spec_path)
|
||||
spec = parser.load()
|
||||
|
||||
info = parser.get_info()
|
||||
endpoints = parser.get_endpoints()
|
||||
security_schemes = parser.get_security_schemes()
|
||||
|
||||
click.echo(f"API: {info['title']} v{info['version']}")
|
||||
click.echo(f"OpenAPI Version: {parser.version}")
|
||||
click.echo(f"Base Path: {parser.base_path}")
|
||||
click.echo(f"Endpoints: {len(endpoints)}")
|
||||
click.echo(f"Security Schemes: {len(security_schemes)}")
|
||||
click.echo()
|
||||
|
||||
for endpoint in endpoints:
|
||||
click.echo(f" {endpoint['method'].upper():6} {endpoint['path']}")
|
||||
|
||||
ctx.obj["parser"] = parser
|
||||
|
||||
except InvalidOpenAPISpecError as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
except UnsupportedVersionError as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@main.command("generate")
|
||||
@click.argument("framework", type=click.Choice(["pytest", "jest", "go"]))
|
||||
@click.option(
|
||||
"--output-file",
|
||||
"-f",
|
||||
type=click.Path(file_okay=True, dir_okay=False),
|
||||
help="Specific output file path",
|
||||
)
|
||||
@click.option(
|
||||
"--package-name",
|
||||
default="apitest",
|
||||
help="Go package name (only for go framework)",
|
||||
)
|
||||
@click.pass_context
|
||||
def generate_tests(
|
||||
ctx: click.Context,
|
||||
framework: str,
|
||||
output_file: str,
|
||||
package_name: str,
|
||||
):
|
||||
"""Generate test files for a framework (pytest, jest, or go)."""
|
||||
spec_path = ctx.obj["spec"]
|
||||
|
||||
if not spec_path:
|
||||
click.echo("Error: --spec option is required", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
output_dir = ctx.obj["output"]
|
||||
mock_url = ctx.obj["mock_url"]
|
||||
|
||||
try:
|
||||
parser = SpecParser(spec_path)
|
||||
parser.load()
|
||||
|
||||
if framework == "pytest":
|
||||
generator = PytestGenerator(parser, output_dir, mock_url)
|
||||
files = generator.generate(output_file)
|
||||
|
||||
elif framework == "jest":
|
||||
generator = JestGenerator(parser, output_dir, mock_url)
|
||||
files = generator.generate(output_file)
|
||||
|
||||
elif framework == "go":
|
||||
generator = GoGenerator(parser, output_dir, mock_url, package_name)
|
||||
files = generator.generate(output_file)
|
||||
|
||||
click.echo(f"Generated {len(files)} test file(s):")
|
||||
for f in files:
|
||||
click.echo(f" - {f}")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@main.command("mock")
|
||||
@click.option(
|
||||
"--no-prism-config",
|
||||
is_flag=True,
|
||||
help="Skip generating prism-config.json",
|
||||
)
|
||||
@click.option(
|
||||
"--no-docker-compose",
|
||||
is_flag=True,
|
||||
help="Skip generating docker-compose.yml",
|
||||
)
|
||||
@click.option(
|
||||
"--no-dockerfile",
|
||||
is_flag=True,
|
||||
help="Skip generating Dockerfile",
|
||||
)
|
||||
@click.pass_context
|
||||
def generate_mock(
|
||||
ctx: click.Context,
|
||||
no_prism_config: bool,
|
||||
no_docker_compose: bool,
|
||||
no_dockerfile: bool,
|
||||
):
|
||||
"""Generate mock server configuration files."""
|
||||
spec_path = ctx.obj["spec"]
|
||||
|
||||
if not spec_path:
|
||||
click.echo("Error: --spec option is required", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
output_dir = ctx.obj["output"]
|
||||
|
||||
try:
|
||||
parser = SpecParser(spec_path)
|
||||
parser.load()
|
||||
|
||||
generator = MockServerGenerator(parser, output_dir)
|
||||
|
||||
files = generator.generate(
|
||||
prism_config=not no_prism_config,
|
||||
docker_compose=not no_docker_compose,
|
||||
dockerfile=not no_dockerfile,
|
||||
)
|
||||
|
||||
click.echo(f"Generated {len(files)} mock server file(s):")
|
||||
for f in files:
|
||||
click.echo(f" - {f}")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@main.command("all")
|
||||
@click.argument("framework", type=click.Choice(["pytest", "jest", "go"]))
|
||||
@click.pass_context
|
||||
def generate_all(
|
||||
ctx: click.Context,
|
||||
framework: str,
|
||||
):
|
||||
"""Generate test files and mock server configuration."""
|
||||
spec_path = ctx.obj["spec"]
|
||||
|
||||
if not spec_path:
|
||||
click.echo("Error: --spec option is required", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
output_dir = ctx.obj["output"]
|
||||
mock_url = ctx.obj["mock_url"]
|
||||
|
||||
try:
|
||||
parser = SpecParser(spec_path)
|
||||
parser.load()
|
||||
|
||||
click.echo("Generating tests...")
|
||||
if framework == "pytest":
|
||||
generator = PytestGenerator(parser, output_dir, mock_url)
|
||||
files = generator.generate()
|
||||
|
||||
elif framework == "jest":
|
||||
generator = JestGenerator(parser, output_dir, mock_url)
|
||||
files = generator.generate()
|
||||
|
||||
elif framework == "go":
|
||||
generator = GoGenerator(parser, output_dir, mock_url)
|
||||
files = generator.generate()
|
||||
|
||||
click.echo(f"Generated {len(files)} test file(s)")
|
||||
|
||||
click.echo("Generating mock server configuration...")
|
||||
mock_generator = MockServerGenerator(parser, output_dir)
|
||||
mock_files = mock_generator.generate()
|
||||
|
||||
click.echo(f"Generated {len(mock_files)} mock server file(s)")
|
||||
|
||||
click.echo("\nAll files generated successfully!")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@main.command("auth")
|
||||
@click.argument("scheme_name")
|
||||
@click.option("--type", "auth_type", type=click.Choice(["apiKey", "bearer", "basic"]), help="Authentication type")
|
||||
@click.option("--header", help="Header name for API key", default="X-API-Key")
|
||||
@click.option("--token", help="Bearer token or API key value")
|
||||
@click.option("--username", help="Username for Basic auth")
|
||||
@click.option("--password", help="Password for Basic auth")
|
||||
@click.pass_context
|
||||
def configure_auth(
|
||||
ctx: click.Context,
|
||||
scheme_name: str,
|
||||
auth_type: str,
|
||||
header: str,
|
||||
token: str,
|
||||
username: str,
|
||||
password: str,
|
||||
):
|
||||
"""Configure authentication for a security scheme."""
|
||||
auth_config = AuthConfig()
|
||||
|
||||
if auth_type == "apiKey":
|
||||
auth_config.add_api_key(scheme_name, header, token or "")
|
||||
elif auth_type == "bearer":
|
||||
auth_config.add_bearer(scheme_name, token or "")
|
||||
elif auth_type == "basic":
|
||||
auth_config.add_basic(scheme_name, username or "", password or "")
|
||||
|
||||
click.echo(f"Authentication scheme '{scheme_name}' configured:")
|
||||
methods = auth_config.get_all_methods()
|
||||
for name, method in methods.items():
|
||||
click.echo(f" - {name}: {method['type'].value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
api_testgen/core/__init__.py
Normal file
6
api_testgen/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Core module for API TestGen."""
|
||||
|
||||
from .spec_parser import SpecParser
|
||||
from .auth import AuthConfig
|
||||
|
||||
__all__ = ["SpecParser", "AuthConfig"]
|
||||
313
api_testgen/core/auth.py
Normal file
313
api_testgen/core/auth.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Authentication configuration for API TestGen."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from .exceptions import AuthConfigError, MissingSecuritySchemeError
|
||||
|
||||
|
||||
class AuthType(str, Enum):
|
||||
"""Types of authentication."""
|
||||
API_KEY = "apiKey"
|
||||
BEARER = "bearer"
|
||||
BASIC = "basic"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class AuthConfig:
|
||||
"""Authentication configuration for generated tests."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize authentication configuration."""
|
||||
self._auth_methods: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def add_api_key(
|
||||
self,
|
||||
scheme_name: str,
|
||||
header_name: str = "X-API-Key",
|
||||
api_key: str = "",
|
||||
) -> "AuthConfig":
|
||||
"""Add API key authentication.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme in OpenAPI spec.
|
||||
header_name: Name of the header containing the API key.
|
||||
api_key: The API key value (can be set later).
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._auth_methods[scheme_name] = {
|
||||
"type": AuthType.API_KEY,
|
||||
"header_name": header_name,
|
||||
"api_key": api_key,
|
||||
}
|
||||
return self
|
||||
|
||||
def add_bearer(
|
||||
self,
|
||||
scheme_name: str,
|
||||
token: str = "",
|
||||
token_prefix: str = "Bearer",
|
||||
) -> "AuthConfig":
|
||||
"""Add Bearer token authentication.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme in OpenAPI spec.
|
||||
token: The Bearer token value (can be set later).
|
||||
token_prefix: The token prefix (default: Bearer).
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._auth_methods[scheme_name] = {
|
||||
"type": AuthType.BEARER,
|
||||
"token": token,
|
||||
"token_prefix": token_prefix,
|
||||
}
|
||||
return self
|
||||
|
||||
def add_basic(
|
||||
self,
|
||||
scheme_name: str,
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
) -> "AuthConfig":
|
||||
"""Add Basic authentication.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme in OpenAPI spec.
|
||||
username: The username (can be set later).
|
||||
password: The password (can be set later).
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._auth_methods[scheme_name] = {
|
||||
"type": AuthType.BASIC,
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
return self
|
||||
|
||||
def get_auth_method(self, scheme_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get authentication method by scheme name.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme.
|
||||
|
||||
Returns:
|
||||
Authentication method configuration or None.
|
||||
"""
|
||||
return self._auth_methods.get(scheme_name)
|
||||
|
||||
def get_all_methods(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all configured authentication methods.
|
||||
|
||||
Returns:
|
||||
Dictionary of scheme names and their configurations.
|
||||
"""
|
||||
return self._auth_methods.copy()
|
||||
|
||||
def get_headers(self, scheme_name: str) -> Dict[str, str]:
|
||||
"""Get authentication headers for a scheme.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme.
|
||||
|
||||
Returns:
|
||||
Dictionary of header names and values.
|
||||
|
||||
Raises:
|
||||
AuthConfigError: If scheme is not configured.
|
||||
"""
|
||||
method = self.get_auth_method(scheme_name)
|
||||
if not method:
|
||||
raise AuthConfigError(f"Authentication scheme '{scheme_name}' not configured")
|
||||
|
||||
if method["type"] == AuthType.API_KEY:
|
||||
return {method["header_name"]: method["api_key"]}
|
||||
elif method["type"] == AuthType.BEARER:
|
||||
return {"Authorization": f"{method['token_prefix']} {method['token']}"}
|
||||
elif method["type"] == AuthType.BASIC:
|
||||
import base64
|
||||
credentials = f"{method['username']}:{method['password']}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
return {"Authorization": f"Basic {encoded}"}
|
||||
return {}
|
||||
|
||||
def build_from_spec(
|
||||
self,
|
||||
security_schemes: Dict[str, Any],
|
||||
security_requirements: List[Dict[str, Any]],
|
||||
) -> "AuthConfig":
|
||||
"""Build auth configuration from OpenAPI security schemes.
|
||||
|
||||
Args:
|
||||
security_schemes: Security schemes from OpenAPI spec.
|
||||
security_requirements: Security requirements from endpoint.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
MissingSecuritySchemeError: If required scheme is not defined.
|
||||
"""
|
||||
for requirement in security_requirements:
|
||||
for scheme_name in requirement.keys():
|
||||
if scheme_name not in self._auth_methods:
|
||||
if scheme_name not in security_schemes:
|
||||
raise MissingSecuritySchemeError(
|
||||
f"Security scheme '{scheme_name}' not found in spec"
|
||||
)
|
||||
scheme = security_schemes[scheme_name]
|
||||
self._add_scheme_from_spec(scheme_name, scheme)
|
||||
|
||||
return self
|
||||
|
||||
def _add_scheme_from_spec(self, scheme_name: str, scheme: Dict[str, Any]) -> None:
|
||||
"""Add authentication scheme from OpenAPI spec definition.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme.
|
||||
scheme: The security scheme definition from OpenAPI spec.
|
||||
"""
|
||||
scheme_type = scheme.get("type", "")
|
||||
|
||||
if scheme_type == "apiKey":
|
||||
self.add_api_key(
|
||||
scheme_name,
|
||||
header_name=scheme.get("name", "X-API-Key"),
|
||||
)
|
||||
elif scheme_type == "http":
|
||||
scheme_scheme = scheme.get("scheme", "").lower()
|
||||
if scheme_scheme == "bearer":
|
||||
self.add_bearer(scheme_name)
|
||||
elif scheme_scheme == "basic":
|
||||
self.add_basic(scheme_name)
|
||||
elif scheme_type == "openIdConnect":
|
||||
self.add_bearer(scheme_name)
|
||||
elif scheme_type == "oauth2":
|
||||
self.add_bearer(scheme_name)
|
||||
|
||||
def generate_auth_code(self, scheme_name: str, framework: str = "pytest") -> str:
|
||||
"""Generate authentication code for a test framework.
|
||||
|
||||
Args:
|
||||
scheme_name: Name of the security scheme.
|
||||
framework: Target test framework (pytest, jest, go).
|
||||
|
||||
Returns:
|
||||
String containing authentication code snippet.
|
||||
"""
|
||||
method = self.get_auth_method(scheme_name)
|
||||
if not method:
|
||||
return ""
|
||||
|
||||
if framework == "pytest":
|
||||
return self._generate_pytest_auth(method)
|
||||
elif framework == "jest":
|
||||
return self._generate_jest_auth(method)
|
||||
elif framework == "go":
|
||||
return self._generate_go_auth(method)
|
||||
return ""
|
||||
|
||||
def _generate_pytest_auth(self, method: Dict[str, Any]) -> str:
|
||||
"""Generate pytest authentication code.
|
||||
|
||||
Args:
|
||||
method: Authentication method configuration.
|
||||
|
||||
Returns:
|
||||
String containing pytest auth code.
|
||||
"""
|
||||
if method["type"] == AuthType.API_KEY:
|
||||
return f'''
|
||||
@pytest.fixture
|
||||
def api_key_headers():
|
||||
return {{"{method['header_name']}": "{method['api_key']}"}}
|
||||
'''
|
||||
elif method["type"] == AuthType.BEARER:
|
||||
return f'''
|
||||
@pytest.fixture
|
||||
def bearer_headers():
|
||||
return {{"Authorization": "{method['token_prefix']} {method['token']}"}}
|
||||
'''
|
||||
elif method["type"] == AuthType.BASIC:
|
||||
return f'''
|
||||
import base64
|
||||
|
||||
@pytest.fixture
|
||||
def basic_headers():
|
||||
credentials = f"{{"{method['username']}"}}:{{"{method['password']}"}}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
return {{"Authorization": f"Basic {{encoded}}"}}
|
||||
'''
|
||||
return ""
|
||||
|
||||
def _generate_jest_auth(self, method: Dict[str, Any]) -> str:
|
||||
"""Generate Jest authentication code.
|
||||
|
||||
Args:
|
||||
method: Authentication method configuration.
|
||||
|
||||
Returns:
|
||||
String containing Jest auth code.
|
||||
"""
|
||||
if method["type"] == AuthType.API_KEY:
|
||||
return f'''
|
||||
const getApiKeyHeaders = () => ({{
|
||||
"{method['header_name']}": process.env.API_KEY || "{method['api_key']}",
|
||||
}});
|
||||
'''
|
||||
elif method["type"] == AuthType.BEARER:
|
||||
return f'''
|
||||
const getBearerHeaders = () => ({{
|
||||
Authorization: `${{process.env.TOKEN_PREFIX || "{method['token_prefix']}"}} ${{process.env.TOKEN || "{method['token']}"}}`,
|
||||
}});
|
||||
'''
|
||||
elif method["type"] == AuthType.BASIC:
|
||||
return f'''
|
||||
const getBasicHeaders = () => {{
|
||||
const credentials = Buffer.from(`${{process.env.USERNAME || "{method['username']}"}}:${{process.env.PASSWORD || "{method['password']}"}}`).toString('base64');
|
||||
return {{ Authorization: `Basic ${{credentials}}` }};
|
||||
}};
|
||||
'''
|
||||
return ""
|
||||
|
||||
def _generate_go_auth(self, method: Dict[str, Any]) -> str:
|
||||
"""Generate Go authentication code.
|
||||
|
||||
Args:
|
||||
method: Authentication method configuration.
|
||||
|
||||
Returns:
|
||||
String containing Go auth code.
|
||||
"""
|
||||
if method["type"] == AuthType.API_KEY:
|
||||
return f'''
|
||||
func getAPIKeyHeaders() map[string]string {{
|
||||
return map[string]string{{
|
||||
"{method['header_name']}": os.Getenv("API_KEY"),
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
elif method["type"] == AuthType.BEARER:
|
||||
return f'''
|
||||
func getBearerHeaders() map[string]string {{
|
||||
return map[string]string{{
|
||||
"Authorization": fmt.Sprintf("%s %s", os.Getenv("TOKEN_PREFIX"), os.Getenv("TOKEN")),
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
elif method["type"] == AuthType.BASIC:
|
||||
return f'''
|
||||
func getBasicHeaders(username, password string) map[string]string {{
|
||||
auth := username + ":" + password
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
return map[string]string{{
|
||||
"Authorization": "Basic " + encoded,
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
return ""
|
||||
36
api_testgen/core/exceptions.py
Normal file
36
api_testgen/core/exceptions.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Custom exceptions for API TestGen."""
|
||||
|
||||
|
||||
class SpecParserError(Exception):
|
||||
"""Base exception for spec parser errors."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidOpenAPISpecError(SpecParserError):
|
||||
"""Raised when OpenAPI specification is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedVersionError(SpecParserError):
|
||||
"""Raised when OpenAPI version is not supported."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthConfigError(Exception):
|
||||
"""Base exception for auth configuration errors."""
|
||||
pass
|
||||
|
||||
|
||||
class MissingSecuritySchemeError(AuthConfigError):
|
||||
"""Raised when security scheme is not defined in spec."""
|
||||
pass
|
||||
|
||||
|
||||
class GeneratorError(Exception):
|
||||
"""Base exception for generator errors."""
|
||||
pass
|
||||
|
||||
|
||||
class TemplateRenderError(GeneratorError):
|
||||
"""Raised when template rendering fails."""
|
||||
pass
|
||||
307
api_testgen/core/spec_parser.py
Normal file
307
api_testgen/core/spec_parser.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""OpenAPI Specification Parser."""
|
||||
|
||||
import yaml
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openapi_spec_validator import validate
|
||||
from openapi_spec_validator.versions import consts as validator_consts
|
||||
|
||||
from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
|
||||
|
||||
|
||||
class SpecParser:
|
||||
"""Parse and validate OpenAPI/Swagger specifications."""
|
||||
|
||||
SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
|
||||
|
||||
def __init__(self, spec_path: Union[str, Path]):
|
||||
"""Initialize the spec parser.
|
||||
|
||||
Args:
|
||||
spec_path: Path to OpenAPI specification file (YAML or JSON).
|
||||
"""
|
||||
self.spec_path = Path(spec_path)
|
||||
self.spec: Dict[str, Any] = {}
|
||||
self.version: str = ""
|
||||
self.base_path: str = ""
|
||||
self.servers: List[Dict[str, str]] = []
|
||||
|
||||
def load(self) -> Dict[str, Any]:
|
||||
"""Load and validate the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
The parsed specification dictionary.
|
||||
|
||||
Raises:
|
||||
InvalidOpenAPISpecError: If the specification is invalid.
|
||||
UnsupportedVersionError: If the OpenAPI version is not supported.
|
||||
"""
|
||||
self.spec = self._load_file()
|
||||
self._validate()
|
||||
self._extract_metadata()
|
||||
return self.spec
|
||||
|
||||
def _load_file(self) -> Dict[str, Any]:
|
||||
"""Load the specification file.
|
||||
|
||||
Returns:
|
||||
Parsed specification dictionary.
|
||||
|
||||
Raises:
|
||||
InvalidOpenAPISpecError: If the file cannot be loaded or parsed.
|
||||
"""
|
||||
if not self.spec_path.exists():
|
||||
raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}")
|
||||
|
||||
try:
|
||||
with open(self.spec_path, 'r', encoding='utf-8') as f:
|
||||
if self.spec_path.suffix in ['.yaml', '.yml']:
|
||||
return yaml.safe_load(f) or {}
|
||||
elif self.spec_path.suffix == '.json':
|
||||
return json.load(f)
|
||||
else:
|
||||
return yaml.safe_load(f) or {}
|
||||
except (yaml.YAMLError, json.JSONDecodeError) as e:
|
||||
raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}")
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Validate the specification.
|
||||
|
||||
Raises:
|
||||
InvalidOpenAPISpecError: If the specification is invalid.
|
||||
UnsupportedVersionError: If the OpenAPI version is not supported.
|
||||
"""
|
||||
try:
|
||||
validate(self.spec)
|
||||
except Exception as e:
|
||||
raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}")
|
||||
|
||||
version = self._get_version()
|
||||
if version not in self.SUPPORTED_VERSIONS:
|
||||
raise UnsupportedVersionError(
|
||||
f"Unsupported OpenAPI version: {version}. "
|
||||
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
|
||||
)
|
||||
|
||||
def _get_version(self) -> str:
|
||||
"""Extract the OpenAPI version from the spec.
|
||||
|
||||
Returns:
|
||||
The OpenAPI version string.
|
||||
"""
|
||||
if "openapi" in self.spec:
|
||||
return self.spec["openapi"]
|
||||
elif "swagger" in self.spec:
|
||||
return self.spec["swagger"]
|
||||
return "2.0"
|
||||
|
||||
def _extract_metadata(self) -> None:
|
||||
"""Extract metadata from the specification."""
|
||||
self.version = self._get_version()
|
||||
self.base_path = self.spec.get("basePath", "")
|
||||
self.servers = self.spec.get("servers", [])
|
||||
|
||||
def get_paths(self) -> Dict[str, Any]:
|
||||
"""Get all paths from the specification.
|
||||
|
||||
Returns:
|
||||
Dictionary of paths and their operations.
|
||||
"""
|
||||
return self.spec.get("paths", {})
|
||||
|
||||
def get_endpoints(self) -> List[Dict[str, Any]]:
|
||||
"""Extract all endpoints from the specification.
|
||||
|
||||
Returns:
|
||||
List of endpoint dictionaries with path, method, and details.
|
||||
"""
|
||||
endpoints = []
|
||||
paths = self.get_paths()
|
||||
|
||||
for path, path_item in paths.items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]:
|
||||
endpoint = {
|
||||
"path": path,
|
||||
"method": method.lower(),
|
||||
"operation_id": operation.get("operationId", ""),
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"tags": operation.get("tags", []),
|
||||
"parameters": self._extract_parameters(path_item, operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
"responses": self._extract_responses(operation),
|
||||
"security": self._extract_security(operation),
|
||||
}
|
||||
endpoints.append(endpoint)
|
||||
|
||||
return endpoints
|
||||
|
||||
def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]:
|
||||
"""Extract parameters from path item and operation.
|
||||
|
||||
Args:
|
||||
path_item: The path item dictionary.
|
||||
operation: The operation dictionary.
|
||||
|
||||
Returns:
|
||||
List of parameter dictionaries.
|
||||
"""
|
||||
parameters = []
|
||||
|
||||
for param in path_item.get("parameters", []):
|
||||
if param.get("in") != "body":
|
||||
parameters.append(self._normalize_parameter(param))
|
||||
|
||||
for param in operation.get("parameters", []):
|
||||
if param.get("in") != "body":
|
||||
parameters.append(self._normalize_parameter(param))
|
||||
|
||||
return parameters
|
||||
|
||||
def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize a parameter.
|
||||
|
||||
Args:
|
||||
param: The parameter dictionary.
|
||||
|
||||
Returns:
|
||||
Normalized parameter dictionary.
|
||||
"""
|
||||
return {
|
||||
"name": param.get("name", ""),
|
||||
"in": param.get("in", ""),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"schema": param.get("schema", {}),
|
||||
"type": param.get("type", ""),
|
||||
"enum": param.get("enum", []),
|
||||
"default": param.get("default"),
|
||||
}
|
||||
|
||||
def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]:
|
||||
"""Extract request body from operation.
|
||||
|
||||
Args:
|
||||
operation: The operation dictionary.
|
||||
|
||||
Returns:
|
||||
Request body dictionary or None.
|
||||
"""
|
||||
if self.version.startswith("3."):
|
||||
request_body = operation.get("requestBody", {})
|
||||
if not request_body:
|
||||
return None
|
||||
|
||||
content = request_body.get("content", {})
|
||||
media_types = list(content.keys())
|
||||
|
||||
return {
|
||||
"description": request_body.get("description", ""),
|
||||
"required": request_body.get("required", False),
|
||||
"media_types": media_types,
|
||||
"schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {},
|
||||
}
|
||||
else:
|
||||
params = operation.get("parameters", [])
|
||||
for param in params:
|
||||
if param.get("in") == "body":
|
||||
return {
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"schema": param.get("schema", {}),
|
||||
}
|
||||
return None
|
||||
|
||||
def _extract_responses(self, operation: Dict) -> Dict[str, Any]:
|
||||
"""Extract responses from operation.
|
||||
|
||||
Args:
|
||||
operation: The operation dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary of response status codes and their details.
|
||||
"""
|
||||
responses = {}
|
||||
|
||||
for status_code, response in operation.get("responses", {}).items():
|
||||
content = response.get("content", {})
|
||||
|
||||
if self.version.startswith("3."):
|
||||
media_types = list(content.keys())
|
||||
schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {}
|
||||
else:
|
||||
schema = response.get("schema", {})
|
||||
|
||||
responses[status_code] = {
|
||||
"description": response.get("description", ""),
|
||||
"schema": schema,
|
||||
"media_type": list(content.keys())[0] if content else "application/json",
|
||||
}
|
||||
|
||||
return responses
|
||||
|
||||
def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]:
|
||||
"""Extract security requirements from operation.
|
||||
|
||||
Args:
|
||||
operation: The operation dictionary.
|
||||
|
||||
Returns:
|
||||
List of security requirement dictionaries.
|
||||
"""
|
||||
return operation.get("security", self.spec.get("security", []))
|
||||
|
||||
def get_security_schemes(self) -> Dict[str, Any]:
|
||||
"""Get security schemes from the specification.
|
||||
|
||||
Returns:
|
||||
Dictionary of security scheme names and their definitions.
|
||||
"""
|
||||
if self.version.startswith("3."):
|
||||
return self.spec.get("components", {}).get("securitySchemes", {})
|
||||
else:
|
||||
return self.spec.get("securityDefinitions", {})
|
||||
|
||||
def get_definitions(self) -> Dict[str, Any]:
|
||||
"""Get schema definitions from the specification.
|
||||
|
||||
Returns:
|
||||
Dictionary of schema definitions.
|
||||
"""
|
||||
if self.version.startswith("3."):
|
||||
return self.spec.get("components", {}).get("schemas", {})
|
||||
else:
|
||||
return self.spec.get("definitions", {})
|
||||
|
||||
def get_info(self) -> Dict[str, str]:
|
||||
"""Get API info from the specification.
|
||||
|
||||
Returns:
|
||||
Dictionary with title, version, and description.
|
||||
"""
|
||||
info = self.spec.get("info", {})
|
||||
return {
|
||||
"title": info.get("title", "API"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
"description": info.get("description", ""),
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the spec to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the parsed spec.
|
||||
"""
|
||||
return {
|
||||
"version": self.version,
|
||||
"base_path": self.base_path,
|
||||
"servers": self.servers,
|
||||
"info": self.get_info(),
|
||||
"paths": self.get_paths(),
|
||||
"endpoints": self.get_endpoints(),
|
||||
"security_schemes": self.get_security_schemes(),
|
||||
"definitions": self.get_definitions(),
|
||||
}
|
||||
7
api_testgen/generators/__init__.py
Normal file
7
api_testgen/generators/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Generators module for API TestGen."""
|
||||
|
||||
from .pytest import PytestGenerator
|
||||
from .jest import JestGenerator
|
||||
from .go import GoGenerator
|
||||
|
||||
__all__ = ["PytestGenerator", "JestGenerator", "GoGenerator"]
|
||||
229
api_testgen/generators/go.py
Normal file
229
api_testgen/generators/go.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Go test generator."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
|
||||
|
||||
from ..core import SpecParser, AuthConfig
|
||||
from ..core.exceptions import GeneratorError, TemplateRenderError
|
||||
|
||||
|
||||
class GoGenerator:
|
||||
"""Generate Go-compatible test files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_parser: SpecParser,
|
||||
output_dir: str = "tests",
|
||||
mock_server_url: str = "http://localhost:4010",
|
||||
package_name: str = "apitest",
|
||||
):
|
||||
"""Initialize the go generator.
|
||||
|
||||
Args:
|
||||
spec_parser: The OpenAPI specification parser.
|
||||
output_dir: Directory for generated test files.
|
||||
mock_server_url: URL of the mock server for testing.
|
||||
package_name: Go package name for test files.
|
||||
"""
|
||||
self.spec_parser = spec_parser
|
||||
self.output_dir = Path(output_dir)
|
||||
self.mock_server_url = mock_server_url
|
||||
self.package_name = package_name
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
def generate(self, output_file: Optional[str] = None) -> List[Path]:
|
||||
"""Generate Go test files.
|
||||
|
||||
Args:
|
||||
output_file: Optional specific output file path.
|
||||
|
||||
Returns:
|
||||
List of generated file paths.
|
||||
"""
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
info = self.spec_parser.get_info()
|
||||
|
||||
grouped_endpoints = self._group_endpoints_by_path(endpoints)
|
||||
|
||||
context = {
|
||||
"package_name": self.package_name,
|
||||
"api_title": info["title"],
|
||||
"api_version": info["version"],
|
||||
"endpoints": endpoints,
|
||||
"grouped_endpoints": grouped_endpoints,
|
||||
"mock_server_url": self.mock_server_url,
|
||||
"security_schemes": self.spec_parser.get_security_schemes(),
|
||||
"definitions": self.spec_parser.get_definitions(),
|
||||
}
|
||||
|
||||
generated_files = []
|
||||
|
||||
try:
|
||||
template = self.env.get_template("api_test.go.j2")
|
||||
content = template.render(context)
|
||||
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
|
||||
output_path = self.output_dir / f"{safe_name}_test.go"
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content)
|
||||
generated_files.append(output_path)
|
||||
|
||||
except (TemplateSyntaxError, UndefinedError) as e:
|
||||
raise TemplateRenderError(f"Failed to render Go template: {e}")
|
||||
|
||||
return generated_files
|
||||
|
||||
def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Group endpoints by their path.
|
||||
|
||||
Args:
|
||||
endpoints: List of endpoint dictionaries.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping paths to their endpoints.
|
||||
"""
|
||||
grouped = {}
|
||||
for endpoint in endpoints:
|
||||
path = endpoint["path"]
|
||||
if path not in grouped:
|
||||
grouped[path] = []
|
||||
grouped[path].append(endpoint)
|
||||
return grouped
|
||||
|
||||
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
|
||||
"""Generate test for a specific endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_path: The API endpoint path.
|
||||
method: The HTTP method.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
|
||||
for endpoint in endpoints:
|
||||
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
|
||||
return self._generate_single_test(endpoint)
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate test code for a single endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
test_name = self._generate_test_name(endpoint)
|
||||
params = self._generate_params(endpoint)
|
||||
url_params = self._generate_url_params(endpoint)
|
||||
|
||||
test_code = f'''
|
||||
func Test{test_name}(t *testing.T) {{
|
||||
client := &http.Client{{Timeout: 10 * time.Second}}
|
||||
url := baseURL + "{endpoint['path']}"
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
{params}
|
||||
|
||||
req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil)
|
||||
if err != nil {{
|
||||
t.Fatal(err)
|
||||
}}
|
||||
|
||||
for k, v := range getAuthHeaders() {{
|
||||
req.Header.Set(k, v)
|
||||
}}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {{
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{
|
||||
t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode)
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
return test_code
|
||||
|
||||
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate a valid Go test function name.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
A valid Go function name.
|
||||
"""
|
||||
path = endpoint["path"]
|
||||
method = endpoint["method"]
|
||||
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
|
||||
name = re.sub(r"_+/", "_", name)
|
||||
name = re.sub(r"_+$", "", name)
|
||||
name = name.title().replace("_", "")
|
||||
|
||||
return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default"
|
||||
|
||||
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate parameter variables for test.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing parameter declarations.
|
||||
"""
|
||||
params = []
|
||||
|
||||
for param in endpoint.get("parameters", []):
|
||||
param_name = param["name"]
|
||||
|
||||
if param["in"] == "path":
|
||||
params.append(f'{param_name} := "test_{param_name}"')
|
||||
params.append(f'url = strings.Replace(url, "{{'+param_name+'}}", {param_name}, 1)')
|
||||
|
||||
elif param["in"] == "query":
|
||||
params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}')
|
||||
params.append(f'url += "?" + q.Encode()')
|
||||
|
||||
return "\n ".join(params) if params else ""
|
||||
|
||||
def _generate_url_params(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate URL parameters.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing URL parameter handling.
|
||||
"""
|
||||
path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"]
|
||||
query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"]
|
||||
|
||||
parts = []
|
||||
|
||||
for param in path_params:
|
||||
parts.append(f'strings.Replace(url, "{{'+param['name']+'}}", "test_' + param['name'] + '", 1)')
|
||||
|
||||
return ""
|
||||
169
api_testgen/generators/jest.py
Normal file
169
api_testgen/generators/jest.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Jest test generator."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
|
||||
|
||||
from ..core import SpecParser, AuthConfig
|
||||
from ..core.exceptions import GeneratorError, TemplateRenderError
|
||||
|
||||
|
||||
class JestGenerator:
|
||||
"""Generate Jest-compatible integration test templates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_parser: SpecParser,
|
||||
output_dir: str = "tests",
|
||||
mock_server_url: str = "http://localhost:4010",
|
||||
):
|
||||
"""Initialize the jest generator.
|
||||
|
||||
Args:
|
||||
spec_parser: The OpenAPI specification parser.
|
||||
output_dir: Directory for generated test files.
|
||||
mock_server_url: URL of the mock server for testing.
|
||||
"""
|
||||
self.spec_parser = spec_parser
|
||||
self.output_dir = Path(output_dir)
|
||||
self.mock_server_url = mock_server_url
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "jest")),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
def generate(self, output_file: Optional[str] = None) -> List[Path]:
|
||||
"""Generate Jest test files.
|
||||
|
||||
Args:
|
||||
output_file: Optional specific output file path.
|
||||
|
||||
Returns:
|
||||
List of generated file paths.
|
||||
"""
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
info = self.spec_parser.get_info()
|
||||
|
||||
context = {
|
||||
"api_title": info["title"],
|
||||
"api_version": info["version"],
|
||||
"endpoints": endpoints,
|
||||
"mock_server_url": self.mock_server_url,
|
||||
"security_schemes": self.spec_parser.get_security_schemes(),
|
||||
"definitions": self.spec_parser.get_definitions(),
|
||||
}
|
||||
|
||||
generated_files = []
|
||||
|
||||
try:
|
||||
template = self.env.get_template("api.test.js.j2")
|
||||
content = template.render(context)
|
||||
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
|
||||
output_path = self.output_dir / f"{safe_name}.test.js"
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content)
|
||||
generated_files.append(output_path)
|
||||
|
||||
except (TemplateSyntaxError, UndefinedError) as e:
|
||||
raise TemplateRenderError(f"Failed to render Jest template: {e}")
|
||||
|
||||
return generated_files
|
||||
|
||||
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
|
||||
"""Generate test for a specific endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_path: The API endpoint path.
|
||||
method: The HTTP method.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
|
||||
for endpoint in endpoints:
|
||||
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
|
||||
return self._generate_single_test(endpoint)
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate test code for a single endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
test_name = self._generate_test_name(endpoint)
|
||||
describe_name = endpoint["summary"] or endpoint["path"]
|
||||
|
||||
params = self._generate_params(endpoint)
|
||||
|
||||
endpoint_path = endpoint["path"]
|
||||
endpoint_method = endpoint["method"]
|
||||
|
||||
test_code = f'''
|
||||
describe('{describe_name}', () => {{
|
||||
it('should {endpoint_method.upper()} {endpoint_path}', async () => {{
|
||||
const response = await request(baseUrl)
|
||||
.{endpoint_method}('{endpoint_path}'{params});
|
||||
|
||||
expect([200, 201, 204]).toContain(response.status);
|
||||
}});
|
||||
}});
|
||||
'''
|
||||
return test_code
|
||||
|
||||
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate a valid test function name.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
A valid JavaScript function name.
|
||||
"""
|
||||
path = endpoint["path"]
|
||||
method = endpoint["method"]
|
||||
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
|
||||
name = re.sub(r"_+/", "_", name)
|
||||
name = re.sub(r"_+$", "", name)
|
||||
|
||||
return f"{method}_{name}" if name else f"{method}_default"
|
||||
|
||||
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate parameters for test request.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing parameter chain.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
for param in endpoint.get("parameters", []):
|
||||
param_name = param["name"]
|
||||
|
||||
if param["in"] == "path":
|
||||
parts.append(f'{param_name}="test_{param_name}"')
|
||||
|
||||
elif param["in"] == "query":
|
||||
parts.append(f'{param_name}')
|
||||
|
||||
if parts:
|
||||
return ", {" + ", ".join(parts) + "}"
|
||||
return ""
|
||||
199
api_testgen/generators/pytest.py
Normal file
199
api_testgen/generators/pytest.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Pytest test generator."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
|
||||
|
||||
from ..core import SpecParser, AuthConfig
|
||||
from ..core.exceptions import GeneratorError, TemplateRenderError
|
||||
|
||||
|
||||
class PytestGenerator:
|
||||
"""Generate pytest-compatible integration test templates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_parser: SpecParser,
|
||||
output_dir: str = "tests",
|
||||
mock_server_url: str = "http://localhost:4010",
|
||||
):
|
||||
"""Initialize the pytest generator.
|
||||
|
||||
Args:
|
||||
spec_parser: The OpenAPI specification parser.
|
||||
output_dir: Directory for generated test files.
|
||||
mock_server_url: URL of the mock server for testing.
|
||||
"""
|
||||
self.spec_parser = spec_parser
|
||||
self.output_dir = Path(output_dir)
|
||||
self.mock_server_url = mock_server_url
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "pytest")),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
def generate(self, output_file: Optional[str] = None) -> List[Path]:
|
||||
"""Generate pytest test files.
|
||||
|
||||
Args:
|
||||
output_file: Optional specific output file path.
|
||||
|
||||
Returns:
|
||||
List of generated file paths.
|
||||
"""
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
info = self.spec_parser.get_info()
|
||||
|
||||
context = {
|
||||
"api_title": info["title"],
|
||||
"api_version": info["version"],
|
||||
"endpoints": endpoints,
|
||||
"mock_server_url": self.mock_server_url,
|
||||
"security_schemes": self.spec_parser.get_security_schemes(),
|
||||
"definitions": self.spec_parser.get_definitions(),
|
||||
}
|
||||
|
||||
generated_files = []
|
||||
|
||||
try:
|
||||
template = self.env.get_template("test_base.py.j2")
|
||||
content = template.render(context)
|
||||
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
|
||||
output_path = self.output_dir / f"test_{safe_name}.py"
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content)
|
||||
generated_files.append(output_path)
|
||||
|
||||
except (TemplateSyntaxError, UndefinedError) as e:
|
||||
raise TemplateRenderError(f"Failed to render pytest template: {e}")
|
||||
|
||||
return generated_files
|
||||
|
||||
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
|
||||
"""Generate test for a specific endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_path: The API endpoint path.
|
||||
method: The HTTP method.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
endpoints = self.spec_parser.get_endpoints()
|
||||
|
||||
for endpoint in endpoints:
|
||||
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
|
||||
return self._generate_single_test(endpoint)
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate test code for a single endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing the test code.
|
||||
"""
|
||||
test_name = self._generate_test_name(endpoint)
|
||||
params = self._generate_parameters(endpoint)
|
||||
auth_headers = self._generate_auth_headers(endpoint)
|
||||
|
||||
test_code = f'''
|
||||
def test_{test_name}(base_url, {params}):
|
||||
"""Test {endpoint["summary"] or endpoint["path"]} endpoint."""
|
||||
url = f"{{base_url}}{endpoint['path']}"
|
||||
|
||||
headers = {{"Content-Type": "application/json"}}
|
||||
{auth_headers}
|
||||
|
||||
response = requests.{endpoint['method']}(url, json={{}} if method == "POST" else None, headers=headers)
|
||||
|
||||
assert response.status_code in [200, 201, 204]
|
||||
'''
|
||||
return test_code
|
||||
|
||||
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate a valid test function name.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
A valid Python function name.
|
||||
"""
|
||||
path = endpoint["path"]
|
||||
method = endpoint["method"]
|
||||
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
|
||||
name = re.sub(r"_+/", "_", name)
|
||||
name = re.sub(r"_+$", "", name)
|
||||
|
||||
return f"{method}_{name}" if name else f"{method}_default"
|
||||
|
||||
def _generate_parameters(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate parameters for test function.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing parameter declarations.
|
||||
"""
|
||||
params = []
|
||||
|
||||
for param in endpoint.get("parameters", []):
|
||||
param_name = param["name"]
|
||||
|
||||
if param["in"] == "path":
|
||||
params.append(f'{param_name}="test_{param_name}"')
|
||||
|
||||
elif param["in"] == "query":
|
||||
params.append(f'{param_name}=None')
|
||||
|
||||
return ", ".join(params)
|
||||
|
||||
def _generate_auth_headers(self, endpoint: Dict[str, Any]) -> str:
|
||||
"""Generate authentication headers for endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint dictionary.
|
||||
|
||||
Returns:
|
||||
String containing header assignments.
|
||||
"""
|
||||
security_requirements = endpoint.get("security", [])
|
||||
|
||||
if not security_requirements:
|
||||
return ""
|
||||
|
||||
schemes = self.spec_parser.get_security_schemes()
|
||||
auth_config = AuthConfig()
|
||||
|
||||
try:
|
||||
auth_config.build_from_spec(schemes, security_requirements)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
headers = []
|
||||
|
||||
for scheme_name in security_requirements[0].keys():
|
||||
method = auth_config.get_auth_method(scheme_name)
|
||||
if method:
|
||||
if method["type"] == "apiKey":
|
||||
headers.append(f'headers["{method["header_name"]}"] = "test_api_key"')
|
||||
elif method["type"] == "bearer":
|
||||
headers.append('headers["Authorization"] = "Bearer test_token"')
|
||||
|
||||
return "\n ".join(headers) if headers else ""
|
||||
5
api_testgen/mocks/__init__.py
Normal file
5
api_testgen/mocks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Mocks module for API TestGen."""
|
||||
|
||||
from .generator import MockServerGenerator
|
||||
|
||||
__all__ = ["MockServerGenerator"]
|
||||
278
api_testgen/mocks/generator.py
Normal file
278
api_testgen/mocks/generator.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Mock server generator for Prism/OpenAPI mock support."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..core import SpecParser
|
||||
|
||||
|
||||
class MockServerGenerator:
|
||||
"""Generate Prism/OpenAPI mock server configurations."""
|
||||
|
||||
DEFAULT_PORT = 4010
|
||||
DEFAULT_HOST = "0.0.0.0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_parser: SpecParser,
|
||||
output_dir: str = ".",
|
||||
):
|
||||
"""Initialize the mock server generator.
|
||||
|
||||
Args:
|
||||
spec_parser: The OpenAPI specification parser.
|
||||
output_dir: Directory for generated configuration files.
|
||||
"""
|
||||
self.spec_parser = spec_parser
|
||||
self.output_dir = Path(output_dir)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prism_config: bool = True,
|
||||
docker_compose: bool = True,
|
||||
dockerfile: bool = True,
|
||||
) -> List[Path]:
|
||||
"""Generate mock server configuration files.
|
||||
|
||||
Args:
|
||||
prism_config: Whether to generate prism-config.json.
|
||||
docker_compose: Whether to generate docker-compose.yml.
|
||||
dockerfile: Whether to generate Dockerfile.
|
||||
|
||||
Returns:
|
||||
List of generated file paths.
|
||||
"""
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generated_files = []
|
||||
|
||||
if prism_config:
|
||||
generated_files.append(self.generate_prism_config())
|
||||
|
||||
if docker_compose:
|
||||
generated_files.append(self.generate_docker_compose())
|
||||
|
||||
if dockerfile:
|
||||
generated_files.append(self.generate_dockerfile())
|
||||
|
||||
return generated_files
|
||||
|
||||
def generate_prism_config(self, output_file: Optional[str] = None) -> Path:
|
||||
"""Generate Prism mock server configuration.
|
||||
|
||||
Args:
|
||||
output_file: Optional output file path.
|
||||
|
||||
Returns:
|
||||
Path to the generated file.
|
||||
"""
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
output_path = self.output_dir / "prism-config.json"
|
||||
|
||||
config = {
|
||||
"mock": {
|
||||
"host": self.DEFAULT_HOST,
|
||||
"port": self.DEFAULT_PORT,
|
||||
"cors": {
|
||||
"enabled": True,
|
||||
"allowOrigin": "*",
|
||||
"allowMethods": ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"],
|
||||
"allowHeaders": ["Content-Type", "Authorization", "X-API-Key"],
|
||||
},
|
||||
"Validation": {
|
||||
"request": True,
|
||||
"response": True,
|
||||
},
|
||||
},
|
||||
"logging": {
|
||||
"level": "info",
|
||||
"format": "json",
|
||||
},
|
||||
}
|
||||
|
||||
output_path.write_text(json.dumps(config, indent=2))
|
||||
return output_path
|
||||
|
||||
def generate_docker_compose(self, output_file: Optional[str] = None) -> Path:
|
||||
"""Generate Docker Compose configuration for mock server.
|
||||
|
||||
Args:
|
||||
output_file: Optional output file path.
|
||||
|
||||
Returns:
|
||||
Path to the generated file.
|
||||
"""
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
output_path = self.output_dir / "docker-compose.yml"
|
||||
|
||||
spec_info = self.spec_parser.get_info()
|
||||
|
||||
compose_content = f'''version: '3.8'
|
||||
|
||||
services:
|
||||
mock-server:
|
||||
image: stoplight/prism:latest
|
||||
container_name: "{spec_info['title']}-mock-server"
|
||||
command: >
|
||||
mock
|
||||
--spec /app/openapi.yaml
|
||||
--port {self.DEFAULT_PORT}
|
||||
--host {self.DEFAULT_HOST}
|
||||
ports:
|
||||
- "{self.DEFAULT_PORT}:{self.DEFAULT_PORT}"
|
||||
volumes:
|
||||
- ./:/app
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", f"http://localhost:{self.DEFAULT_PORT}/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
mock-server-https:
|
||||
image: stoplight/prism:latest
|
||||
container_name: "{spec_info['title']}-mock-server-https"
|
||||
command: >
|
||||
mock
|
||||
--spec /app/openapi.yaml
|
||||
--port {self.DEFAULT_PORT + 1}
|
||||
--host {self.DEFAULT_HOST}
|
||||
ports:
|
||||
- "{self.DEFAULT_PORT + 1}:{self.DEFAULT_PORT + 1}"
|
||||
volumes:
|
||||
- ./:/app
|
||||
restart: unless-stopped
|
||||
'''
|
||||
|
||||
output_path.write_text(compose_content)
|
||||
return output_path
|
||||
|
||||
def generate_dockerfile(self, output_file: Optional[str] = None) -> Path:
|
||||
"""Generate Dockerfile for mock server.
|
||||
|
||||
Args:
|
||||
output_file: Optional output file path.
|
||||
|
||||
Returns:
|
||||
Path to the generated file.
|
||||
"""
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
output_path = self.output_dir / "Dockerfile"
|
||||
|
||||
spec_info = self.spec_parser.get_info()
|
||||
|
||||
dockerfile_content = f'''FROM stoplight/prism:latest
|
||||
|
||||
LABEL maintainer="developer@example.com"
|
||||
LABEL description="Mock server for {spec_info['title']} API"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY openapi.yaml .
|
||||
|
||||
EXPOSE {self.DEFAULT_PORT}
|
||||
|
||||
CMD ["mock", "--spec", "openapi.yaml", "--port", "{self.DEFAULT_PORT}", "--host", "0.0.0.0"]
|
||||
'''
|
||||
|
||||
output_path.write_text(dockerfile_content)
|
||||
return output_path
|
||||
|
||||
def generate_start_script(self, output_file: Optional[str] = None) -> Path:
|
||||
"""Generate shell script to start mock server.
|
||||
|
||||
Args:
|
||||
output_file: Optional output file path.
|
||||
|
||||
Returns:
|
||||
Path to the generated file.
|
||||
"""
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
else:
|
||||
output_path = self.output_dir / "start-mock-server.sh"
|
||||
|
||||
script_content = f'''#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${{BASH_SOURCE[0]}}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
PORT={self.DEFAULT_PORT}
|
||||
HOST="0.0.0.0"
|
||||
|
||||
echo "Starting mock server for API..."
|
||||
echo "Mock server will be available at: http://localhost:$PORT"
|
||||
|
||||
docker compose up -d mock-server
|
||||
|
||||
echo "Mock server started successfully!"
|
||||
echo "To stop: docker compose down"
|
||||
'''
|
||||
|
||||
output_path.write_text(script_content)
|
||||
output_path.chmod(0o755)
|
||||
return output_path
|
||||
|
||||
def generate_health_check(self) -> Path:
|
||||
"""Generate health check endpoint configuration.
|
||||
|
||||
Returns:
|
||||
Path to the generated file.
|
||||
"""
|
||||
output_path = self.output_dir / "health-check.json"
|
||||
|
||||
spec_info = self.spec_parser.get_info()
|
||||
|
||||
health_check = {
|
||||
"openapi": self.spec_parser.version,
|
||||
"info": {
|
||||
"title": f"{spec_info['title']} - Health Check",
|
||||
"version": spec_info["version"],
|
||||
"description": "Health check endpoint for the mock server",
|
||||
},
|
||||
"paths": {
|
||||
"/health": {
|
||||
"get": {
|
||||
"summary": "Health check endpoint",
|
||||
"description": "Returns the health status of the mock server",
|
||||
"operationId": "healthCheck",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Mock server is healthy",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["healthy"],
|
||||
},
|
||||
"api": {
|
||||
"type": "string",
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
output_path.write_text(json.dumps(health_check, indent=2))
|
||||
return output_path
|
||||
Reference in New Issue
Block a user