fix: add Gitea Actions CI workflow for automated testing

This commit is contained in:
CI Bot
2026-02-06 06:37:08 +00:00
parent 40a6a4f7d4
commit 839317c44b
24 changed files with 3115 additions and 0 deletions

62
.gitea/workflows/ci.yml Normal file
View File

@@ -0,0 +1,62 @@
name: CI
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
- name: Run unit tests
run: pytest tests/unit/ -v
- name: Run integration tests
run: pytest tests/integration/ -v
- name: Run tests with coverage
run: python -m pytest tests/ --cov=api_testgen
- name: Upload coverage report
uses: codecov/codecov-action@v4
with:
files: ./coverage.xml
fail_ci_if_error: false
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install linting tools
run: pip install ruff mypy
- name: Run ruff
run: ruff check .
- name: Run mypy
run: mypy api_testgen/ --ignore-missing-imports

3
api_testgen/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""API TestGen - OpenAPI Specification Test Generator."""
__version__ = "0.1.0"

View File

@@ -0,0 +1 @@
"""CLI module for API TestGen."""

277
api_testgen/cli/main.py Normal file
View File

@@ -0,0 +1,277 @@
"""CLI interface for API TestGen."""
from pathlib import Path
from typing import Optional
import click
import yaml
from ..core import SpecParser, AuthConfig
from ..core.exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
from ..generators import PytestGenerator, JestGenerator, GoGenerator
from ..mocks import MockServerGenerator
@click.group()
@click.version_option(version="0.1.0")
@click.option(
"--spec",
"-s",
type=click.Path(exists=True, file_okay=True, dir_okay=False),
help="Path to OpenAPI specification file",
)
@click.option(
"--output",
"-o",
type=click.Path(file_okay=False, dir_okay=True),
help="Output directory for generated files",
)
@click.option(
"--mock-url",
default="http://localhost:4010",
help="URL of the mock server",
)
@click.pass_context
def main(
ctx: click.Context,
spec: str,
output: str,
mock_url: str,
):
"""API TestGen - Generate integration tests from OpenAPI specifications."""
ctx.ensure_object(dict)
ctx.obj["spec"] = spec
ctx.obj["output"] = output or "./generated"
ctx.obj["mock_url"] = mock_url
@main.command("parse")
@click.pass_context
def parse_spec(ctx: click.Context):
"""Parse and validate an OpenAPI specification."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
try:
parser = SpecParser(spec_path)
spec = parser.load()
info = parser.get_info()
endpoints = parser.get_endpoints()
security_schemes = parser.get_security_schemes()
click.echo(f"API: {info['title']} v{info['version']}")
click.echo(f"OpenAPI Version: {parser.version}")
click.echo(f"Base Path: {parser.base_path}")
click.echo(f"Endpoints: {len(endpoints)}")
click.echo(f"Security Schemes: {len(security_schemes)}")
click.echo()
for endpoint in endpoints:
click.echo(f" {endpoint['method'].upper():6} {endpoint['path']}")
ctx.obj["parser"] = parser
except InvalidOpenAPISpecError as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
except UnsupportedVersionError as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("generate")
@click.argument("framework", type=click.Choice(["pytest", "jest", "go"]))
@click.option(
"--output-file",
"-f",
type=click.Path(file_okay=True, dir_okay=False),
help="Specific output file path",
)
@click.option(
"--package-name",
default="apitest",
help="Go package name (only for go framework)",
)
@click.pass_context
def generate_tests(
ctx: click.Context,
framework: str,
output_file: str,
package_name: str,
):
"""Generate test files for a framework (pytest, jest, or go)."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
output_dir = ctx.obj["output"]
mock_url = ctx.obj["mock_url"]
try:
parser = SpecParser(spec_path)
parser.load()
if framework == "pytest":
generator = PytestGenerator(parser, output_dir, mock_url)
files = generator.generate(output_file)
elif framework == "jest":
generator = JestGenerator(parser, output_dir, mock_url)
files = generator.generate(output_file)
elif framework == "go":
generator = GoGenerator(parser, output_dir, mock_url, package_name)
files = generator.generate(output_file)
click.echo(f"Generated {len(files)} test file(s):")
for f in files:
click.echo(f" - {f}")
except Exception as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("mock")
@click.option(
"--no-prism-config",
is_flag=True,
help="Skip generating prism-config.json",
)
@click.option(
"--no-docker-compose",
is_flag=True,
help="Skip generating docker-compose.yml",
)
@click.option(
"--no-dockerfile",
is_flag=True,
help="Skip generating Dockerfile",
)
@click.pass_context
def generate_mock(
ctx: click.Context,
no_prism_config: bool,
no_docker_compose: bool,
no_dockerfile: bool,
):
"""Generate mock server configuration files."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
output_dir = ctx.obj["output"]
try:
parser = SpecParser(spec_path)
parser.load()
generator = MockServerGenerator(parser, output_dir)
files = generator.generate(
prism_config=not no_prism_config,
docker_compose=not no_docker_compose,
dockerfile=not no_dockerfile,
)
click.echo(f"Generated {len(files)} mock server file(s):")
for f in files:
click.echo(f" - {f}")
except Exception as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("all")
@click.argument("framework", type=click.Choice(["pytest", "jest", "go"]))
@click.pass_context
def generate_all(
ctx: click.Context,
framework: str,
):
"""Generate test files and mock server configuration."""
spec_path = ctx.obj["spec"]
if not spec_path:
click.echo("Error: --spec option is required", err=True)
raise click.Abort()
output_dir = ctx.obj["output"]
mock_url = ctx.obj["mock_url"]
try:
parser = SpecParser(spec_path)
parser.load()
click.echo("Generating tests...")
if framework == "pytest":
generator = PytestGenerator(parser, output_dir, mock_url)
files = generator.generate()
elif framework == "jest":
generator = JestGenerator(parser, output_dir, mock_url)
files = generator.generate()
elif framework == "go":
generator = GoGenerator(parser, output_dir, mock_url)
files = generator.generate()
click.echo(f"Generated {len(files)} test file(s)")
click.echo("Generating mock server configuration...")
mock_generator = MockServerGenerator(parser, output_dir)
mock_files = mock_generator.generate()
click.echo(f"Generated {len(mock_files)} mock server file(s)")
click.echo("\nAll files generated successfully!")
except Exception as e:
click.echo(f"Error: {e}", err=True)
raise click.Abort()
@main.command("auth")
@click.argument("scheme_name")
@click.option("--type", "auth_type", type=click.Choice(["apiKey", "bearer", "basic"]), help="Authentication type")
@click.option("--header", help="Header name for API key", default="X-API-Key")
@click.option("--token", help="Bearer token or API key value")
@click.option("--username", help="Username for Basic auth")
@click.option("--password", help="Password for Basic auth")
@click.pass_context
def configure_auth(
ctx: click.Context,
scheme_name: str,
auth_type: str,
header: str,
token: str,
username: str,
password: str,
):
"""Configure authentication for a security scheme."""
auth_config = AuthConfig()
if auth_type == "apiKey":
auth_config.add_api_key(scheme_name, header, token or "")
elif auth_type == "bearer":
auth_config.add_bearer(scheme_name, token or "")
elif auth_type == "basic":
auth_config.add_basic(scheme_name, username or "", password or "")
click.echo(f"Authentication scheme '{scheme_name}' configured:")
methods = auth_config.get_all_methods()
for name, method in methods.items():
click.echo(f" - {name}: {method['type'].value}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
"""Core module for API TestGen."""
from .spec_parser import SpecParser
from .auth import AuthConfig
__all__ = ["SpecParser", "AuthConfig"]

313
api_testgen/core/auth.py Normal file
View File

@@ -0,0 +1,313 @@
"""Authentication configuration for API TestGen."""
from typing import Any, Dict, List, Optional, Union
from enum import Enum
from .exceptions import AuthConfigError, MissingSecuritySchemeError
class AuthType(str, Enum):
"""Types of authentication."""
API_KEY = "apiKey"
BEARER = "bearer"
BASIC = "basic"
NONE = "none"
class AuthConfig:
"""Authentication configuration for generated tests."""
def __init__(self):
"""Initialize authentication configuration."""
self._auth_methods: Dict[str, Dict[str, Any]] = {}
def add_api_key(
self,
scheme_name: str,
header_name: str = "X-API-Key",
api_key: str = "",
) -> "AuthConfig":
"""Add API key authentication.
Args:
scheme_name: Name of the security scheme in OpenAPI spec.
header_name: Name of the header containing the API key.
api_key: The API key value (can be set later).
Returns:
Self for method chaining.
"""
self._auth_methods[scheme_name] = {
"type": AuthType.API_KEY,
"header_name": header_name,
"api_key": api_key,
}
return self
def add_bearer(
self,
scheme_name: str,
token: str = "",
token_prefix: str = "Bearer",
) -> "AuthConfig":
"""Add Bearer token authentication.
Args:
scheme_name: Name of the security scheme in OpenAPI spec.
token: The Bearer token value (can be set later).
token_prefix: The token prefix (default: Bearer).
Returns:
Self for method chaining.
"""
self._auth_methods[scheme_name] = {
"type": AuthType.BEARER,
"token": token,
"token_prefix": token_prefix,
}
return self
def add_basic(
self,
scheme_name: str,
username: str = "",
password: str = "",
) -> "AuthConfig":
"""Add Basic authentication.
Args:
scheme_name: Name of the security scheme in OpenAPI spec.
username: The username (can be set later).
password: The password (can be set later).
Returns:
Self for method chaining.
"""
self._auth_methods[scheme_name] = {
"type": AuthType.BASIC,
"username": username,
"password": password,
}
return self
def get_auth_method(self, scheme_name: str) -> Optional[Dict[str, Any]]:
"""Get authentication method by scheme name.
Args:
scheme_name: Name of the security scheme.
Returns:
Authentication method configuration or None.
"""
return self._auth_methods.get(scheme_name)
def get_all_methods(self) -> Dict[str, Dict[str, Any]]:
"""Get all configured authentication methods.
Returns:
Dictionary of scheme names and their configurations.
"""
return self._auth_methods.copy()
def get_headers(self, scheme_name: str) -> Dict[str, str]:
"""Get authentication headers for a scheme.
Args:
scheme_name: Name of the security scheme.
Returns:
Dictionary of header names and values.
Raises:
AuthConfigError: If scheme is not configured.
"""
method = self.get_auth_method(scheme_name)
if not method:
raise AuthConfigError(f"Authentication scheme '{scheme_name}' not configured")
if method["type"] == AuthType.API_KEY:
return {method["header_name"]: method["api_key"]}
elif method["type"] == AuthType.BEARER:
return {"Authorization": f"{method['token_prefix']} {method['token']}"}
elif method["type"] == AuthType.BASIC:
import base64
credentials = f"{method['username']}:{method['password']}"
encoded = base64.b64encode(credentials.encode()).decode()
return {"Authorization": f"Basic {encoded}"}
return {}
def build_from_spec(
self,
security_schemes: Dict[str, Any],
security_requirements: List[Dict[str, Any]],
) -> "AuthConfig":
"""Build auth configuration from OpenAPI security schemes.
Args:
security_schemes: Security schemes from OpenAPI spec.
security_requirements: Security requirements from endpoint.
Returns:
Self for method chaining.
Raises:
MissingSecuritySchemeError: If required scheme is not defined.
"""
for requirement in security_requirements:
for scheme_name in requirement.keys():
if scheme_name not in self._auth_methods:
if scheme_name not in security_schemes:
raise MissingSecuritySchemeError(
f"Security scheme '{scheme_name}' not found in spec"
)
scheme = security_schemes[scheme_name]
self._add_scheme_from_spec(scheme_name, scheme)
return self
def _add_scheme_from_spec(self, scheme_name: str, scheme: Dict[str, Any]) -> None:
"""Add authentication scheme from OpenAPI spec definition.
Args:
scheme_name: Name of the security scheme.
scheme: The security scheme definition from OpenAPI spec.
"""
scheme_type = scheme.get("type", "")
if scheme_type == "apiKey":
self.add_api_key(
scheme_name,
header_name=scheme.get("name", "X-API-Key"),
)
elif scheme_type == "http":
scheme_scheme = scheme.get("scheme", "").lower()
if scheme_scheme == "bearer":
self.add_bearer(scheme_name)
elif scheme_scheme == "basic":
self.add_basic(scheme_name)
elif scheme_type == "openIdConnect":
self.add_bearer(scheme_name)
elif scheme_type == "oauth2":
self.add_bearer(scheme_name)
def generate_auth_code(self, scheme_name: str, framework: str = "pytest") -> str:
"""Generate authentication code for a test framework.
Args:
scheme_name: Name of the security scheme.
framework: Target test framework (pytest, jest, go).
Returns:
String containing authentication code snippet.
"""
method = self.get_auth_method(scheme_name)
if not method:
return ""
if framework == "pytest":
return self._generate_pytest_auth(method)
elif framework == "jest":
return self._generate_jest_auth(method)
elif framework == "go":
return self._generate_go_auth(method)
return ""
def _generate_pytest_auth(self, method: Dict[str, Any]) -> str:
"""Generate pytest authentication code.
Args:
method: Authentication method configuration.
Returns:
String containing pytest auth code.
"""
if method["type"] == AuthType.API_KEY:
return f'''
@pytest.fixture
def api_key_headers():
return {{"{method['header_name']}": "{method['api_key']}"}}
'''
elif method["type"] == AuthType.BEARER:
return f'''
@pytest.fixture
def bearer_headers():
return {{"Authorization": "{method['token_prefix']} {method['token']}"}}
'''
elif method["type"] == AuthType.BASIC:
return f'''
import base64
@pytest.fixture
def basic_headers():
credentials = f"{{"{method['username']}"}}:{{"{method['password']}"}}"
encoded = base64.b64encode(credentials.encode()).decode()
return {{"Authorization": f"Basic {{encoded}}"}}
'''
return ""
def _generate_jest_auth(self, method: Dict[str, Any]) -> str:
"""Generate Jest authentication code.
Args:
method: Authentication method configuration.
Returns:
String containing Jest auth code.
"""
if method["type"] == AuthType.API_KEY:
return f'''
const getApiKeyHeaders = () => ({{
"{method['header_name']}": process.env.API_KEY || "{method['api_key']}",
}});
'''
elif method["type"] == AuthType.BEARER:
return f'''
const getBearerHeaders = () => ({{
Authorization: `${{process.env.TOKEN_PREFIX || "{method['token_prefix']}"}} ${{process.env.TOKEN || "{method['token']}"}}`,
}});
'''
elif method["type"] == AuthType.BASIC:
return f'''
const getBasicHeaders = () => {{
const credentials = Buffer.from(`${{process.env.USERNAME || "{method['username']}"}}:${{process.env.PASSWORD || "{method['password']}"}}`).toString('base64');
return {{ Authorization: `Basic ${{credentials}}` }};
}};
'''
return ""
def _generate_go_auth(self, method: Dict[str, Any]) -> str:
"""Generate Go authentication code.
Args:
method: Authentication method configuration.
Returns:
String containing Go auth code.
"""
if method["type"] == AuthType.API_KEY:
return f'''
func getAPIKeyHeaders() map[string]string {{
return map[string]string{{
"{method['header_name']}": os.Getenv("API_KEY"),
}}
}}
'''
elif method["type"] == AuthType.BEARER:
return f'''
func getBearerHeaders() map[string]string {{
return map[string]string{{
"Authorization": fmt.Sprintf("%s %s", os.Getenv("TOKEN_PREFIX"), os.Getenv("TOKEN")),
}}
}}
'''
elif method["type"] == AuthType.BASIC:
return f'''
func getBasicHeaders(username, password string) map[string]string {{
auth := username + ":" + password
encoded := base64.StdEncoding.EncodeToString([]byte(auth))
return map[string]string{{
"Authorization": "Basic " + encoded,
}}
}}
'''
return ""

View File

@@ -0,0 +1,36 @@
"""Custom exceptions for API TestGen."""
class SpecParserError(Exception):
"""Base exception for spec parser errors."""
pass
class InvalidOpenAPISpecError(SpecParserError):
"""Raised when OpenAPI specification is invalid."""
pass
class UnsupportedVersionError(SpecParserError):
"""Raised when OpenAPI version is not supported."""
pass
class AuthConfigError(Exception):
"""Base exception for auth configuration errors."""
pass
class MissingSecuritySchemeError(AuthConfigError):
"""Raised when security scheme is not defined in spec."""
pass
class GeneratorError(Exception):
"""Base exception for generator errors."""
pass
class TemplateRenderError(GeneratorError):
"""Raised when template rendering fails."""
pass

View File

@@ -0,0 +1,307 @@
"""OpenAPI Specification Parser."""
import yaml
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from openapi_spec_validator import validate
from openapi_spec_validator.versions import consts as validator_consts
from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
class SpecParser:
"""Parse and validate OpenAPI/Swagger specifications."""
SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
def __init__(self, spec_path: Union[str, Path]):
"""Initialize the spec parser.
Args:
spec_path: Path to OpenAPI specification file (YAML or JSON).
"""
self.spec_path = Path(spec_path)
self.spec: Dict[str, Any] = {}
self.version: str = ""
self.base_path: str = ""
self.servers: List[Dict[str, str]] = []
def load(self) -> Dict[str, Any]:
"""Load and validate the OpenAPI specification.
Returns:
The parsed specification dictionary.
Raises:
InvalidOpenAPISpecError: If the specification is invalid.
UnsupportedVersionError: If the OpenAPI version is not supported.
"""
self.spec = self._load_file()
self._validate()
self._extract_metadata()
return self.spec
def _load_file(self) -> Dict[str, Any]:
"""Load the specification file.
Returns:
Parsed specification dictionary.
Raises:
InvalidOpenAPISpecError: If the file cannot be loaded or parsed.
"""
if not self.spec_path.exists():
raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}")
try:
with open(self.spec_path, 'r', encoding='utf-8') as f:
if self.spec_path.suffix in ['.yaml', '.yml']:
return yaml.safe_load(f) or {}
elif self.spec_path.suffix == '.json':
return json.load(f)
else:
return yaml.safe_load(f) or {}
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}")
def _validate(self) -> None:
"""Validate the specification.
Raises:
InvalidOpenAPISpecError: If the specification is invalid.
UnsupportedVersionError: If the OpenAPI version is not supported.
"""
try:
validate(self.spec)
except Exception as e:
raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}")
version = self._get_version()
if version not in self.SUPPORTED_VERSIONS:
raise UnsupportedVersionError(
f"Unsupported OpenAPI version: {version}. "
f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}"
)
def _get_version(self) -> str:
"""Extract the OpenAPI version from the spec.
Returns:
The OpenAPI version string.
"""
if "openapi" in self.spec:
return self.spec["openapi"]
elif "swagger" in self.spec:
return self.spec["swagger"]
return "2.0"
def _extract_metadata(self) -> None:
"""Extract metadata from the specification."""
self.version = self._get_version()
self.base_path = self.spec.get("basePath", "")
self.servers = self.spec.get("servers", [])
def get_paths(self) -> Dict[str, Any]:
"""Get all paths from the specification.
Returns:
Dictionary of paths and their operations.
"""
return self.spec.get("paths", {})
def get_endpoints(self) -> List[Dict[str, Any]]:
"""Extract all endpoints from the specification.
Returns:
List of endpoint dictionaries with path, method, and details.
"""
endpoints = []
paths = self.get_paths()
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]:
endpoint = {
"path": path,
"method": method.lower(),
"operation_id": operation.get("operationId", ""),
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"tags": operation.get("tags", []),
"parameters": self._extract_parameters(path_item, operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"security": self._extract_security(operation),
}
endpoints.append(endpoint)
return endpoints
def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]:
"""Extract parameters from path item and operation.
Args:
path_item: The path item dictionary.
operation: The operation dictionary.
Returns:
List of parameter dictionaries.
"""
parameters = []
for param in path_item.get("parameters", []):
if param.get("in") != "body":
parameters.append(self._normalize_parameter(param))
for param in operation.get("parameters", []):
if param.get("in") != "body":
parameters.append(self._normalize_parameter(param))
return parameters
def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize a parameter.
Args:
param: The parameter dictionary.
Returns:
Normalized parameter dictionary.
"""
return {
"name": param.get("name", ""),
"in": param.get("in", ""),
"description": param.get("description", ""),
"required": param.get("required", False),
"schema": param.get("schema", {}),
"type": param.get("type", ""),
"enum": param.get("enum", []),
"default": param.get("default"),
}
def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]:
"""Extract request body from operation.
Args:
operation: The operation dictionary.
Returns:
Request body dictionary or None.
"""
if self.version.startswith("3."):
request_body = operation.get("requestBody", {})
if not request_body:
return None
content = request_body.get("content", {})
media_types = list(content.keys())
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"media_types": media_types,
"schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {},
}
else:
params = operation.get("parameters", [])
for param in params:
if param.get("in") == "body":
return {
"description": param.get("description", ""),
"required": param.get("required", False),
"schema": param.get("schema", {}),
}
return None
def _extract_responses(self, operation: Dict) -> Dict[str, Any]:
"""Extract responses from operation.
Args:
operation: The operation dictionary.
Returns:
Dictionary of response status codes and their details.
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
content = response.get("content", {})
if self.version.startswith("3."):
media_types = list(content.keys())
schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {}
else:
schema = response.get("schema", {})
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"media_type": list(content.keys())[0] if content else "application/json",
}
return responses
def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]:
"""Extract security requirements from operation.
Args:
operation: The operation dictionary.
Returns:
List of security requirement dictionaries.
"""
return operation.get("security", self.spec.get("security", []))
def get_security_schemes(self) -> Dict[str, Any]:
"""Get security schemes from the specification.
Returns:
Dictionary of security scheme names and their definitions.
"""
if self.version.startswith("3."):
return self.spec.get("components", {}).get("securitySchemes", {})
else:
return self.spec.get("securityDefinitions", {})
def get_definitions(self) -> Dict[str, Any]:
"""Get schema definitions from the specification.
Returns:
Dictionary of schema definitions.
"""
if self.version.startswith("3."):
return self.spec.get("components", {}).get("schemas", {})
else:
return self.spec.get("definitions", {})
def get_info(self) -> Dict[str, str]:
"""Get API info from the specification.
Returns:
Dictionary with title, version, and description.
"""
info = self.spec.get("info", {})
return {
"title": info.get("title", "API"),
"version": info.get("version", "1.0.0"),
"description": info.get("description", ""),
}
def to_dict(self) -> Dict[str, Any]:
"""Convert the spec to a dictionary.
Returns:
Dictionary representation of the parsed spec.
"""
return {
"version": self.version,
"base_path": self.base_path,
"servers": self.servers,
"info": self.get_info(),
"paths": self.get_paths(),
"endpoints": self.get_endpoints(),
"security_schemes": self.get_security_schemes(),
"definitions": self.get_definitions(),
}

View File

@@ -0,0 +1,7 @@
"""Generators module for API TestGen."""
from .pytest import PytestGenerator
from .jest import JestGenerator
from .go import GoGenerator
__all__ = ["PytestGenerator", "JestGenerator", "GoGenerator"]

View File

@@ -0,0 +1,229 @@
"""Go test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser, AuthConfig
from ..core.exceptions import GeneratorError, TemplateRenderError
class GoGenerator:
"""Generate Go-compatible test files."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
package_name: str = "apitest",
):
"""Initialize the go generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
package_name: Go package name for test files.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.package_name = package_name
self.env = Environment(
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate Go test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
grouped_endpoints = self._group_endpoints_by_path(endpoints)
context = {
"package_name": self.package_name,
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"grouped_endpoints": grouped_endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("api_test.go.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"{safe_name}_test.go"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render Go template: {e}")
return generated_files
def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""Group endpoints by their path.
Args:
endpoints: List of endpoint dictionaries.
Returns:
Dictionary mapping paths to their endpoints.
"""
grouped = {}
for endpoint in endpoints:
path = endpoint["path"]
if path not in grouped:
grouped[path] = []
grouped[path].append(endpoint)
return grouped
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
test_name = self._generate_test_name(endpoint)
params = self._generate_params(endpoint)
url_params = self._generate_url_params(endpoint)
test_code = f'''
func Test{test_name}(t *testing.T) {{
client := &http.Client{{Timeout: 10 * time.Second}}
url := baseURL + "{endpoint['path']}"
var req *http.Request
var err error
{params}
req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil)
if err != nil {{
t.Fatal(err)
}}
for k, v := range getAuthHeaders() {{
req.Header.Set(k, v)
}}
resp, err := client.Do(req)
if err != nil {{
t.Fatalf("Request failed: %v", err)
}}
defer resp.Body.Close()
if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{
t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode)
}}
}}
'''
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid Go test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid Go function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
name = name.title().replace("_", "")
return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default"
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameter variables for test.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter declarations.
"""
params = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
params.append(f'{param_name} := "test_{param_name}"')
params.append(f'url = strings.Replace(url, "{{'+param_name+'}}", {param_name}, 1)')
elif param["in"] == "query":
params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}')
params.append(f'url += "?" + q.Encode()')
return "\n ".join(params) if params else ""
def _generate_url_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate URL parameters.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing URL parameter handling.
"""
path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"]
query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"]
parts = []
for param in path_params:
parts.append(f'strings.Replace(url, "{{'+param['name']+'}}", "test_' + param['name'] + '", 1)')
return ""

View File

@@ -0,0 +1,169 @@
"""Jest test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser, AuthConfig
from ..core.exceptions import GeneratorError, TemplateRenderError
class JestGenerator:
"""Generate Jest-compatible integration test templates."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
):
"""Initialize the jest generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.env = Environment(
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "jest")),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate Jest test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
context = {
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("api.test.js.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"{safe_name}.test.js"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render Jest template: {e}")
return generated_files
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
test_name = self._generate_test_name(endpoint)
describe_name = endpoint["summary"] or endpoint["path"]
params = self._generate_params(endpoint)
endpoint_path = endpoint["path"]
endpoint_method = endpoint["method"]
test_code = f'''
describe('{describe_name}', () => {{
it('should {endpoint_method.upper()} {endpoint_path}', async () => {{
const response = await request(baseUrl)
.{endpoint_method}('{endpoint_path}'{params});
expect([200, 201, 204]).toContain(response.status);
}});
}});
'''
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid JavaScript function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
return f"{method}_{name}" if name else f"{method}_default"
def _generate_params(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameters for test request.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter chain.
"""
parts = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
parts.append(f'{param_name}="test_{param_name}"')
elif param["in"] == "query":
parts.append(f'{param_name}')
if parts:
return ", {" + ", ".join(parts) + "}"
return ""

View File

@@ -0,0 +1,199 @@
"""Pytest test generator."""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError
from ..core import SpecParser, AuthConfig
from ..core.exceptions import GeneratorError, TemplateRenderError
class PytestGenerator:
"""Generate pytest-compatible integration test templates."""
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = "tests",
mock_server_url: str = "http://localhost:4010",
):
"""Initialize the pytest generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated test files.
mock_server_url: URL of the mock server for testing.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
self.mock_server_url = mock_server_url
self.env = Environment(
loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "pytest")),
trim_blocks=True,
lstrip_blocks=True,
)
def generate(self, output_file: Optional[str] = None) -> List[Path]:
"""Generate pytest test files.
Args:
output_file: Optional specific output file path.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
endpoints = self.spec_parser.get_endpoints()
info = self.spec_parser.get_info()
context = {
"api_title": info["title"],
"api_version": info["version"],
"endpoints": endpoints,
"mock_server_url": self.mock_server_url,
"security_schemes": self.spec_parser.get_security_schemes(),
"definitions": self.spec_parser.get_definitions(),
}
generated_files = []
try:
template = self.env.get_template("test_base.py.j2")
content = template.render(context)
if output_file:
output_path = Path(output_file)
else:
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower())
output_path = self.output_dir / f"test_{safe_name}.py"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
generated_files.append(output_path)
except (TemplateSyntaxError, UndefinedError) as e:
raise TemplateRenderError(f"Failed to render pytest template: {e}")
return generated_files
def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str:
"""Generate test for a specific endpoint.
Args:
endpoint_path: The API endpoint path.
method: The HTTP method.
Returns:
String containing the test code.
"""
endpoints = self.spec_parser.get_endpoints()
for endpoint in endpoints:
if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower():
return self._generate_single_test(endpoint)
return ""
def _generate_single_test(self, endpoint: Dict[str, Any]) -> str:
"""Generate test code for a single endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing the test code.
"""
test_name = self._generate_test_name(endpoint)
params = self._generate_parameters(endpoint)
auth_headers = self._generate_auth_headers(endpoint)
test_code = f'''
def test_{test_name}(base_url, {params}):
"""Test {endpoint["summary"] or endpoint["path"]} endpoint."""
url = f"{{base_url}}{endpoint['path']}"
headers = {{"Content-Type": "application/json"}}
{auth_headers}
response = requests.{endpoint['method']}(url, json={{}} if method == "POST" else None, headers=headers)
assert response.status_code in [200, 201, 204]
'''
return test_code
def _generate_test_name(self, endpoint: Dict[str, Any]) -> str:
"""Generate a valid test function name.
Args:
endpoint: The endpoint dictionary.
Returns:
A valid Python function name.
"""
path = endpoint["path"]
method = endpoint["method"]
name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/"))
name = re.sub(r"_+/", "_", name)
name = re.sub(r"_+$", "", name)
return f"{method}_{name}" if name else f"{method}_default"
def _generate_parameters(self, endpoint: Dict[str, Any]) -> str:
"""Generate parameters for test function.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing parameter declarations.
"""
params = []
for param in endpoint.get("parameters", []):
param_name = param["name"]
if param["in"] == "path":
params.append(f'{param_name}="test_{param_name}"')
elif param["in"] == "query":
params.append(f'{param_name}=None')
return ", ".join(params)
def _generate_auth_headers(self, endpoint: Dict[str, Any]) -> str:
"""Generate authentication headers for endpoint.
Args:
endpoint: The endpoint dictionary.
Returns:
String containing header assignments.
"""
security_requirements = endpoint.get("security", [])
if not security_requirements:
return ""
schemes = self.spec_parser.get_security_schemes()
auth_config = AuthConfig()
try:
auth_config.build_from_spec(schemes, security_requirements)
except Exception:
return ""
headers = []
for scheme_name in security_requirements[0].keys():
method = auth_config.get_auth_method(scheme_name)
if method:
if method["type"] == "apiKey":
headers.append(f'headers["{method["header_name"]}"] = "test_api_key"')
elif method["type"] == "bearer":
headers.append('headers["Authorization"] = "Bearer test_token"')
return "\n ".join(headers) if headers else ""

View File

@@ -0,0 +1,5 @@
"""Mocks module for API TestGen."""
from .generator import MockServerGenerator
__all__ = ["MockServerGenerator"]

View File

@@ -0,0 +1,278 @@
"""Mock server generator for Prism/OpenAPI mock support."""
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from ..core import SpecParser
class MockServerGenerator:
"""Generate Prism/OpenAPI mock server configurations."""
DEFAULT_PORT = 4010
DEFAULT_HOST = "0.0.0.0"
def __init__(
self,
spec_parser: SpecParser,
output_dir: str = ".",
):
"""Initialize the mock server generator.
Args:
spec_parser: The OpenAPI specification parser.
output_dir: Directory for generated configuration files.
"""
self.spec_parser = spec_parser
self.output_dir = Path(output_dir)
def generate(
self,
prism_config: bool = True,
docker_compose: bool = True,
dockerfile: bool = True,
) -> List[Path]:
"""Generate mock server configuration files.
Args:
prism_config: Whether to generate prism-config.json.
docker_compose: Whether to generate docker-compose.yml.
dockerfile: Whether to generate Dockerfile.
Returns:
List of generated file paths.
"""
self.output_dir.mkdir(parents=True, exist_ok=True)
generated_files = []
if prism_config:
generated_files.append(self.generate_prism_config())
if docker_compose:
generated_files.append(self.generate_docker_compose())
if dockerfile:
generated_files.append(self.generate_dockerfile())
return generated_files
def generate_prism_config(self, output_file: Optional[str] = None) -> Path:
"""Generate Prism mock server configuration.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "prism-config.json"
config = {
"mock": {
"host": self.DEFAULT_HOST,
"port": self.DEFAULT_PORT,
"cors": {
"enabled": True,
"allowOrigin": "*",
"allowMethods": ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"],
"allowHeaders": ["Content-Type", "Authorization", "X-API-Key"],
},
"Validation": {
"request": True,
"response": True,
},
},
"logging": {
"level": "info",
"format": "json",
},
}
output_path.write_text(json.dumps(config, indent=2))
return output_path
def generate_docker_compose(self, output_file: Optional[str] = None) -> Path:
"""Generate Docker Compose configuration for mock server.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "docker-compose.yml"
spec_info = self.spec_parser.get_info()
compose_content = f'''version: '3.8'
services:
mock-server:
image: stoplight/prism:latest
container_name: "{spec_info['title']}-mock-server"
command: >
mock
--spec /app/openapi.yaml
--port {self.DEFAULT_PORT}
--host {self.DEFAULT_HOST}
ports:
- "{self.DEFAULT_PORT}:{self.DEFAULT_PORT}"
volumes:
- ./:/app
restart: unless-stopped
healthcheck:
test: ["CMD", "wget", "-q", "--spider", f"http://localhost:{self.DEFAULT_PORT}/health"]
interval: 30s
timeout: 10s
retries: 3
mock-server-https:
image: stoplight/prism:latest
container_name: "{spec_info['title']}-mock-server-https"
command: >
mock
--spec /app/openapi.yaml
--port {self.DEFAULT_PORT + 1}
--host {self.DEFAULT_HOST}
ports:
- "{self.DEFAULT_PORT + 1}:{self.DEFAULT_PORT + 1}"
volumes:
- ./:/app
restart: unless-stopped
'''
output_path.write_text(compose_content)
return output_path
def generate_dockerfile(self, output_file: Optional[str] = None) -> Path:
"""Generate Dockerfile for mock server.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "Dockerfile"
spec_info = self.spec_parser.get_info()
dockerfile_content = f'''FROM stoplight/prism:latest
LABEL maintainer="developer@example.com"
LABEL description="Mock server for {spec_info['title']} API"
WORKDIR /app
COPY openapi.yaml .
EXPOSE {self.DEFAULT_PORT}
CMD ["mock", "--spec", "openapi.yaml", "--port", "{self.DEFAULT_PORT}", "--host", "0.0.0.0"]
'''
output_path.write_text(dockerfile_content)
return output_path
def generate_start_script(self, output_file: Optional[str] = None) -> Path:
"""Generate shell script to start mock server.
Args:
output_file: Optional output file path.
Returns:
Path to the generated file.
"""
if output_file:
output_path = Path(output_file)
else:
output_path = self.output_dir / "start-mock-server.sh"
script_content = f'''#!/bin/bash
set -e
SCRIPT_DIR="$(cd "$(dirname "${{BASH_SOURCE[0]}}")" && pwd)"
cd "$SCRIPT_DIR"
PORT={self.DEFAULT_PORT}
HOST="0.0.0.0"
echo "Starting mock server for API..."
echo "Mock server will be available at: http://localhost:$PORT"
docker compose up -d mock-server
echo "Mock server started successfully!"
echo "To stop: docker compose down"
'''
output_path.write_text(script_content)
output_path.chmod(0o755)
return output_path
def generate_health_check(self) -> Path:
"""Generate health check endpoint configuration.
Returns:
Path to the generated file.
"""
output_path = self.output_dir / "health-check.json"
spec_info = self.spec_parser.get_info()
health_check = {
"openapi": self.spec_parser.version,
"info": {
"title": f"{spec_info['title']} - Health Check",
"version": spec_info["version"],
"description": "Health check endpoint for the mock server",
},
"paths": {
"/health": {
"get": {
"summary": "Health check endpoint",
"description": "Returns the health status of the mock server",
"operationId": "healthCheck",
"responses": {
"200": {
"description": "Mock server is healthy",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["healthy"],
},
"api": {
"type": "string",
},
"version": {
"type": "string",
},
},
},
}
},
}
},
}
}
},
}
output_path.write_text(json.dumps(health_check, indent=2))
return output_path

297
examples/petstore.yaml Normal file
View File

@@ -0,0 +1,297 @@
openapi: 3.0.0
info:
title: Pet Store API
description: A sample API for managing pets in a store
version: 1.0.0
servers:
- url: https://api.petstore.example.com/v1
description: Production server
- url: http://localhost:4010
description: Mock server for testing
tags:
- name: pets
description: Operations on pets
- name: store
description: Store management operations
- name: users
description: User management operations
paths:
/pets:
get:
summary: List all pets
description: Returns a list of pets with optional filtering
operationId: listPets
tags:
- pets
parameters:
- name: status
in: query
description: Filter pets by status
schema:
type: string
enum:
- available
- pending
- sold
- name: limit
in: query
description: Maximum number of pets to return
schema:
type: integer
default: 20
- name: offset
in: query
description: Number of pets to skip
schema:
type: integer
default: 0
responses:
'200':
description: A list of pets
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Pet'
post:
summary: Create a new pet
description: Creates a new pet in the store
operationId: createPet
tags:
- pets
security:
- BearerAuth: []
- ApiKeyAuth: []
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PetInput'
responses:
'201':
description: Pet created successfully
content:
application/json:
schema:
$ref: '#/components/schemas/Pet'
'400':
description: Invalid input
/pets/{petId}:
get:
summary: Get a pet by ID
description: Returns a single pet by its ID
operationId: getPetById
tags:
- pets
parameters:
- name: petId
in: path
description: ID of the pet to retrieve
required: true
schema:
type: string
responses:
'200':
description: A single pet
content:
application/json:
schema:
$ref: '#/components/schemas/Pet'
'404':
description: Pet not found
put:
summary: Update a pet
description: Updates an existing pet in the store
operationId: updatePet
tags:
- pets
security:
- BearerAuth: []
parameters:
- name: petId
in: path
description: ID of the pet to update
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PetInput'
responses:
'200':
description: Pet updated successfully
content:
application/json:
schema:
$ref: '#/components/schemas/Pet'
'404':
description: Pet not found
delete:
summary: Delete a pet
description: Deletes a pet from the store
operationId: deletePet
tags:
- pets
security:
- ApiKeyAuth: []
parameters:
- name: petId
in: path
description: ID of the pet to delete
required: true
schema:
type: string
responses:
'204':
description: Pet deleted successfully
'404':
description: Pet not found
/store/inventory:
get:
summary: Get store inventory
description: Returns pet inventory by status
operationId: getInventory
tags:
- store
security:
- ApiKeyAuth: []
responses:
'200':
description: Inventory counts by status
content:
application/json:
schema:
type: object
additionalProperties:
type: integer
/users:
get:
summary: List all users
description: Returns a list of users
operationId: listUsers
tags:
- users
responses:
'200':
description: A list of users
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/User'
post:
summary: Create a user
description: Creates a new user
operationId: createUser
tags:
- users
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/UserInput'
responses:
'201':
description: User created successfully
components:
securitySchemes:
ApiKeyAuth:
type: apiKey
name: X-API-Key
in: header
BearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
schemas:
Pet:
type: object
required:
- name
- status
properties:
id:
type: string
format: uuid
description: Unique identifier
name:
type: string
description: Name of the pet
status:
type: string
enum:
- available
- pending
- sold
description: Status of the pet
tags:
type: array
items:
type: string
description: Tags associated with the pet
photoUrls:
type: array
items:
type: string
format: uri
description: URLs of pet photos
PetInput:
type: object
required:
- name
- status
properties:
name:
type: string
description: Name of the pet
status:
type: string
enum:
- available
- pending
- sold
description: Status of the pet
tags:
type: array
items:
type: string
User:
type: object
properties:
id:
type: string
format: uuid
username:
type: string
description: Unique username
email:
type: string
format: email
firstName:
type: string
lastName:
type: string
UserInput:
type: object
required:
- username
- email
properties:
username:
type: string
email:
type: string
format: email
firstName:
type: string
lastName:
type: string

129
templates/go/api_test.go.j2 Normal file
View File

@@ -0,0 +1,129 @@
package {{ package_name }}
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"testing"
"time"
)
const (
baseURL = "{{ mock_server_url }}"
testTimeout = 10 * time.Second
)
{% if security_schemes %}
{% for scheme_name, scheme in security_schemes.items() %}
{% if scheme.type == "apiKey" %}
func get{{ scheme_name|capitalize }}Headers() map[string]string {
return map[string]string{
"{{ scheme.name }}": os.Getenv("API_KEY"),
}
}
{% elif scheme.type == "http" and scheme.scheme == "bearer" %}
func getBearerHeaders() map[string]string {
return map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", os.Getenv("TOKEN")),
}
}
{% elif scheme.type == "http" and scheme.scheme == "basic" %}
func getBasicHeaders(username, password string) map[string]string {
auth := username + ":" + password
encoded := base64.StdEncoding.EncodeToString([]byte(auth))
return map[string]string{
"Authorization": "Basic " + encoded,
}
}
{% endif %}
{% endfor %}
{% endif %}
func contains(slice []int, item int) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
{% for path, path_endpoints in grouped_endpoints.items() %}
{% for endpoint in path_endpoints %}
{% set test_name = (endpoint.method + "_" + path.strip("/").replace("/", "_").replace("{", "").replace("}", "")).title().replace("_", "") %}
func Test{{ test_name }}(t *testing.T) {
client := &http.Client{Timeout: testTimeout}
url := baseURL + "{{ path }}"
{% for param in endpoint.parameters %}
{% if param.in == "path" %}
url = strings.Replace(url, "{+{{ param.name }}}", "{{ param.name }}", 1)
{% endif %}
{% endfor %}
{% if endpoint.method in ["post", "put", "patch"] %}
body := `{}`
req, err := http.NewRequest("{{ endpoint.method.upper() }}", url, strings.NewReader(body))
{% else %}
req, err := http.NewRequest("{{ endpoint.method.upper() }}", url, nil)
{% endif %}
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
{% if endpoint.security %}
{% set scheme_name = endpoint.security[0].keys()|first %}
{% set scheme = security_schemes[scheme_name] %}
{% if scheme.type == "apiKey" %}
for k, v := range get{{ scheme_name|capitalize }}Headers() {
req.Header.Set(k, v)
}
{% elif scheme.type == "http" and scheme.scheme == "bearer" %}
for k, v := range getBearerHeaders() {
req.Header.Set(k, v)
}
{% endif %}
{% endif %}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if !contains([]int{200, 201, 204}, resp.StatusCode) {
t.Errorf("Expected status in [200, 201, 204], got %d", resp.StatusCode)
}
if resp.StatusCode == 200 || resp.StatusCode == 201 {
var data interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
t.Errorf("Failed to decode response: %v", err)
}
}
}
{% endfor %}
{% endfor %}
func TestAPIHealth(t *testing.T) {
client := &http.Client{Timeout: testTimeout}
resp, err := client.Get(baseURL + "/health")
if err != nil {
t.Fatalf("Health check failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
}

View File

@@ -0,0 +1,101 @@
/**
* Generated Jest tests for {{ api_title }} API v{{ api_version }}
*
* This file was auto-generated by API TestGen.
*/
const request = require('supertest');
const BASE_URL = process.env.MOCK_SERVER_URL || '{{ mock_server_url }}';
{% if security_schemes %}
{% for scheme_name, scheme in security_schemes.items() %}
{% if scheme.type == "apiKey" %}
const get{{ scheme_name|capitalize }}Headers = () => ({
"{{ scheme.name }}": process.env.API_KEY || 'your-api-key',
});
{% elif scheme.type == "http" and scheme.scheme == "bearer" %}
const getBearerHeaders = () => ({
Authorization: `Bearer ${process.env.TOKEN || 'your-token'}`,
});
{% elif scheme.type == "http" and scheme.scheme == "basic" %}
const getBasicHeaders = () => {
const credentials = Buffer.from(`${process.env.USERNAME || 'username'}:${process.env.PASSWORD || 'password'}`).toString('base64');
return { Authorization: `Basic ${credentials}` };
};
{% endif %}
{% endfor %}
{% endif %}
const validateResponse = (response, expectedStatus) => {
expect(response.headers['content-type']).toMatch(/application\/json/);
if (expectedStatus) {
expect(response.status).toBe(expectedStatus);
}
};
describe('{{ api_title }} API', () => {
{% for endpoint in endpoints %}
{% set endpoint_id = endpoint.operation_id or (endpoint.method + "_" + endpoint.path.strip("/").replace("/", "_").replace("{", "").replace("}", "")) %}
describe('{{ endpoint.method.upper() }} {{ endpoint.path }}', () => {
{{ endpoint.description or "" }}
it('should return valid response', async () => {
{% if endpoint.security %}
{% set scheme_name = endpoint.security[0].keys()|first %}
{% set scheme = security_schemes[scheme_name] %}
{% if scheme.type == "apiKey" %}
const headers = get{{ scheme_name|capitalize }}Headers();
{% elif scheme.type == "http" and scheme.scheme == "bearer" %}
const headers = getBearerHeaders();
{% elif scheme.type == "http" and scheme.scheme == "basic" %}
const headers = getBasicHeaders();
{% else %}
const headers = {};
{% endif %}
{% else %}
const headers = { 'Content-Type': 'application/json' };
{% endif %}
{% if endpoint.method in ["post", "put", "patch"] %}
const body = {};
const response = await request(BASE_URL)
.{{ endpoint["method"] }}('{{ endpoint["path"] }}')
.send(body)
.set(headers);
{% else %}
{% if endpoint.parameters|selectattr("in", "equalto", "query")|list %}
{% set query_params = endpoint.parameters|selectattr("in", "equalto", "query")|list %}
const queryParams = {
{% for param in query_params %}
{{ param.name }}: 'test'{% if not loop.last %},
{% endif %}
{% endfor %}
};
const response = await request(BASE_URL)
.{{ endpoint["method"] }}('{{ endpoint["path"] }}')
.query(queryParams)
.set(headers);
{% else %}
const response = await request(BASE_URL)
.{{ endpoint["method"] }}('{{ endpoint["path"] }}')
.set(headers);
{% endif %}
{% endif %}
expect([200, 201, 204]).toContain(response.status);
{% if endpoint.responses["200"] %}
validateResponse(response, {{ endpoint.responses["200"] }});
{% else %}
validateResponse(response);
{% endif %}
});
});
{% endfor %}
});

View File

@@ -0,0 +1,115 @@
"""
Generated pytest tests for {{ api_title }} API v{{ api_version }}
This file was auto-generated by API TestGen.
"""
import pytest
import requests
import json
from jsonschema import validate, ValidationError
BASE_URL = "{{ mock_server_url }}"
{% if security_schemes %}
{% for scheme_name, scheme in security_schemes.items() %}
{% if scheme.type == "apiKey" %}
@pytest.fixture
def {{ scheme_name }}_headers():
"""API Key authentication headers."""
return {"{{ scheme.name }}": "your-api-key"}
{% elif scheme.type == "http" and scheme.scheme == "bearer" %}
@pytest.fixture
def bearer_headers():
"""Bearer token authentication headers."""
return {"Authorization": "Bearer your-token"}
{% elif scheme.type == "http" and scheme.scheme == "basic" %}
@pytest.fixture
def basic_headers():
"""Basic authentication headers."""
import base64
credentials = "username:password"
encoded = base64.b64encode(credentials.encode()).decode()
return {"Authorization": f"Basic {encoded}"}
{% endif %}
{% endfor %}
{% endif %}
{% if definitions %}
@pytest.fixture
def base_url():
"""Base URL for API requests."""
return BASE_URL
{% endif %}
def validate_response(response, status_code=None):
"""Validate API response.
Args:
response: The response object.
status_code: Expected status code (optional).
"""
if status_code:
assert response.status_code == status_code, \
f"Expected status {status_code}, got {response.status_code}"
assert response.headers.get("Content-Type", "").startswith("application/json"), \
"Response Content-Type is not JSON"
{% for endpoint in endpoints %}
{% set endpoint_id = endpoint.operation_id or (endpoint.method + "_" + endpoint.path.strip("/").replace("/", "_").replace("{", "").replace("}", "")) %}
def test_{{ endpoint_id }}(base_url{% if endpoint.parameters|selectattr("in", "equalto", "path")|list %}, {% for param in endpoint.parameters %}{% if param.in == "path" %}{{ param.name }}{% endif %}{% endfor %}{% endif %}):
"""Test {{ endpoint.summary or endpoint.path }} endpoint.
{{ endpoint.description or "" }}
"""
url = f"{base_url}{{ endpoint.path }}"
headers = {"Content-Type": "application/json"}
{% if endpoint.security %}
{% for security_requirement in endpoint.security %}
{% for scheme_name in security_requirement.keys() %}
{% if security_schemes[scheme_name].type == "apiKey" %}
headers["{{ security_schemes[scheme_name].name }}"] = "test-api-key"
{% elif security_schemes[scheme_name].type == "http" and security_schemes[scheme_name].scheme == "bearer" %}
headers["Authorization"] = "Bearer test-token"
{% endif %}
{% endfor %}
{% endfor %}
{% endif %}
{% if endpoint.method in ["post", "put", "patch"] %}
{% if endpoint.request_body %}
request_body = {}
{% else %}
request_body = {}
{% endif %}
response = requests.{{ endpoint["method"] }}(url, json=request_body, headers=headers)
{% else %}
response = requests.{{ endpoint["method"] }}(url, headers=headers{% if endpoint.parameters|selectattr("in", "equalto", "query")|list %}{% set query_params = endpoint.parameters|selectattr("in", "equalto", "query")|list %}{% if query_params %}, params={% for param in query_params %}"{{ param.name }}": "test"{% endfor %}{% endif %}{% endif %})
{% endif %}
validate_response(response{% if endpoint.responses["200"] %}, {{ endpoint.responses["200"] }}{% endif %})
{% if endpoint.responses %}
{% if endpoint.responses["200"] and endpoint.responses["200"].schema %}
try:
data = response.json()
except json.JSONDecodeError:
pytest.fail("Response is not valid JSON")
{% endif %}
{% endif %}
{% endfor %}

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Tests package for API TestGen."""

View File

@@ -0,0 +1 @@
"""Integration tests for API TestGen."""

1
tests/unit/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Unit tests for API TestGen."""

182
tests/unit/test_auth.py Normal file
View File

@@ -0,0 +1,182 @@
"""Unit tests for the auth configuration module."""
import pytest
import tempfile
from pathlib import Path
from api_testgen.core.auth import AuthConfig, AuthType
from api_testgen.core.exceptions import AuthConfigError, MissingSecuritySchemeError
class TestAuthConfig:
"""Tests for AuthConfig class."""
def test_add_api_key(self):
"""Test adding API key authentication."""
auth = AuthConfig()
auth.add_api_key("test_api_key", header_name="X-API-Key", api_key="test123")
method = auth.get_auth_method("test_api_key")
assert method is not None
assert method["type"] == AuthType.API_KEY
assert method["header_name"] == "X-API-Key"
assert method["api_key"] == "test123"
def test_add_bearer_token(self):
"""Test adding Bearer token authentication."""
auth = AuthConfig()
auth.add_bearer("test_bearer", token="abc123", token_prefix="Bearer")
method = auth.get_auth_method("test_bearer")
assert method is not None
assert method["type"] == AuthType.BEARER
assert method["token"] == "abc123"
assert method["token_prefix"] == "Bearer"
def test_add_basic_auth(self):
"""Test adding Basic authentication."""
auth = AuthConfig()
auth.add_basic("test_basic", username="user", password="pass")
method = auth.get_auth_method("test_basic")
assert method is not None
assert method["type"] == AuthType.BASIC
assert method["username"] == "user"
assert method["password"] == "pass"
def test_method_chaining(self):
"""Test that add methods return self for chaining."""
auth = AuthConfig()
result = auth.add_api_key("key1")
assert result is auth
result = auth.add_bearer("key2")
assert result is auth
result = auth.add_basic("key3")
assert result is auth
def test_get_all_methods(self):
"""Test getting all configured auth methods."""
auth = AuthConfig()
auth.add_api_key("api_key")
auth.add_bearer("bearer")
methods = auth.get_all_methods()
assert len(methods) == 2
assert "api_key" in methods
assert "bearer" in methods
def test_get_headers_api_key(self):
"""Test getting headers for API key auth."""
auth = AuthConfig()
auth.add_api_key("test", header_name="X-Custom-Key", api_key="mykey")
headers = auth.get_headers("test")
assert headers["X-Custom-Key"] == "mykey"
def test_get_headers_bearer(self):
"""Test getting headers for Bearer auth."""
auth = AuthConfig()
auth.add_bearer("test", token="mytoken", token_prefix="Bearer")
headers = auth.get_headers("test")
assert headers["Authorization"] == "Bearer mytoken"
def test_get_headers_basic(self):
"""Test getting headers for Basic auth."""
import base64
auth = AuthConfig()
auth.add_basic("test", username="user", password="pass")
headers = auth.get_headers("test")
expected = base64.b64encode(b"user:pass").decode()
assert headers["Authorization"] == f"Basic {expected}"
def test_get_headers_unconfigured_scheme_raises_error(self):
"""Test that getting headers for unconfigured scheme raises error."""
auth = AuthConfig()
with pytest.raises(AuthConfigError):
auth.get_headers("nonexistent")
def test_build_from_spec(self):
"""Test building auth config from OpenAPI security schemes."""
auth = AuthConfig()
security_schemes = {
"ApiKeyAuth": {"type": "apiKey", "name": "X-API-Key", "in": "header"},
"BearerAuth": {"type": "http", "scheme": "bearer"},
"BasicAuth": {"type": "http", "scheme": "basic"},
}
security_requirements = [
{"ApiKeyAuth": []},
{"BearerAuth": []},
]
auth.build_from_spec(security_schemes, security_requirements)
assert auth.get_auth_method("ApiKeyAuth") is not None
assert auth.get_auth_method("BearerAuth") is not None
assert auth.get_auth_method("BasicAuth") is None
def test_build_from_spec_missing_scheme_raises_error(self):
"""Test that missing security scheme raises error."""
auth = AuthConfig()
security_schemes = {
"ApiKeyAuth": {"type": "apiKey", "name": "X-API-Key", "in": "header"},
}
security_requirements = [
{"MissingScheme": []},
]
with pytest.raises(MissingSecuritySchemeError):
auth.build_from_spec(security_schemes, security_requirements)
def test_generate_pytest_auth_code(self):
"""Test generating pytest authentication code."""
auth = AuthConfig()
auth.add_api_key("test_key", header_name="X-Api-Key", api_key="key123")
code = auth.generate_auth_code("test_key", "pytest")
assert "X-Api-Key" in code
assert "key123" in code
def test_generate_jest_auth_code(self):
"""Test generating Jest authentication code."""
auth = AuthConfig()
auth.add_bearer("test_bearer", token="token123")
code = auth.generate_auth_code("test_bearer", "jest")
assert "Authorization" in code
assert "token123" in code
def test_generate_go_auth_code(self):
"""Test generating Go authentication code."""
auth = AuthConfig()
auth.add_api_key("test_key", header_name="X-Api-Key")
code = auth.generate_auth_code("test_key", "go")
assert "X-Api-Key" in code
def test_generate_auth_code_unconfigured_scheme(self):
"""Test generating auth code for unconfigured scheme returns empty."""
auth = AuthConfig()
code = auth.generate_auth_code("nonexistent", "pytest")
assert code == ""

View File

@@ -0,0 +1,223 @@
"""Unit tests for the generator modules."""
import pytest
import tempfile
import os
from pathlib import Path
from api_testgen.core import SpecParser
from api_testgen.generators import PytestGenerator, JestGenerator, GoGenerator
from api_testgen.mocks import MockServerGenerator
class TestPytestGenerator:
"""Tests for PytestGenerator class."""
@pytest.fixture
def parser(self, sample_openapi_spec):
"""Create a spec parser with sample spec."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
import yaml
yaml.dump(sample_openapi_spec, f)
parser = SpecParser(Path(f.name))
parser.load()
return parser
def test_generate_creates_file(self, parser, tmp_path):
"""Test that generate creates a test file."""
generator = PytestGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
assert len(files) == 1
assert files[0].exists()
assert files[0].suffix == ".py"
def test_generate_content_contains_tests(self, parser, tmp_path):
"""Test that generated file contains test functions."""
generator = PytestGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
content = files[0].read_text()
assert "test_" in content
assert "BASE_URL" in content
assert "def test_" in content
def test_generate_with_custom_output_file(self, parser, tmp_path):
"""Test generating to a specific file path."""
generator = PytestGenerator(parser, output_dir=str(tmp_path))
files = generator.generate(output_file=str(tmp_path / "custom_test.py"))
assert len(files) == 1
assert files[0].name == "custom_test.py"
def test_generate_endpoint_test(self, parser):
"""Test generating a single endpoint test."""
generator = PytestGenerator(parser)
test_code = generator.generate_endpoint_tests("/pets", "get")
assert "test_" in test_code
assert "requests.get" in test_code
def test_generate_test_name(self, parser):
"""Test test name generation."""
generator = PytestGenerator(parser)
name = generator._generate_test_name({
"path": "/pets",
"method": "get",
"summary": "List all pets",
})
assert name == "get_pets" or "pets" in name
class TestJestGenerator:
"""Tests for JestGenerator class."""
@pytest.fixture
def parser(self, sample_openapi_spec):
"""Create a spec parser with sample spec."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
import yaml
yaml.dump(sample_openapi_spec, f)
parser = SpecParser(Path(f.name))
parser.load()
return parser
def test_generate_creates_file(self, parser, tmp_path):
"""Test that generate creates a test file."""
generator = JestGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
assert len(files) == 1
assert files[0].exists()
assert files[0].suffix == ".js"
def test_generate_content_contains_describe(self, parser, tmp_path):
"""Test that generated file contains describe blocks."""
generator = JestGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
content = files[0].read_text()
assert "describe" in content
expect_in = "expect" in content or "toContain" in content or "toBe" in content
assert expect_in
def test_generate_endpoint_test(self, parser):
"""Test generating a single endpoint test."""
generator = JestGenerator(parser)
test_code = generator.generate_endpoint_tests("/pets", "get")
assert "describe" in test_code or "it(" in test_code
class TestGoGenerator:
"""Tests for GoGenerator class."""
@pytest.fixture
def parser(self, sample_openapi_spec):
"""Create a spec parser with sample spec."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
import yaml
yaml.dump(sample_openapi_spec, f)
parser = SpecParser(Path(f.name))
parser.load()
return parser
def test_generate_creates_file(self, parser, tmp_path):
"""Test that generate creates a test file."""
generator = GoGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
assert len(files) == 1
assert files[0].exists()
assert files[0].name.endswith("_test.go")
def test_generate_content_contains_tests(self, parser, tmp_path):
"""Test that generated file contains test functions."""
generator = GoGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
content = files[0].read_text()
assert "func Test" in content
assert "http.Client" in content or "http.NewRequest" in content
def test_generate_with_custom_package(self, parser, tmp_path):
"""Test generating with custom package name."""
generator = GoGenerator(parser, output_dir=str(tmp_path), package_name="custompkg")
files = generator.generate()
content = files[0].read_text()
assert "package custompkg" in content
class TestMockServerGenerator:
"""Tests for MockServerGenerator class."""
@pytest.fixture
def parser(self, sample_openapi_spec):
"""Create a spec parser with sample spec."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
import yaml
yaml.dump(sample_openapi_spec, f)
parser = SpecParser(Path(f.name))
parser.load()
return parser
def test_generate_prism_config(self, parser, tmp_path):
"""Test generating Prism configuration."""
generator = MockServerGenerator(parser, output_dir=str(tmp_path))
file = generator.generate_prism_config()
assert file.exists()
assert file.name == "prism-config.json"
content = file.read_text()
import json
config = json.loads(content)
assert "mock" in config
assert "port" in config["mock"]
def test_generate_docker_compose(self, parser, tmp_path):
"""Test generating Docker Compose configuration."""
generator = MockServerGenerator(parser, output_dir=str(tmp_path))
file = generator.generate_docker_compose()
assert file.exists()
assert file.name == "docker-compose.yml"
content = file.read_text()
assert "services" in content
assert "mock-server" in content
def test_generate_dockerfile(self, parser, tmp_path):
"""Test generating Dockerfile."""
generator = MockServerGenerator(parser, output_dir=str(tmp_path))
file = generator.generate_dockerfile()
assert file.exists()
assert file.name == "Dockerfile"
content = file.read_text()
assert "FROM" in content
assert "prism" in content.lower()
def test_generate_all(self, parser, tmp_path):
"""Test generating all mock server files."""
generator = MockServerGenerator(parser, output_dir=str(tmp_path))
files = generator.generate()
assert len(files) == 3
file_names = [f.name for f in files]
assert "prism-config.json" in file_names
assert "docker-compose.yml" in file_names
assert "Dockerfile" in file_names

View File

@@ -0,0 +1,173 @@
"""Unit tests for the spec parser module."""
import pytest
import tempfile
import json
from pathlib import Path
from api_testgen.core import SpecParser
from api_testgen.core.exceptions import InvalidOpenAPISpecError, UnsupportedVersionError
class TestSpecParser:
"""Tests for SpecParser class."""
def test_load_valid_openapi_30_spec(self, temp_spec_file):
"""Test loading a valid OpenAPI 3.0 specification."""
parser = SpecParser(temp_spec_file)
spec = parser.load()
assert spec is not None
assert parser.version == "3.0.0"
assert parser.base_path == ""
assert len(parser.servers) == 1
def test_load_valid_json_spec(self, temp_json_spec_file):
"""Test loading a valid JSON OpenAPI specification."""
parser = SpecParser(temp_json_spec_file)
spec = parser.load()
assert spec is not None
assert parser.version == "3.0.0"
def test_get_info(self, temp_spec_file):
"""Test extracting API info from spec."""
parser = SpecParser(temp_spec_file)
parser.load()
info = parser.get_info()
assert info["title"] == "Test Pet Store API"
assert info["version"] == "1.0.0"
assert "sample" in info["description"].lower()
def test_get_paths(self, temp_spec_file):
"""Test extracting paths from spec."""
parser = SpecParser(temp_spec_file)
parser.load()
paths = parser.get_paths()
assert "/pets" in paths
assert "/pets/{petId}" in paths
def test_get_endpoints(self, temp_spec_file):
"""Test extracting endpoints from spec."""
parser = SpecParser(temp_spec_file)
parser.load()
endpoints = parser.get_endpoints()
assert len(endpoints) == 4
assert any(e["method"] == "get" and e["path"] == "/pets" for e in endpoints)
assert any(e["method"] == "post" and e["path"] == "/pets" for e in endpoints)
assert any(e["method"] == "get" and e["path"] == "/pets/{petId}" for e in endpoints)
assert any(e["method"] == "delete" and e["path"] == "/pets/{petId}" for e in endpoints)
def test_get_security_schemes(self, temp_spec_file):
"""Test extracting security schemes from spec."""
parser = SpecParser(temp_spec_file)
parser.load()
schemes = parser.get_security_schemes()
assert "ApiKeyAuth" in schemes
assert "BearerAuth" in schemes
assert schemes["ApiKeyAuth"]["type"] == "apiKey"
assert schemes["BearerAuth"]["type"] == "http"
assert schemes["BearerAuth"]["scheme"] == "bearer"
def test_get_definitions(self, temp_spec_file):
"""Test extracting schema definitions from spec."""
parser = SpecParser(temp_spec_file)
parser.load()
definitions = parser.get_definitions()
assert "Pet" in definitions
assert definitions["Pet"]["type"] == "object"
def test_endpoint_with_parameters(self, temp_spec_file):
"""Test endpoint parameter extraction."""
parser = SpecParser(temp_spec_file)
parser.load()
endpoints = parser.get_endpoints()
pets_endpoint = next(e for e in endpoints if e["path"] == "/pets" and e["method"] == "get")
assert len(pets_endpoint["parameters"]) == 2
assert any(p["name"] == "limit" and p["in"] == "query" for p in pets_endpoint["parameters"])
assert any(p["name"] == "status" and p["in"] == "query" for p in pets_endpoint["parameters"])
def test_endpoint_with_path_parameters(self, temp_spec_file):
"""Test path parameter extraction."""
parser = SpecParser(temp_spec_file)
parser.load()
endpoints = parser.get_endpoints()
pet_endpoint = next(e for e in endpoints if e["path"] == "/pets/{petId}" and e["method"] == "get")
assert len(pet_endpoint["parameters"]) == 1
param = pet_endpoint["parameters"][0]
assert param["name"] == "petId"
assert param["in"] == "path"
assert param["required"] is True
def test_endpoint_with_request_body(self, temp_spec_file):
"""Test request body extraction for OpenAPI 3.0."""
parser = SpecParser(temp_spec_file)
parser.load()
endpoints = parser.get_endpoints()
create_endpoint = next(e for e in endpoints if e["path"] == "/pets" and e["method"] == "post")
assert create_endpoint["request_body"] is not None
assert create_endpoint["request_body"]["required"] is True
def test_endpoint_with_responses(self, temp_spec_file):
"""Test response extraction."""
parser = SpecParser(temp_spec_file)
parser.load()
endpoints = parser.get_endpoints()
pets_endpoint = next(e for e in endpoints if e["path"] == "/pets" and e["method"] == "get")
assert "200" in pets_endpoint["responses"]
assert pets_endpoint["responses"]["200"]["description"] == "A list of pets"
def test_to_dict(self, temp_spec_file):
"""Test dictionary representation of parsed spec."""
parser = SpecParser(temp_spec_file)
parser.load()
spec_dict = parser.to_dict()
assert "version" in spec_dict
assert "info" in spec_dict
assert "paths" in spec_dict
assert "endpoints" in spec_dict
assert "security_schemes" in spec_dict
assert "definitions" in spec_dict
def test_nonexistent_file_raises_error(self):
"""Test that nonexistent file raises InvalidOpenAPISpecError."""
parser = SpecParser("/nonexistent/path/spec.yaml")
with pytest.raises(InvalidOpenAPISpecError):
parser.load()
def test_invalid_yaml_raises_error(self):
"""Test that invalid YAML raises InvalidOpenAPISpecError."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write("invalid: yaml: content: [[[")
f.flush()
parser = SpecParser(Path(f.name))
with pytest.raises(InvalidOpenAPISpecError):
parser.load()
Path(f.name).unlink()