diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..53f9499 --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,62 @@ +name: CI + +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + + - name: Run unit tests + run: pytest tests/unit/ -v + + - name: Run integration tests + run: pytest tests/integration/ -v + + - name: Run tests with coverage + run: python -m pytest tests/ --cov=api_testgen + + - name: Upload coverage report + uses: codecov/codecov-action@v4 + with: + files: ./coverage.xml + fail_ci_if_error: false + + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install linting tools + run: pip install ruff mypy + + - name: Run ruff + run: ruff check . + + - name: Run mypy + run: mypy api_testgen/ --ignore-missing-imports diff --git a/api_testgen/__init__.py b/api_testgen/__init__.py new file mode 100644 index 0000000..0b050e4 --- /dev/null +++ b/api_testgen/__init__.py @@ -0,0 +1,3 @@ +"""API TestGen - OpenAPI Specification Test Generator.""" + +__version__ = "0.1.0" diff --git a/api_testgen/cli/__init__.py b/api_testgen/cli/__init__.py new file mode 100644 index 0000000..2bef01e --- /dev/null +++ b/api_testgen/cli/__init__.py @@ -0,0 +1 @@ +"""CLI module for API TestGen.""" diff --git a/api_testgen/cli/main.py b/api_testgen/cli/main.py new file mode 100644 index 0000000..d228999 --- /dev/null +++ b/api_testgen/cli/main.py @@ -0,0 +1,277 @@ +"""CLI interface for API TestGen.""" + +from pathlib import Path +from typing import Optional + +import click +import yaml + +from ..core import SpecParser, AuthConfig +from ..core.exceptions import InvalidOpenAPISpecError, UnsupportedVersionError +from ..generators import PytestGenerator, JestGenerator, GoGenerator +from ..mocks import MockServerGenerator + + +@click.group() +@click.version_option(version="0.1.0") +@click.option( + "--spec", + "-s", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + help="Path to OpenAPI specification file", +) +@click.option( + "--output", + "-o", + type=click.Path(file_okay=False, dir_okay=True), + help="Output directory for generated files", +) +@click.option( + "--mock-url", + default="http://localhost:4010", + help="URL of the mock server", +) +@click.pass_context +def main( + ctx: click.Context, + spec: str, + output: str, + mock_url: str, +): + """API TestGen - Generate integration tests from OpenAPI specifications.""" + ctx.ensure_object(dict) + ctx.obj["spec"] = spec + ctx.obj["output"] = output or "./generated" + ctx.obj["mock_url"] = mock_url + + +@main.command("parse") +@click.pass_context +def parse_spec(ctx: click.Context): + """Parse and validate an OpenAPI specification.""" + spec_path = ctx.obj["spec"] + + if not spec_path: + click.echo("Error: --spec option is required", err=True) + raise click.Abort() + + try: + parser = SpecParser(spec_path) + spec = parser.load() + + info = parser.get_info() + endpoints = parser.get_endpoints() + security_schemes = parser.get_security_schemes() + + click.echo(f"API: {info['title']} v{info['version']}") + click.echo(f"OpenAPI Version: {parser.version}") + click.echo(f"Base Path: {parser.base_path}") + click.echo(f"Endpoints: {len(endpoints)}") + click.echo(f"Security Schemes: {len(security_schemes)}") + click.echo() + + for endpoint in endpoints: + click.echo(f" {endpoint['method'].upper():6} {endpoint['path']}") + + ctx.obj["parser"] = parser + + except InvalidOpenAPISpecError as e: + click.echo(f"Error: {e}", err=True) + raise click.Abort() + except UnsupportedVersionError as e: + click.echo(f"Error: {e}", err=True) + raise click.Abort() + + +@main.command("generate") +@click.argument("framework", type=click.Choice(["pytest", "jest", "go"])) +@click.option( + "--output-file", + "-f", + type=click.Path(file_okay=True, dir_okay=False), + help="Specific output file path", +) +@click.option( + "--package-name", + default="apitest", + help="Go package name (only for go framework)", +) +@click.pass_context +def generate_tests( + ctx: click.Context, + framework: str, + output_file: str, + package_name: str, +): + """Generate test files for a framework (pytest, jest, or go).""" + spec_path = ctx.obj["spec"] + + if not spec_path: + click.echo("Error: --spec option is required", err=True) + raise click.Abort() + + output_dir = ctx.obj["output"] + mock_url = ctx.obj["mock_url"] + + try: + parser = SpecParser(spec_path) + parser.load() + + if framework == "pytest": + generator = PytestGenerator(parser, output_dir, mock_url) + files = generator.generate(output_file) + + elif framework == "jest": + generator = JestGenerator(parser, output_dir, mock_url) + files = generator.generate(output_file) + + elif framework == "go": + generator = GoGenerator(parser, output_dir, mock_url, package_name) + files = generator.generate(output_file) + + click.echo(f"Generated {len(files)} test file(s):") + for f in files: + click.echo(f" - {f}") + + except Exception as e: + click.echo(f"Error: {e}", err=True) + raise click.Abort() + + +@main.command("mock") +@click.option( + "--no-prism-config", + is_flag=True, + help="Skip generating prism-config.json", +) +@click.option( + "--no-docker-compose", + is_flag=True, + help="Skip generating docker-compose.yml", +) +@click.option( + "--no-dockerfile", + is_flag=True, + help="Skip generating Dockerfile", +) +@click.pass_context +def generate_mock( + ctx: click.Context, + no_prism_config: bool, + no_docker_compose: bool, + no_dockerfile: bool, +): + """Generate mock server configuration files.""" + spec_path = ctx.obj["spec"] + + if not spec_path: + click.echo("Error: --spec option is required", err=True) + raise click.Abort() + + output_dir = ctx.obj["output"] + + try: + parser = SpecParser(spec_path) + parser.load() + + generator = MockServerGenerator(parser, output_dir) + + files = generator.generate( + prism_config=not no_prism_config, + docker_compose=not no_docker_compose, + dockerfile=not no_dockerfile, + ) + + click.echo(f"Generated {len(files)} mock server file(s):") + for f in files: + click.echo(f" - {f}") + + except Exception as e: + click.echo(f"Error: {e}", err=True) + raise click.Abort() + + +@main.command("all") +@click.argument("framework", type=click.Choice(["pytest", "jest", "go"])) +@click.pass_context +def generate_all( + ctx: click.Context, + framework: str, +): + """Generate test files and mock server configuration.""" + spec_path = ctx.obj["spec"] + + if not spec_path: + click.echo("Error: --spec option is required", err=True) + raise click.Abort() + + output_dir = ctx.obj["output"] + mock_url = ctx.obj["mock_url"] + + try: + parser = SpecParser(spec_path) + parser.load() + + click.echo("Generating tests...") + if framework == "pytest": + generator = PytestGenerator(parser, output_dir, mock_url) + files = generator.generate() + + elif framework == "jest": + generator = JestGenerator(parser, output_dir, mock_url) + files = generator.generate() + + elif framework == "go": + generator = GoGenerator(parser, output_dir, mock_url) + files = generator.generate() + + click.echo(f"Generated {len(files)} test file(s)") + + click.echo("Generating mock server configuration...") + mock_generator = MockServerGenerator(parser, output_dir) + mock_files = mock_generator.generate() + + click.echo(f"Generated {len(mock_files)} mock server file(s)") + + click.echo("\nAll files generated successfully!") + + except Exception as e: + click.echo(f"Error: {e}", err=True) + raise click.Abort() + + +@main.command("auth") +@click.argument("scheme_name") +@click.option("--type", "auth_type", type=click.Choice(["apiKey", "bearer", "basic"]), help="Authentication type") +@click.option("--header", help="Header name for API key", default="X-API-Key") +@click.option("--token", help="Bearer token or API key value") +@click.option("--username", help="Username for Basic auth") +@click.option("--password", help="Password for Basic auth") +@click.pass_context +def configure_auth( + ctx: click.Context, + scheme_name: str, + auth_type: str, + header: str, + token: str, + username: str, + password: str, +): + """Configure authentication for a security scheme.""" + auth_config = AuthConfig() + + if auth_type == "apiKey": + auth_config.add_api_key(scheme_name, header, token or "") + elif auth_type == "bearer": + auth_config.add_bearer(scheme_name, token or "") + elif auth_type == "basic": + auth_config.add_basic(scheme_name, username or "", password or "") + + click.echo(f"Authentication scheme '{scheme_name}' configured:") + methods = auth_config.get_all_methods() + for name, method in methods.items(): + click.echo(f" - {name}: {method['type'].value}") + + +if __name__ == "__main__": + main() diff --git a/api_testgen/core/__init__.py b/api_testgen/core/__init__.py new file mode 100644 index 0000000..7718543 --- /dev/null +++ b/api_testgen/core/__init__.py @@ -0,0 +1,6 @@ +"""Core module for API TestGen.""" + +from .spec_parser import SpecParser +from .auth import AuthConfig + +__all__ = ["SpecParser", "AuthConfig"] diff --git a/api_testgen/core/auth.py b/api_testgen/core/auth.py new file mode 100644 index 0000000..71411cb --- /dev/null +++ b/api_testgen/core/auth.py @@ -0,0 +1,313 @@ +"""Authentication configuration for API TestGen.""" + +from typing import Any, Dict, List, Optional, Union +from enum import Enum + +from .exceptions import AuthConfigError, MissingSecuritySchemeError + + +class AuthType(str, Enum): + """Types of authentication.""" + API_KEY = "apiKey" + BEARER = "bearer" + BASIC = "basic" + NONE = "none" + + +class AuthConfig: + """Authentication configuration for generated tests.""" + + def __init__(self): + """Initialize authentication configuration.""" + self._auth_methods: Dict[str, Dict[str, Any]] = {} + + def add_api_key( + self, + scheme_name: str, + header_name: str = "X-API-Key", + api_key: str = "", + ) -> "AuthConfig": + """Add API key authentication. + + Args: + scheme_name: Name of the security scheme in OpenAPI spec. + header_name: Name of the header containing the API key. + api_key: The API key value (can be set later). + + Returns: + Self for method chaining. + """ + self._auth_methods[scheme_name] = { + "type": AuthType.API_KEY, + "header_name": header_name, + "api_key": api_key, + } + return self + + def add_bearer( + self, + scheme_name: str, + token: str = "", + token_prefix: str = "Bearer", + ) -> "AuthConfig": + """Add Bearer token authentication. + + Args: + scheme_name: Name of the security scheme in OpenAPI spec. + token: The Bearer token value (can be set later). + token_prefix: The token prefix (default: Bearer). + + Returns: + Self for method chaining. + """ + self._auth_methods[scheme_name] = { + "type": AuthType.BEARER, + "token": token, + "token_prefix": token_prefix, + } + return self + + def add_basic( + self, + scheme_name: str, + username: str = "", + password: str = "", + ) -> "AuthConfig": + """Add Basic authentication. + + Args: + scheme_name: Name of the security scheme in OpenAPI spec. + username: The username (can be set later). + password: The password (can be set later). + + Returns: + Self for method chaining. + """ + self._auth_methods[scheme_name] = { + "type": AuthType.BASIC, + "username": username, + "password": password, + } + return self + + def get_auth_method(self, scheme_name: str) -> Optional[Dict[str, Any]]: + """Get authentication method by scheme name. + + Args: + scheme_name: Name of the security scheme. + + Returns: + Authentication method configuration or None. + """ + return self._auth_methods.get(scheme_name) + + def get_all_methods(self) -> Dict[str, Dict[str, Any]]: + """Get all configured authentication methods. + + Returns: + Dictionary of scheme names and their configurations. + """ + return self._auth_methods.copy() + + def get_headers(self, scheme_name: str) -> Dict[str, str]: + """Get authentication headers for a scheme. + + Args: + scheme_name: Name of the security scheme. + + Returns: + Dictionary of header names and values. + + Raises: + AuthConfigError: If scheme is not configured. + """ + method = self.get_auth_method(scheme_name) + if not method: + raise AuthConfigError(f"Authentication scheme '{scheme_name}' not configured") + + if method["type"] == AuthType.API_KEY: + return {method["header_name"]: method["api_key"]} + elif method["type"] == AuthType.BEARER: + return {"Authorization": f"{method['token_prefix']} {method['token']}"} + elif method["type"] == AuthType.BASIC: + import base64 + credentials = f"{method['username']}:{method['password']}" + encoded = base64.b64encode(credentials.encode()).decode() + return {"Authorization": f"Basic {encoded}"} + return {} + + def build_from_spec( + self, + security_schemes: Dict[str, Any], + security_requirements: List[Dict[str, Any]], + ) -> "AuthConfig": + """Build auth configuration from OpenAPI security schemes. + + Args: + security_schemes: Security schemes from OpenAPI spec. + security_requirements: Security requirements from endpoint. + + Returns: + Self for method chaining. + + Raises: + MissingSecuritySchemeError: If required scheme is not defined. + """ + for requirement in security_requirements: + for scheme_name in requirement.keys(): + if scheme_name not in self._auth_methods: + if scheme_name not in security_schemes: + raise MissingSecuritySchemeError( + f"Security scheme '{scheme_name}' not found in spec" + ) + scheme = security_schemes[scheme_name] + self._add_scheme_from_spec(scheme_name, scheme) + + return self + + def _add_scheme_from_spec(self, scheme_name: str, scheme: Dict[str, Any]) -> None: + """Add authentication scheme from OpenAPI spec definition. + + Args: + scheme_name: Name of the security scheme. + scheme: The security scheme definition from OpenAPI spec. + """ + scheme_type = scheme.get("type", "") + + if scheme_type == "apiKey": + self.add_api_key( + scheme_name, + header_name=scheme.get("name", "X-API-Key"), + ) + elif scheme_type == "http": + scheme_scheme = scheme.get("scheme", "").lower() + if scheme_scheme == "bearer": + self.add_bearer(scheme_name) + elif scheme_scheme == "basic": + self.add_basic(scheme_name) + elif scheme_type == "openIdConnect": + self.add_bearer(scheme_name) + elif scheme_type == "oauth2": + self.add_bearer(scheme_name) + + def generate_auth_code(self, scheme_name: str, framework: str = "pytest") -> str: + """Generate authentication code for a test framework. + + Args: + scheme_name: Name of the security scheme. + framework: Target test framework (pytest, jest, go). + + Returns: + String containing authentication code snippet. + """ + method = self.get_auth_method(scheme_name) + if not method: + return "" + + if framework == "pytest": + return self._generate_pytest_auth(method) + elif framework == "jest": + return self._generate_jest_auth(method) + elif framework == "go": + return self._generate_go_auth(method) + return "" + + def _generate_pytest_auth(self, method: Dict[str, Any]) -> str: + """Generate pytest authentication code. + + Args: + method: Authentication method configuration. + + Returns: + String containing pytest auth code. + """ + if method["type"] == AuthType.API_KEY: + return f''' +@pytest.fixture +def api_key_headers(): + return {{"{method['header_name']}": "{method['api_key']}"}} +''' + elif method["type"] == AuthType.BEARER: + return f''' +@pytest.fixture +def bearer_headers(): + return {{"Authorization": "{method['token_prefix']} {method['token']}"}} +''' + elif method["type"] == AuthType.BASIC: + return f''' +import base64 + +@pytest.fixture +def basic_headers(): + credentials = f"{{"{method['username']}"}}:{{"{method['password']}"}}" + encoded = base64.b64encode(credentials.encode()).decode() + return {{"Authorization": f"Basic {{encoded}}"}} +''' + return "" + + def _generate_jest_auth(self, method: Dict[str, Any]) -> str: + """Generate Jest authentication code. + + Args: + method: Authentication method configuration. + + Returns: + String containing Jest auth code. + """ + if method["type"] == AuthType.API_KEY: + return f''' +const getApiKeyHeaders = () => ({{ + "{method['header_name']}": process.env.API_KEY || "{method['api_key']}", +}}); +''' + elif method["type"] == AuthType.BEARER: + return f''' +const getBearerHeaders = () => ({{ + Authorization: `${{process.env.TOKEN_PREFIX || "{method['token_prefix']}"}} ${{process.env.TOKEN || "{method['token']}"}}`, +}}); +''' + elif method["type"] == AuthType.BASIC: + return f''' +const getBasicHeaders = () => {{ + const credentials = Buffer.from(`${{process.env.USERNAME || "{method['username']}"}}:${{process.env.PASSWORD || "{method['password']}"}}`).toString('base64'); + return {{ Authorization: `Basic ${{credentials}}` }}; +}}; +''' + return "" + + def _generate_go_auth(self, method: Dict[str, Any]) -> str: + """Generate Go authentication code. + + Args: + method: Authentication method configuration. + + Returns: + String containing Go auth code. + """ + if method["type"] == AuthType.API_KEY: + return f''' +func getAPIKeyHeaders() map[string]string {{ + return map[string]string{{ + "{method['header_name']}": os.Getenv("API_KEY"), + }} +}} +''' + elif method["type"] == AuthType.BEARER: + return f''' +func getBearerHeaders() map[string]string {{ + return map[string]string{{ + "Authorization": fmt.Sprintf("%s %s", os.Getenv("TOKEN_PREFIX"), os.Getenv("TOKEN")), + }} +}} +''' + elif method["type"] == AuthType.BASIC: + return f''' +func getBasicHeaders(username, password string) map[string]string {{ + auth := username + ":" + password + encoded := base64.StdEncoding.EncodeToString([]byte(auth)) + return map[string]string{{ + "Authorization": "Basic " + encoded, + }} +}} +''' + return "" diff --git a/api_testgen/core/exceptions.py b/api_testgen/core/exceptions.py new file mode 100644 index 0000000..5f55cc4 --- /dev/null +++ b/api_testgen/core/exceptions.py @@ -0,0 +1,36 @@ +"""Custom exceptions for API TestGen.""" + + +class SpecParserError(Exception): + """Base exception for spec parser errors.""" + pass + + +class InvalidOpenAPISpecError(SpecParserError): + """Raised when OpenAPI specification is invalid.""" + pass + + +class UnsupportedVersionError(SpecParserError): + """Raised when OpenAPI version is not supported.""" + pass + + +class AuthConfigError(Exception): + """Base exception for auth configuration errors.""" + pass + + +class MissingSecuritySchemeError(AuthConfigError): + """Raised when security scheme is not defined in spec.""" + pass + + +class GeneratorError(Exception): + """Base exception for generator errors.""" + pass + + +class TemplateRenderError(GeneratorError): + """Raised when template rendering fails.""" + pass diff --git a/api_testgen/core/spec_parser.py b/api_testgen/core/spec_parser.py new file mode 100644 index 0000000..4da87db --- /dev/null +++ b/api_testgen/core/spec_parser.py @@ -0,0 +1,307 @@ +"""OpenAPI Specification Parser.""" + +import yaml +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from openapi_spec_validator import validate +from openapi_spec_validator.versions import consts as validator_consts + +from .exceptions import InvalidOpenAPISpecError, UnsupportedVersionError + + +class SpecParser: + """Parse and validate OpenAPI/Swagger specifications.""" + + SUPPORTED_VERSIONS = ["2.0", "3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"] + + def __init__(self, spec_path: Union[str, Path]): + """Initialize the spec parser. + + Args: + spec_path: Path to OpenAPI specification file (YAML or JSON). + """ + self.spec_path = Path(spec_path) + self.spec: Dict[str, Any] = {} + self.version: str = "" + self.base_path: str = "" + self.servers: List[Dict[str, str]] = [] + + def load(self) -> Dict[str, Any]: + """Load and validate the OpenAPI specification. + + Returns: + The parsed specification dictionary. + + Raises: + InvalidOpenAPISpecError: If the specification is invalid. + UnsupportedVersionError: If the OpenAPI version is not supported. + """ + self.spec = self._load_file() + self._validate() + self._extract_metadata() + return self.spec + + def _load_file(self) -> Dict[str, Any]: + """Load the specification file. + + Returns: + Parsed specification dictionary. + + Raises: + InvalidOpenAPISpecError: If the file cannot be loaded or parsed. + """ + if not self.spec_path.exists(): + raise InvalidOpenAPISpecError(f"Specification file not found: {self.spec_path}") + + try: + with open(self.spec_path, 'r', encoding='utf-8') as f: + if self.spec_path.suffix in ['.yaml', '.yml']: + return yaml.safe_load(f) or {} + elif self.spec_path.suffix == '.json': + return json.load(f) + else: + return yaml.safe_load(f) or {} + except (yaml.YAMLError, json.JSONDecodeError) as e: + raise InvalidOpenAPISpecError(f"Failed to parse specification: {e}") + + def _validate(self) -> None: + """Validate the specification. + + Raises: + InvalidOpenAPISpecError: If the specification is invalid. + UnsupportedVersionError: If the OpenAPI version is not supported. + """ + try: + validate(self.spec) + except Exception as e: + raise InvalidOpenAPISpecError(f"Invalid OpenAPI specification: {e}") + + version = self._get_version() + if version not in self.SUPPORTED_VERSIONS: + raise UnsupportedVersionError( + f"Unsupported OpenAPI version: {version}. " + f"Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}" + ) + + def _get_version(self) -> str: + """Extract the OpenAPI version from the spec. + + Returns: + The OpenAPI version string. + """ + if "openapi" in self.spec: + return self.spec["openapi"] + elif "swagger" in self.spec: + return self.spec["swagger"] + return "2.0" + + def _extract_metadata(self) -> None: + """Extract metadata from the specification.""" + self.version = self._get_version() + self.base_path = self.spec.get("basePath", "") + self.servers = self.spec.get("servers", []) + + def get_paths(self) -> Dict[str, Any]: + """Get all paths from the specification. + + Returns: + Dictionary of paths and their operations. + """ + return self.spec.get("paths", {}) + + def get_endpoints(self) -> List[Dict[str, Any]]: + """Extract all endpoints from the specification. + + Returns: + List of endpoint dictionaries with path, method, and details. + """ + endpoints = [] + paths = self.get_paths() + + for path, path_item in paths.items(): + for method, operation in path_item.items(): + if method.lower() in ["get", "post", "put", "patch", "delete", "options", "head"]: + endpoint = { + "path": path, + "method": method.lower(), + "operation_id": operation.get("operationId", ""), + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "tags": operation.get("tags", []), + "parameters": self._extract_parameters(path_item, operation), + "request_body": self._extract_request_body(operation), + "responses": self._extract_responses(operation), + "security": self._extract_security(operation), + } + endpoints.append(endpoint) + + return endpoints + + def _extract_parameters(self, path_item: Dict, operation: Dict) -> List[Dict[str, Any]]: + """Extract parameters from path item and operation. + + Args: + path_item: The path item dictionary. + operation: The operation dictionary. + + Returns: + List of parameter dictionaries. + """ + parameters = [] + + for param in path_item.get("parameters", []): + if param.get("in") != "body": + parameters.append(self._normalize_parameter(param)) + + for param in operation.get("parameters", []): + if param.get("in") != "body": + parameters.append(self._normalize_parameter(param)) + + return parameters + + def _normalize_parameter(self, param: Dict[str, Any]) -> Dict[str, Any]: + """Normalize a parameter. + + Args: + param: The parameter dictionary. + + Returns: + Normalized parameter dictionary. + """ + return { + "name": param.get("name", ""), + "in": param.get("in", ""), + "description": param.get("description", ""), + "required": param.get("required", False), + "schema": param.get("schema", {}), + "type": param.get("type", ""), + "enum": param.get("enum", []), + "default": param.get("default"), + } + + def _extract_request_body(self, operation: Dict) -> Optional[Dict[str, Any]]: + """Extract request body from operation. + + Args: + operation: The operation dictionary. + + Returns: + Request body dictionary or None. + """ + if self.version.startswith("3."): + request_body = operation.get("requestBody", {}) + if not request_body: + return None + + content = request_body.get("content", {}) + media_types = list(content.keys()) + + return { + "description": request_body.get("description", ""), + "required": request_body.get("required", False), + "media_types": media_types, + "schema": content.get(media_types[0], {}).get("schema", {}) if media_types else {}, + } + else: + params = operation.get("parameters", []) + for param in params: + if param.get("in") == "body": + return { + "description": param.get("description", ""), + "required": param.get("required", False), + "schema": param.get("schema", {}), + } + return None + + def _extract_responses(self, operation: Dict) -> Dict[str, Any]: + """Extract responses from operation. + + Args: + operation: The operation dictionary. + + Returns: + Dictionary of response status codes and their details. + """ + responses = {} + + for status_code, response in operation.get("responses", {}).items(): + content = response.get("content", {}) + + if self.version.startswith("3."): + media_types = list(content.keys()) + schema = content.get(media_types[0], {}).get("schema", {}) if media_types else {} + else: + schema = response.get("schema", {}) + + responses[status_code] = { + "description": response.get("description", ""), + "schema": schema, + "media_type": list(content.keys())[0] if content else "application/json", + } + + return responses + + def _extract_security(self, operation: Dict) -> List[Dict[str, Any]]: + """Extract security requirements from operation. + + Args: + operation: The operation dictionary. + + Returns: + List of security requirement dictionaries. + """ + return operation.get("security", self.spec.get("security", [])) + + def get_security_schemes(self) -> Dict[str, Any]: + """Get security schemes from the specification. + + Returns: + Dictionary of security scheme names and their definitions. + """ + if self.version.startswith("3."): + return self.spec.get("components", {}).get("securitySchemes", {}) + else: + return self.spec.get("securityDefinitions", {}) + + def get_definitions(self) -> Dict[str, Any]: + """Get schema definitions from the specification. + + Returns: + Dictionary of schema definitions. + """ + if self.version.startswith("3."): + return self.spec.get("components", {}).get("schemas", {}) + else: + return self.spec.get("definitions", {}) + + def get_info(self) -> Dict[str, str]: + """Get API info from the specification. + + Returns: + Dictionary with title, version, and description. + """ + info = self.spec.get("info", {}) + return { + "title": info.get("title", "API"), + "version": info.get("version", "1.0.0"), + "description": info.get("description", ""), + } + + def to_dict(self) -> Dict[str, Any]: + """Convert the spec to a dictionary. + + Returns: + Dictionary representation of the parsed spec. + """ + return { + "version": self.version, + "base_path": self.base_path, + "servers": self.servers, + "info": self.get_info(), + "paths": self.get_paths(), + "endpoints": self.get_endpoints(), + "security_schemes": self.get_security_schemes(), + "definitions": self.get_definitions(), + } diff --git a/api_testgen/generators/__init__.py b/api_testgen/generators/__init__.py new file mode 100644 index 0000000..ec17d19 --- /dev/null +++ b/api_testgen/generators/__init__.py @@ -0,0 +1,7 @@ +"""Generators module for API TestGen.""" + +from .pytest import PytestGenerator +from .jest import JestGenerator +from .go import GoGenerator + +__all__ = ["PytestGenerator", "JestGenerator", "GoGenerator"] diff --git a/api_testgen/generators/go.py b/api_testgen/generators/go.py new file mode 100644 index 0000000..dba899b --- /dev/null +++ b/api_testgen/generators/go.py @@ -0,0 +1,229 @@ +"""Go test generator.""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError + +from ..core import SpecParser, AuthConfig +from ..core.exceptions import GeneratorError, TemplateRenderError + + +class GoGenerator: + """Generate Go-compatible test files.""" + + def __init__( + self, + spec_parser: SpecParser, + output_dir: str = "tests", + mock_server_url: str = "http://localhost:4010", + package_name: str = "apitest", + ): + """Initialize the go generator. + + Args: + spec_parser: The OpenAPI specification parser. + output_dir: Directory for generated test files. + mock_server_url: URL of the mock server for testing. + package_name: Go package name for test files. + """ + self.spec_parser = spec_parser + self.output_dir = Path(output_dir) + self.mock_server_url = mock_server_url + self.package_name = package_name + self.env = Environment( + loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "go")), + trim_blocks=True, + lstrip_blocks=True, + ) + + def generate(self, output_file: Optional[str] = None) -> List[Path]: + """Generate Go test files. + + Args: + output_file: Optional specific output file path. + + Returns: + List of generated file paths. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + endpoints = self.spec_parser.get_endpoints() + info = self.spec_parser.get_info() + + grouped_endpoints = self._group_endpoints_by_path(endpoints) + + context = { + "package_name": self.package_name, + "api_title": info["title"], + "api_version": info["version"], + "endpoints": endpoints, + "grouped_endpoints": grouped_endpoints, + "mock_server_url": self.mock_server_url, + "security_schemes": self.spec_parser.get_security_schemes(), + "definitions": self.spec_parser.get_definitions(), + } + + generated_files = [] + + try: + template = self.env.get_template("api_test.go.j2") + content = template.render(context) + + if output_file: + output_path = Path(output_file) + else: + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) + output_path = self.output_dir / f"{safe_name}_test.go" + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + generated_files.append(output_path) + + except (TemplateSyntaxError, UndefinedError) as e: + raise TemplateRenderError(f"Failed to render Go template: {e}") + + return generated_files + + def _group_endpoints_by_path(self, endpoints: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """Group endpoints by their path. + + Args: + endpoints: List of endpoint dictionaries. + + Returns: + Dictionary mapping paths to their endpoints. + """ + grouped = {} + for endpoint in endpoints: + path = endpoint["path"] + if path not in grouped: + grouped[path] = [] + grouped[path].append(endpoint) + return grouped + + def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: + """Generate test for a specific endpoint. + + Args: + endpoint_path: The API endpoint path. + method: The HTTP method. + + Returns: + String containing the test code. + """ + endpoints = self.spec_parser.get_endpoints() + + for endpoint in endpoints: + if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): + return self._generate_single_test(endpoint) + + return "" + + def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: + """Generate test code for a single endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing the test code. + """ + test_name = self._generate_test_name(endpoint) + params = self._generate_params(endpoint) + url_params = self._generate_url_params(endpoint) + + test_code = f''' +func Test{test_name}(t *testing.T) {{ + client := &http.Client{{Timeout: 10 * time.Second}} + url := baseURL + "{endpoint['path']}" + + var req *http.Request + var err error + + {params} + + req, err = http.NewRequest("{endpoint['method'].upper()}", url, nil) + if err != nil {{ + t.Fatal(err) + }} + + for k, v := range getAuthHeaders() {{ + req.Header.Set(k, v) + }} + + resp, err := client.Do(req) + if err != nil {{ + t.Fatalf("Request failed: %v", err) + }} + defer resp.Body.Close() + + if !contains([]int{{200, 201, 204}}, resp.StatusCode) {{ + t.Errorf("Expected status code in [200, 201, 204], got %d", resp.StatusCode) + }} +}} +''' + return test_code + + def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: + """Generate a valid Go test function name. + + Args: + endpoint: The endpoint dictionary. + + Returns: + A valid Go function name. + """ + path = endpoint["path"] + method = endpoint["method"] + + name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) + name = re.sub(r"_+/", "_", name) + name = re.sub(r"_+$", "", name) + name = name.title().replace("_", "") + + return f"{method.capitalize()}{name}" if name else f"{method.capitalize()}Default" + + def _generate_params(self, endpoint: Dict[str, Any]) -> str: + """Generate parameter variables for test. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing parameter declarations. + """ + params = [] + + for param in endpoint.get("parameters", []): + param_name = param["name"] + + if param["in"] == "path": + params.append(f'{param_name} := "test_{param_name}"') + params.append(f'url = strings.Replace(url, "{{'+param_name+'}}", {param_name}, 1)') + + elif param["in"] == "query": + params.append(f'q := url.Values{{{param_name}: []string{{"test"}}}}') + params.append(f'url += "?" + q.Encode()') + + return "\n ".join(params) if params else "" + + def _generate_url_params(self, endpoint: Dict[str, Any]) -> str: + """Generate URL parameters. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing URL parameter handling. + """ + path_params = [p for p in endpoint.get("parameters", []) if p["in"] == "path"] + query_params = [p for p in endpoint.get("parameters", []) if p["in"] == "query"] + + parts = [] + + for param in path_params: + parts.append(f'strings.Replace(url, "{{'+param['name']+'}}", "test_' + param['name'] + '", 1)') + + return "" diff --git a/api_testgen/generators/jest.py b/api_testgen/generators/jest.py new file mode 100644 index 0000000..5798577 --- /dev/null +++ b/api_testgen/generators/jest.py @@ -0,0 +1,169 @@ +"""Jest test generator.""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError + +from ..core import SpecParser, AuthConfig +from ..core.exceptions import GeneratorError, TemplateRenderError + + +class JestGenerator: + """Generate Jest-compatible integration test templates.""" + + def __init__( + self, + spec_parser: SpecParser, + output_dir: str = "tests", + mock_server_url: str = "http://localhost:4010", + ): + """Initialize the jest generator. + + Args: + spec_parser: The OpenAPI specification parser. + output_dir: Directory for generated test files. + mock_server_url: URL of the mock server for testing. + """ + self.spec_parser = spec_parser + self.output_dir = Path(output_dir) + self.mock_server_url = mock_server_url + self.env = Environment( + loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "jest")), + trim_blocks=True, + lstrip_blocks=True, + ) + + def generate(self, output_file: Optional[str] = None) -> List[Path]: + """Generate Jest test files. + + Args: + output_file: Optional specific output file path. + + Returns: + List of generated file paths. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + endpoints = self.spec_parser.get_endpoints() + info = self.spec_parser.get_info() + + context = { + "api_title": info["title"], + "api_version": info["version"], + "endpoints": endpoints, + "mock_server_url": self.mock_server_url, + "security_schemes": self.spec_parser.get_security_schemes(), + "definitions": self.spec_parser.get_definitions(), + } + + generated_files = [] + + try: + template = self.env.get_template("api.test.js.j2") + content = template.render(context) + + if output_file: + output_path = Path(output_file) + else: + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) + output_path = self.output_dir / f"{safe_name}.test.js" + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + generated_files.append(output_path) + + except (TemplateSyntaxError, UndefinedError) as e: + raise TemplateRenderError(f"Failed to render Jest template: {e}") + + return generated_files + + def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: + """Generate test for a specific endpoint. + + Args: + endpoint_path: The API endpoint path. + method: The HTTP method. + + Returns: + String containing the test code. + """ + endpoints = self.spec_parser.get_endpoints() + + for endpoint in endpoints: + if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): + return self._generate_single_test(endpoint) + + return "" + + def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: + """Generate test code for a single endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing the test code. + """ + test_name = self._generate_test_name(endpoint) + describe_name = endpoint["summary"] or endpoint["path"] + + params = self._generate_params(endpoint) + + endpoint_path = endpoint["path"] + endpoint_method = endpoint["method"] + + test_code = f''' +describe('{describe_name}', () => {{ + it('should {endpoint_method.upper()} {endpoint_path}', async () => {{ + const response = await request(baseUrl) + .{endpoint_method}('{endpoint_path}'{params}); + + expect([200, 201, 204]).toContain(response.status); + }}); +}}); +''' + return test_code + + def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: + """Generate a valid test function name. + + Args: + endpoint: The endpoint dictionary. + + Returns: + A valid JavaScript function name. + """ + path = endpoint["path"] + method = endpoint["method"] + + name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) + name = re.sub(r"_+/", "_", name) + name = re.sub(r"_+$", "", name) + + return f"{method}_{name}" if name else f"{method}_default" + + def _generate_params(self, endpoint: Dict[str, Any]) -> str: + """Generate parameters for test request. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing parameter chain. + """ + parts = [] + + for param in endpoint.get("parameters", []): + param_name = param["name"] + + if param["in"] == "path": + parts.append(f'{param_name}="test_{param_name}"') + + elif param["in"] == "query": + parts.append(f'{param_name}') + + if parts: + return ", {" + ", ".join(parts) + "}" + return "" diff --git a/api_testgen/generators/pytest.py b/api_testgen/generators/pytest.py new file mode 100644 index 0000000..4a00d61 --- /dev/null +++ b/api_testgen/generators/pytest.py @@ -0,0 +1,199 @@ +"""Pytest test generator.""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +from jinja2 import Environment, FileSystemLoader, TemplateSyntaxError, UndefinedError + +from ..core import SpecParser, AuthConfig +from ..core.exceptions import GeneratorError, TemplateRenderError + + +class PytestGenerator: + """Generate pytest-compatible integration test templates.""" + + def __init__( + self, + spec_parser: SpecParser, + output_dir: str = "tests", + mock_server_url: str = "http://localhost:4010", + ): + """Initialize the pytest generator. + + Args: + spec_parser: The OpenAPI specification parser. + output_dir: Directory for generated test files. + mock_server_url: URL of the mock server for testing. + """ + self.spec_parser = spec_parser + self.output_dir = Path(output_dir) + self.mock_server_url = mock_server_url + self.env = Environment( + loader=FileSystemLoader(str(Path(__file__).parent.parent.parent / "templates" / "pytest")), + trim_blocks=True, + lstrip_blocks=True, + ) + + def generate(self, output_file: Optional[str] = None) -> List[Path]: + """Generate pytest test files. + + Args: + output_file: Optional specific output file path. + + Returns: + List of generated file paths. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + endpoints = self.spec_parser.get_endpoints() + info = self.spec_parser.get_info() + + context = { + "api_title": info["title"], + "api_version": info["version"], + "endpoints": endpoints, + "mock_server_url": self.mock_server_url, + "security_schemes": self.spec_parser.get_security_schemes(), + "definitions": self.spec_parser.get_definitions(), + } + + generated_files = [] + + try: + template = self.env.get_template("test_base.py.j2") + content = template.render(context) + + if output_file: + output_path = Path(output_file) + else: + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", info["title"].lower()) + output_path = self.output_dir / f"test_{safe_name}.py" + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + generated_files.append(output_path) + + except (TemplateSyntaxError, UndefinedError) as e: + raise TemplateRenderError(f"Failed to render pytest template: {e}") + + return generated_files + + def generate_endpoint_tests(self, endpoint_path: str, method: str) -> str: + """Generate test for a specific endpoint. + + Args: + endpoint_path: The API endpoint path. + method: The HTTP method. + + Returns: + String containing the test code. + """ + endpoints = self.spec_parser.get_endpoints() + + for endpoint in endpoints: + if endpoint["path"] == endpoint_path and endpoint["method"] == method.lower(): + return self._generate_single_test(endpoint) + + return "" + + def _generate_single_test(self, endpoint: Dict[str, Any]) -> str: + """Generate test code for a single endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing the test code. + """ + test_name = self._generate_test_name(endpoint) + params = self._generate_parameters(endpoint) + auth_headers = self._generate_auth_headers(endpoint) + + test_code = f''' +def test_{test_name}(base_url, {params}): + """Test {endpoint["summary"] or endpoint["path"]} endpoint.""" + url = f"{{base_url}}{endpoint['path']}" + + headers = {{"Content-Type": "application/json"}} + {auth_headers} + + response = requests.{endpoint['method']}(url, json={{}} if method == "POST" else None, headers=headers) + + assert response.status_code in [200, 201, 204] +''' + return test_code + + def _generate_test_name(self, endpoint: Dict[str, Any]) -> str: + """Generate a valid test function name. + + Args: + endpoint: The endpoint dictionary. + + Returns: + A valid Python function name. + """ + path = endpoint["path"] + method = endpoint["method"] + + name = re.sub(r"[^a-zA-Z0-9]", "_", path.strip("/")) + name = re.sub(r"_+/", "_", name) + name = re.sub(r"_+$", "", name) + + return f"{method}_{name}" if name else f"{method}_default" + + def _generate_parameters(self, endpoint: Dict[str, Any]) -> str: + """Generate parameters for test function. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing parameter declarations. + """ + params = [] + + for param in endpoint.get("parameters", []): + param_name = param["name"] + + if param["in"] == "path": + params.append(f'{param_name}="test_{param_name}"') + + elif param["in"] == "query": + params.append(f'{param_name}=None') + + return ", ".join(params) + + def _generate_auth_headers(self, endpoint: Dict[str, Any]) -> str: + """Generate authentication headers for endpoint. + + Args: + endpoint: The endpoint dictionary. + + Returns: + String containing header assignments. + """ + security_requirements = endpoint.get("security", []) + + if not security_requirements: + return "" + + schemes = self.spec_parser.get_security_schemes() + auth_config = AuthConfig() + + try: + auth_config.build_from_spec(schemes, security_requirements) + except Exception: + return "" + + headers = [] + + for scheme_name in security_requirements[0].keys(): + method = auth_config.get_auth_method(scheme_name) + if method: + if method["type"] == "apiKey": + headers.append(f'headers["{method["header_name"]}"] = "test_api_key"') + elif method["type"] == "bearer": + headers.append('headers["Authorization"] = "Bearer test_token"') + + return "\n ".join(headers) if headers else "" diff --git a/api_testgen/mocks/__init__.py b/api_testgen/mocks/__init__.py new file mode 100644 index 0000000..e828c9f --- /dev/null +++ b/api_testgen/mocks/__init__.py @@ -0,0 +1,5 @@ +"""Mocks module for API TestGen.""" + +from .generator import MockServerGenerator + +__all__ = ["MockServerGenerator"] diff --git a/api_testgen/mocks/generator.py b/api_testgen/mocks/generator.py new file mode 100644 index 0000000..4c0e099 --- /dev/null +++ b/api_testgen/mocks/generator.py @@ -0,0 +1,278 @@ +"""Mock server generator for Prism/OpenAPI mock support.""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ..core import SpecParser + + +class MockServerGenerator: + """Generate Prism/OpenAPI mock server configurations.""" + + DEFAULT_PORT = 4010 + DEFAULT_HOST = "0.0.0.0" + + def __init__( + self, + spec_parser: SpecParser, + output_dir: str = ".", + ): + """Initialize the mock server generator. + + Args: + spec_parser: The OpenAPI specification parser. + output_dir: Directory for generated configuration files. + """ + self.spec_parser = spec_parser + self.output_dir = Path(output_dir) + + def generate( + self, + prism_config: bool = True, + docker_compose: bool = True, + dockerfile: bool = True, + ) -> List[Path]: + """Generate mock server configuration files. + + Args: + prism_config: Whether to generate prism-config.json. + docker_compose: Whether to generate docker-compose.yml. + dockerfile: Whether to generate Dockerfile. + + Returns: + List of generated file paths. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + generated_files = [] + + if prism_config: + generated_files.append(self.generate_prism_config()) + + if docker_compose: + generated_files.append(self.generate_docker_compose()) + + if dockerfile: + generated_files.append(self.generate_dockerfile()) + + return generated_files + + def generate_prism_config(self, output_file: Optional[str] = None) -> Path: + """Generate Prism mock server configuration. + + Args: + output_file: Optional output file path. + + Returns: + Path to the generated file. + """ + if output_file: + output_path = Path(output_file) + else: + output_path = self.output_dir / "prism-config.json" + + config = { + "mock": { + "host": self.DEFAULT_HOST, + "port": self.DEFAULT_PORT, + "cors": { + "enabled": True, + "allowOrigin": "*", + "allowMethods": ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"], + "allowHeaders": ["Content-Type", "Authorization", "X-API-Key"], + }, + "Validation": { + "request": True, + "response": True, + }, + }, + "logging": { + "level": "info", + "format": "json", + }, + } + + output_path.write_text(json.dumps(config, indent=2)) + return output_path + + def generate_docker_compose(self, output_file: Optional[str] = None) -> Path: + """Generate Docker Compose configuration for mock server. + + Args: + output_file: Optional output file path. + + Returns: + Path to the generated file. + """ + if output_file: + output_path = Path(output_file) + else: + output_path = self.output_dir / "docker-compose.yml" + + spec_info = self.spec_parser.get_info() + + compose_content = f'''version: '3.8' + +services: + mock-server: + image: stoplight/prism:latest + container_name: "{spec_info['title']}-mock-server" + command: > + mock + --spec /app/openapi.yaml + --port {self.DEFAULT_PORT} + --host {self.DEFAULT_HOST} + ports: + - "{self.DEFAULT_PORT}:{self.DEFAULT_PORT}" + volumes: + - ./:/app + restart: unless-stopped + healthcheck: + test: ["CMD", "wget", "-q", "--spider", f"http://localhost:{self.DEFAULT_PORT}/health"] + interval: 30s + timeout: 10s + retries: 3 + + mock-server-https: + image: stoplight/prism:latest + container_name: "{spec_info['title']}-mock-server-https" + command: > + mock + --spec /app/openapi.yaml + --port {self.DEFAULT_PORT + 1} + --host {self.DEFAULT_HOST} + ports: + - "{self.DEFAULT_PORT + 1}:{self.DEFAULT_PORT + 1}" + volumes: + - ./:/app + restart: unless-stopped +''' + + output_path.write_text(compose_content) + return output_path + + def generate_dockerfile(self, output_file: Optional[str] = None) -> Path: + """Generate Dockerfile for mock server. + + Args: + output_file: Optional output file path. + + Returns: + Path to the generated file. + """ + if output_file: + output_path = Path(output_file) + else: + output_path = self.output_dir / "Dockerfile" + + spec_info = self.spec_parser.get_info() + + dockerfile_content = f'''FROM stoplight/prism:latest + +LABEL maintainer="developer@example.com" +LABEL description="Mock server for {spec_info['title']} API" + +WORKDIR /app + +COPY openapi.yaml . + +EXPOSE {self.DEFAULT_PORT} + +CMD ["mock", "--spec", "openapi.yaml", "--port", "{self.DEFAULT_PORT}", "--host", "0.0.0.0"] +''' + + output_path.write_text(dockerfile_content) + return output_path + + def generate_start_script(self, output_file: Optional[str] = None) -> Path: + """Generate shell script to start mock server. + + Args: + output_file: Optional output file path. + + Returns: + Path to the generated file. + """ + if output_file: + output_path = Path(output_file) + else: + output_path = self.output_dir / "start-mock-server.sh" + + script_content = f'''#!/bin/bash + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${{BASH_SOURCE[0]}}")" && pwd)" +cd "$SCRIPT_DIR" + +PORT={self.DEFAULT_PORT} +HOST="0.0.0.0" + +echo "Starting mock server for API..." +echo "Mock server will be available at: http://localhost:$PORT" + +docker compose up -d mock-server + +echo "Mock server started successfully!" +echo "To stop: docker compose down" +''' + + output_path.write_text(script_content) + output_path.chmod(0o755) + return output_path + + def generate_health_check(self) -> Path: + """Generate health check endpoint configuration. + + Returns: + Path to the generated file. + """ + output_path = self.output_dir / "health-check.json" + + spec_info = self.spec_parser.get_info() + + health_check = { + "openapi": self.spec_parser.version, + "info": { + "title": f"{spec_info['title']} - Health Check", + "version": spec_info["version"], + "description": "Health check endpoint for the mock server", + }, + "paths": { + "/health": { + "get": { + "summary": "Health check endpoint", + "description": "Returns the health status of the mock server", + "operationId": "healthCheck", + "responses": { + "200": { + "description": "Mock server is healthy", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["healthy"], + }, + "api": { + "type": "string", + }, + "version": { + "type": "string", + }, + }, + }, + } + }, + } + }, + } + } + }, + } + + output_path.write_text(json.dumps(health_check, indent=2)) + return output_path diff --git a/examples/petstore.yaml b/examples/petstore.yaml new file mode 100644 index 0000000..a34f18a --- /dev/null +++ b/examples/petstore.yaml @@ -0,0 +1,297 @@ +openapi: 3.0.0 +info: + title: Pet Store API + description: A sample API for managing pets in a store + version: 1.0.0 +servers: + - url: https://api.petstore.example.com/v1 + description: Production server + - url: http://localhost:4010 + description: Mock server for testing +tags: + - name: pets + description: Operations on pets + - name: store + description: Store management operations + - name: users + description: User management operations + +paths: + /pets: + get: + summary: List all pets + description: Returns a list of pets with optional filtering + operationId: listPets + tags: + - pets + parameters: + - name: status + in: query + description: Filter pets by status + schema: + type: string + enum: + - available + - pending + - sold + - name: limit + in: query + description: Maximum number of pets to return + schema: + type: integer + default: 20 + - name: offset + in: query + description: Number of pets to skip + schema: + type: integer + default: 0 + responses: + '200': + description: A list of pets + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Pet' + post: + summary: Create a new pet + description: Creates a new pet in the store + operationId: createPet + tags: + - pets + security: + - BearerAuth: [] + - ApiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PetInput' + responses: + '201': + description: Pet created successfully + content: + application/json: + schema: + $ref: '#/components/schemas/Pet' + '400': + description: Invalid input + /pets/{petId}: + get: + summary: Get a pet by ID + description: Returns a single pet by its ID + operationId: getPetById + tags: + - pets + parameters: + - name: petId + in: path + description: ID of the pet to retrieve + required: true + schema: + type: string + responses: + '200': + description: A single pet + content: + application/json: + schema: + $ref: '#/components/schemas/Pet' + '404': + description: Pet not found + put: + summary: Update a pet + description: Updates an existing pet in the store + operationId: updatePet + tags: + - pets + security: + - BearerAuth: [] + parameters: + - name: petId + in: path + description: ID of the pet to update + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PetInput' + responses: + '200': + description: Pet updated successfully + content: + application/json: + schema: + $ref: '#/components/schemas/Pet' + '404': + description: Pet not found + delete: + summary: Delete a pet + description: Deletes a pet from the store + operationId: deletePet + tags: + - pets + security: + - ApiKeyAuth: [] + parameters: + - name: petId + in: path + description: ID of the pet to delete + required: true + schema: + type: string + responses: + '204': + description: Pet deleted successfully + '404': + description: Pet not found + /store/inventory: + get: + summary: Get store inventory + description: Returns pet inventory by status + operationId: getInventory + tags: + - store + security: + - ApiKeyAuth: [] + responses: + '200': + description: Inventory counts by status + content: + application/json: + schema: + type: object + additionalProperties: + type: integer + /users: + get: + summary: List all users + description: Returns a list of users + operationId: listUsers + tags: + - users + responses: + '200': + description: A list of users + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/User' + post: + summary: Create a user + description: Creates a new user + operationId: createUser + tags: + - users + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserInput' + responses: + '201': + description: User created successfully + +components: + securitySchemes: + ApiKeyAuth: + type: apiKey + name: X-API-Key + in: header + BearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + schemas: + Pet: + type: object + required: + - name + - status + properties: + id: + type: string + format: uuid + description: Unique identifier + name: + type: string + description: Name of the pet + status: + type: string + enum: + - available + - pending + - sold + description: Status of the pet + tags: + type: array + items: + type: string + description: Tags associated with the pet + photoUrls: + type: array + items: + type: string + format: uri + description: URLs of pet photos + PetInput: + type: object + required: + - name + - status + properties: + name: + type: string + description: Name of the pet + status: + type: string + enum: + - available + - pending + - sold + description: Status of the pet + tags: + type: array + items: + type: string + User: + type: object + properties: + id: + type: string + format: uuid + username: + type: string + description: Unique username + email: + type: string + format: email + firstName: + type: string + lastName: + type: string + UserInput: + type: object + required: + - username + - email + properties: + username: + type: string + email: + type: string + format: email + firstName: + type: string + lastName: + type: string diff --git a/templates/go/api_test.go.j2 b/templates/go/api_test.go.j2 new file mode 100644 index 0000000..365a97d --- /dev/null +++ b/templates/go/api_test.go.j2 @@ -0,0 +1,129 @@ +package {{ package_name }} + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "testing" + "time" +) + +const ( + baseURL = "{{ mock_server_url }}" + testTimeout = 10 * time.Second +) + +{% if security_schemes %} +{% for scheme_name, scheme in security_schemes.items() %} +{% if scheme.type == "apiKey" %} +func get{{ scheme_name|capitalize }}Headers() map[string]string { + return map[string]string{ + "{{ scheme.name }}": os.Getenv("API_KEY"), + } +} + +{% elif scheme.type == "http" and scheme.scheme == "bearer" %} +func getBearerHeaders() map[string]string { + return map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", os.Getenv("TOKEN")), + } +} + +{% elif scheme.type == "http" and scheme.scheme == "basic" %} +func getBasicHeaders(username, password string) map[string]string { + auth := username + ":" + password + encoded := base64.StdEncoding.EncodeToString([]byte(auth)) + return map[string]string{ + "Authorization": "Basic " + encoded, + } +} + +{% endif %} +{% endfor %} +{% endif %} + +func contains(slice []int, item int) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +{% for path, path_endpoints in grouped_endpoints.items() %} +{% for endpoint in path_endpoints %} +{% set test_name = (endpoint.method + "_" + path.strip("/").replace("/", "_").replace("{", "").replace("}", "")).title().replace("_", "") %} +func Test{{ test_name }}(t *testing.T) { + client := &http.Client{Timeout: testTimeout} + url := baseURL + "{{ path }}" + + {% for param in endpoint.parameters %} + {% if param.in == "path" %} + url = strings.Replace(url, "{+{{ param.name }}}", "{{ param.name }}", 1) + {% endif %} + {% endfor %} + + {% if endpoint.method in ["post", "put", "patch"] %} + body := `{}` + req, err := http.NewRequest("{{ endpoint.method.upper() }}", url, strings.NewReader(body)) + {% else %} + req, err := http.NewRequest("{{ endpoint.method.upper() }}", url, nil) + {% endif %} + + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + + {% if endpoint.security %} + {% set scheme_name = endpoint.security[0].keys()|first %} + {% set scheme = security_schemes[scheme_name] %} + {% if scheme.type == "apiKey" %} + for k, v := range get{{ scheme_name|capitalize }}Headers() { + req.Header.Set(k, v) + } + {% elif scheme.type == "http" and scheme.scheme == "bearer" %} + for k, v := range getBearerHeaders() { + req.Header.Set(k, v) + } + {% endif %} + {% endif %} + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if !contains([]int{200, 201, 204}, resp.StatusCode) { + t.Errorf("Expected status in [200, 201, 204], got %d", resp.StatusCode) + } + + if resp.StatusCode == 200 || resp.StatusCode == 201 { + var data interface{} + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + t.Errorf("Failed to decode response: %v", err) + } + } +} + +{% endfor %} +{% endfor %} + +func TestAPIHealth(t *testing.T) { + client := &http.Client{Timeout: testTimeout} + + resp, err := client.Get(baseURL + "/health") + if err != nil { + t.Fatalf("Health check failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} diff --git a/templates/jest/api.test.js.j2 b/templates/jest/api.test.js.j2 new file mode 100644 index 0000000..ad4a40a --- /dev/null +++ b/templates/jest/api.test.js.j2 @@ -0,0 +1,101 @@ +/** + * Generated Jest tests for {{ api_title }} API v{{ api_version }} + * + * This file was auto-generated by API TestGen. + */ + +const request = require('supertest'); + +const BASE_URL = process.env.MOCK_SERVER_URL || '{{ mock_server_url }}'; + +{% if security_schemes %} +{% for scheme_name, scheme in security_schemes.items() %} +{% if scheme.type == "apiKey" %} +const get{{ scheme_name|capitalize }}Headers = () => ({ + "{{ scheme.name }}": process.env.API_KEY || 'your-api-key', +}); + +{% elif scheme.type == "http" and scheme.scheme == "bearer" %} +const getBearerHeaders = () => ({ + Authorization: `Bearer ${process.env.TOKEN || 'your-token'}`, +}); + +{% elif scheme.type == "http" and scheme.scheme == "basic" %} +const getBasicHeaders = () => { + const credentials = Buffer.from(`${process.env.USERNAME || 'username'}:${process.env.PASSWORD || 'password'}`).toString('base64'); + return { Authorization: `Basic ${credentials}` }; +}; + +{% endif %} +{% endfor %} +{% endif %} + +const validateResponse = (response, expectedStatus) => { + expect(response.headers['content-type']).toMatch(/application\/json/); + if (expectedStatus) { + expect(response.status).toBe(expectedStatus); + } +}; + +describe('{{ api_title }} API', () => { + {% for endpoint in endpoints %} + {% set endpoint_id = endpoint.operation_id or (endpoint.method + "_" + endpoint.path.strip("/").replace("/", "_").replace("{", "").replace("}", "")) %} + describe('{{ endpoint.method.upper() }} {{ endpoint.path }}', () => { + {{ endpoint.description or "" }} + + it('should return valid response', async () => { + {% if endpoint.security %} + {% set scheme_name = endpoint.security[0].keys()|first %} + {% set scheme = security_schemes[scheme_name] %} + {% if scheme.type == "apiKey" %} + const headers = get{{ scheme_name|capitalize }}Headers(); + {% elif scheme.type == "http" and scheme.scheme == "bearer" %} + const headers = getBearerHeaders(); + {% elif scheme.type == "http" and scheme.scheme == "basic" %} + const headers = getBasicHeaders(); + {% else %} + const headers = {}; + {% endif %} + {% else %} + const headers = { 'Content-Type': 'application/json' }; + {% endif %} + + {% if endpoint.method in ["post", "put", "patch"] %} + const body = {}; + + const response = await request(BASE_URL) + .{{ endpoint["method"] }}('{{ endpoint["path"] }}') + .send(body) + .set(headers); + {% else %} + {% if endpoint.parameters|selectattr("in", "equalto", "query")|list %} + {% set query_params = endpoint.parameters|selectattr("in", "equalto", "query")|list %} + const queryParams = { + {% for param in query_params %} + {{ param.name }}: 'test'{% if not loop.last %}, + {% endif %} + {% endfor %} + }; + + const response = await request(BASE_URL) + .{{ endpoint["method"] }}('{{ endpoint["path"] }}') + .query(queryParams) + .set(headers); + {% else %} + const response = await request(BASE_URL) + .{{ endpoint["method"] }}('{{ endpoint["path"] }}') + .set(headers); + {% endif %} + {% endif %} + + expect([200, 201, 204]).toContain(response.status); + {% if endpoint.responses["200"] %} + validateResponse(response, {{ endpoint.responses["200"] }}); + {% else %} + validateResponse(response); + {% endif %} + }); + }); + + {% endfor %} +}); diff --git a/templates/pytest/test_base.py.j2 b/templates/pytest/test_base.py.j2 new file mode 100644 index 0000000..e60f3c4 --- /dev/null +++ b/templates/pytest/test_base.py.j2 @@ -0,0 +1,115 @@ +""" +Generated pytest tests for {{ api_title }} API v{{ api_version }} + +This file was auto-generated by API TestGen. +""" +import pytest +import requests +import json +from jsonschema import validate, ValidationError + + +BASE_URL = "{{ mock_server_url }}" + + +{% if security_schemes %} +{% for scheme_name, scheme in security_schemes.items() %} +{% if scheme.type == "apiKey" %} +@pytest.fixture +def {{ scheme_name }}_headers(): + """API Key authentication headers.""" + return {"{{ scheme.name }}": "your-api-key"} + + +{% elif scheme.type == "http" and scheme.scheme == "bearer" %} +@pytest.fixture +def bearer_headers(): + """Bearer token authentication headers.""" + return {"Authorization": "Bearer your-token"} + + +{% elif scheme.type == "http" and scheme.scheme == "basic" %} +@pytest.fixture +def basic_headers(): + """Basic authentication headers.""" + import base64 + credentials = "username:password" + encoded = base64.b64encode(credentials.encode()).decode() + return {"Authorization": f"Basic {encoded}"} + + +{% endif %} +{% endfor %} +{% endif %} + + +{% if definitions %} +@pytest.fixture +def base_url(): + """Base URL for API requests.""" + return BASE_URL + + +{% endif %} +def validate_response(response, status_code=None): + """Validate API response. + + Args: + response: The response object. + status_code: Expected status code (optional). + """ + if status_code: + assert response.status_code == status_code, \ + f"Expected status {status_code}, got {response.status_code}" + + assert response.headers.get("Content-Type", "").startswith("application/json"), \ + "Response Content-Type is not JSON" + + +{% for endpoint in endpoints %} +{% set endpoint_id = endpoint.operation_id or (endpoint.method + "_" + endpoint.path.strip("/").replace("/", "_").replace("{", "").replace("}", "")) %} +def test_{{ endpoint_id }}(base_url{% if endpoint.parameters|selectattr("in", "equalto", "path")|list %}, {% for param in endpoint.parameters %}{% if param.in == "path" %}{{ param.name }}{% endif %}{% endfor %}{% endif %}): + """Test {{ endpoint.summary or endpoint.path }} endpoint. + + {{ endpoint.description or "" }} + """ + url = f"{base_url}{{ endpoint.path }}" + + headers = {"Content-Type": "application/json"} + + {% if endpoint.security %} + {% for security_requirement in endpoint.security %} + {% for scheme_name in security_requirement.keys() %} + {% if security_schemes[scheme_name].type == "apiKey" %} + headers["{{ security_schemes[scheme_name].name }}"] = "test-api-key" + {% elif security_schemes[scheme_name].type == "http" and security_schemes[scheme_name].scheme == "bearer" %} + headers["Authorization"] = "Bearer test-token" + {% endif %} + {% endfor %} + {% endfor %} + {% endif %} + + {% if endpoint.method in ["post", "put", "patch"] %} + {% if endpoint.request_body %} + request_body = {} + {% else %} + request_body = {} + {% endif %} + + response = requests.{{ endpoint["method"] }}(url, json=request_body, headers=headers) + {% else %} + response = requests.{{ endpoint["method"] }}(url, headers=headers{% if endpoint.parameters|selectattr("in", "equalto", "query")|list %}{% set query_params = endpoint.parameters|selectattr("in", "equalto", "query")|list %}{% if query_params %}, params={% for param in query_params %}"{{ param.name }}": "test"{% endfor %}{% endif %}{% endif %}) + {% endif %} + + validate_response(response{% if endpoint.responses["200"] %}, {{ endpoint.responses["200"] }}{% endif %}) + + {% if endpoint.responses %} + {% if endpoint.responses["200"] and endpoint.responses["200"].schema %} + try: + data = response.json() + except json.JSONDecodeError: + pytest.fail("Response is not valid JSON") + + {% endif %} + {% endif %} +{% endfor %} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..4aad81f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests package for API TestGen.""" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..ae41a8c --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for API TestGen.""" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..7cda33a --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for API TestGen.""" diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 0000000..b854dcc --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,182 @@ +"""Unit tests for the auth configuration module.""" +import pytest +import tempfile +from pathlib import Path + +from api_testgen.core.auth import AuthConfig, AuthType +from api_testgen.core.exceptions import AuthConfigError, MissingSecuritySchemeError + + +class TestAuthConfig: + """Tests for AuthConfig class.""" + + def test_add_api_key(self): + """Test adding API key authentication.""" + auth = AuthConfig() + auth.add_api_key("test_api_key", header_name="X-API-Key", api_key="test123") + + method = auth.get_auth_method("test_api_key") + + assert method is not None + assert method["type"] == AuthType.API_KEY + assert method["header_name"] == "X-API-Key" + assert method["api_key"] == "test123" + + def test_add_bearer_token(self): + """Test adding Bearer token authentication.""" + auth = AuthConfig() + auth.add_bearer("test_bearer", token="abc123", token_prefix="Bearer") + + method = auth.get_auth_method("test_bearer") + + assert method is not None + assert method["type"] == AuthType.BEARER + assert method["token"] == "abc123" + assert method["token_prefix"] == "Bearer" + + def test_add_basic_auth(self): + """Test adding Basic authentication.""" + auth = AuthConfig() + auth.add_basic("test_basic", username="user", password="pass") + + method = auth.get_auth_method("test_basic") + + assert method is not None + assert method["type"] == AuthType.BASIC + assert method["username"] == "user" + assert method["password"] == "pass" + + def test_method_chaining(self): + """Test that add methods return self for chaining.""" + auth = AuthConfig() + + result = auth.add_api_key("key1") + assert result is auth + + result = auth.add_bearer("key2") + assert result is auth + + result = auth.add_basic("key3") + assert result is auth + + def test_get_all_methods(self): + """Test getting all configured auth methods.""" + auth = AuthConfig() + auth.add_api_key("api_key") + auth.add_bearer("bearer") + + methods = auth.get_all_methods() + + assert len(methods) == 2 + assert "api_key" in methods + assert "bearer" in methods + + def test_get_headers_api_key(self): + """Test getting headers for API key auth.""" + auth = AuthConfig() + auth.add_api_key("test", header_name="X-Custom-Key", api_key="mykey") + + headers = auth.get_headers("test") + + assert headers["X-Custom-Key"] == "mykey" + + def test_get_headers_bearer(self): + """Test getting headers for Bearer auth.""" + auth = AuthConfig() + auth.add_bearer("test", token="mytoken", token_prefix="Bearer") + + headers = auth.get_headers("test") + + assert headers["Authorization"] == "Bearer mytoken" + + def test_get_headers_basic(self): + """Test getting headers for Basic auth.""" + import base64 + + auth = AuthConfig() + auth.add_basic("test", username="user", password="pass") + + headers = auth.get_headers("test") + + expected = base64.b64encode(b"user:pass").decode() + assert headers["Authorization"] == f"Basic {expected}" + + def test_get_headers_unconfigured_scheme_raises_error(self): + """Test that getting headers for unconfigured scheme raises error.""" + auth = AuthConfig() + + with pytest.raises(AuthConfigError): + auth.get_headers("nonexistent") + + def test_build_from_spec(self): + """Test building auth config from OpenAPI security schemes.""" + auth = AuthConfig() + + security_schemes = { + "ApiKeyAuth": {"type": "apiKey", "name": "X-API-Key", "in": "header"}, + "BearerAuth": {"type": "http", "scheme": "bearer"}, + "BasicAuth": {"type": "http", "scheme": "basic"}, + } + + security_requirements = [ + {"ApiKeyAuth": []}, + {"BearerAuth": []}, + ] + + auth.build_from_spec(security_schemes, security_requirements) + + assert auth.get_auth_method("ApiKeyAuth") is not None + assert auth.get_auth_method("BearerAuth") is not None + assert auth.get_auth_method("BasicAuth") is None + + def test_build_from_spec_missing_scheme_raises_error(self): + """Test that missing security scheme raises error.""" + auth = AuthConfig() + + security_schemes = { + "ApiKeyAuth": {"type": "apiKey", "name": "X-API-Key", "in": "header"}, + } + + security_requirements = [ + {"MissingScheme": []}, + ] + + with pytest.raises(MissingSecuritySchemeError): + auth.build_from_spec(security_schemes, security_requirements) + + def test_generate_pytest_auth_code(self): + """Test generating pytest authentication code.""" + auth = AuthConfig() + auth.add_api_key("test_key", header_name="X-Api-Key", api_key="key123") + + code = auth.generate_auth_code("test_key", "pytest") + + assert "X-Api-Key" in code + assert "key123" in code + + def test_generate_jest_auth_code(self): + """Test generating Jest authentication code.""" + auth = AuthConfig() + auth.add_bearer("test_bearer", token="token123") + + code = auth.generate_auth_code("test_bearer", "jest") + + assert "Authorization" in code + assert "token123" in code + + def test_generate_go_auth_code(self): + """Test generating Go authentication code.""" + auth = AuthConfig() + auth.add_api_key("test_key", header_name="X-Api-Key") + + code = auth.generate_auth_code("test_key", "go") + + assert "X-Api-Key" in code + + def test_generate_auth_code_unconfigured_scheme(self): + """Test generating auth code for unconfigured scheme returns empty.""" + auth = AuthConfig() + + code = auth.generate_auth_code("nonexistent", "pytest") + + assert code == "" diff --git a/tests/unit/test_generators.py b/tests/unit/test_generators.py new file mode 100644 index 0000000..d0be78d --- /dev/null +++ b/tests/unit/test_generators.py @@ -0,0 +1,223 @@ +"""Unit tests for the generator modules.""" +import pytest +import tempfile +import os +from pathlib import Path + +from api_testgen.core import SpecParser +from api_testgen.generators import PytestGenerator, JestGenerator, GoGenerator +from api_testgen.mocks import MockServerGenerator + + +class TestPytestGenerator: + """Tests for PytestGenerator class.""" + + @pytest.fixture + def parser(self, sample_openapi_spec): + """Create a spec parser with sample spec.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + import yaml + yaml.dump(sample_openapi_spec, f) + parser = SpecParser(Path(f.name)) + parser.load() + return parser + + def test_generate_creates_file(self, parser, tmp_path): + """Test that generate creates a test file.""" + generator = PytestGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + assert len(files) == 1 + assert files[0].exists() + assert files[0].suffix == ".py" + + def test_generate_content_contains_tests(self, parser, tmp_path): + """Test that generated file contains test functions.""" + generator = PytestGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + content = files[0].read_text() + + assert "test_" in content + assert "BASE_URL" in content + assert "def test_" in content + + def test_generate_with_custom_output_file(self, parser, tmp_path): + """Test generating to a specific file path.""" + generator = PytestGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate(output_file=str(tmp_path / "custom_test.py")) + + assert len(files) == 1 + assert files[0].name == "custom_test.py" + + def test_generate_endpoint_test(self, parser): + """Test generating a single endpoint test.""" + generator = PytestGenerator(parser) + + test_code = generator.generate_endpoint_tests("/pets", "get") + + assert "test_" in test_code + assert "requests.get" in test_code + + def test_generate_test_name(self, parser): + """Test test name generation.""" + generator = PytestGenerator(parser) + + name = generator._generate_test_name({ + "path": "/pets", + "method": "get", + "summary": "List all pets", + }) + + assert name == "get_pets" or "pets" in name + + +class TestJestGenerator: + """Tests for JestGenerator class.""" + + @pytest.fixture + def parser(self, sample_openapi_spec): + """Create a spec parser with sample spec.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + import yaml + yaml.dump(sample_openapi_spec, f) + parser = SpecParser(Path(f.name)) + parser.load() + return parser + + def test_generate_creates_file(self, parser, tmp_path): + """Test that generate creates a test file.""" + generator = JestGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + assert len(files) == 1 + assert files[0].exists() + assert files[0].suffix == ".js" + + def test_generate_content_contains_describe(self, parser, tmp_path): + """Test that generated file contains describe blocks.""" + generator = JestGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + content = files[0].read_text() + + assert "describe" in content + expect_in = "expect" in content or "toContain" in content or "toBe" in content + assert expect_in + + def test_generate_endpoint_test(self, parser): + """Test generating a single endpoint test.""" + generator = JestGenerator(parser) + + test_code = generator.generate_endpoint_tests("/pets", "get") + + assert "describe" in test_code or "it(" in test_code + + +class TestGoGenerator: + """Tests for GoGenerator class.""" + + @pytest.fixture + def parser(self, sample_openapi_spec): + """Create a spec parser with sample spec.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + import yaml + yaml.dump(sample_openapi_spec, f) + parser = SpecParser(Path(f.name)) + parser.load() + return parser + + def test_generate_creates_file(self, parser, tmp_path): + """Test that generate creates a test file.""" + generator = GoGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + assert len(files) == 1 + assert files[0].exists() + assert files[0].name.endswith("_test.go") + + def test_generate_content_contains_tests(self, parser, tmp_path): + """Test that generated file contains test functions.""" + generator = GoGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + content = files[0].read_text() + + assert "func Test" in content + assert "http.Client" in content or "http.NewRequest" in content + + def test_generate_with_custom_package(self, parser, tmp_path): + """Test generating with custom package name.""" + generator = GoGenerator(parser, output_dir=str(tmp_path), package_name="custompkg") + files = generator.generate() + + content = files[0].read_text() + + assert "package custompkg" in content + + +class TestMockServerGenerator: + """Tests for MockServerGenerator class.""" + + @pytest.fixture + def parser(self, sample_openapi_spec): + """Create a spec parser with sample spec.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + import yaml + yaml.dump(sample_openapi_spec, f) + parser = SpecParser(Path(f.name)) + parser.load() + return parser + + def test_generate_prism_config(self, parser, tmp_path): + """Test generating Prism configuration.""" + generator = MockServerGenerator(parser, output_dir=str(tmp_path)) + file = generator.generate_prism_config() + + assert file.exists() + assert file.name == "prism-config.json" + + content = file.read_text() + import json + config = json.loads(content) + + assert "mock" in config + assert "port" in config["mock"] + + def test_generate_docker_compose(self, parser, tmp_path): + """Test generating Docker Compose configuration.""" + generator = MockServerGenerator(parser, output_dir=str(tmp_path)) + file = generator.generate_docker_compose() + + assert file.exists() + assert file.name == "docker-compose.yml" + + content = file.read_text() + + assert "services" in content + assert "mock-server" in content + + def test_generate_dockerfile(self, parser, tmp_path): + """Test generating Dockerfile.""" + generator = MockServerGenerator(parser, output_dir=str(tmp_path)) + file = generator.generate_dockerfile() + + assert file.exists() + assert file.name == "Dockerfile" + + content = file.read_text() + + assert "FROM" in content + assert "prism" in content.lower() + + def test_generate_all(self, parser, tmp_path): + """Test generating all mock server files.""" + generator = MockServerGenerator(parser, output_dir=str(tmp_path)) + files = generator.generate() + + assert len(files) == 3 + + file_names = [f.name for f in files] + assert "prism-config.json" in file_names + assert "docker-compose.yml" in file_names + assert "Dockerfile" in file_names diff --git a/tests/unit/test_spec_parser.py b/tests/unit/test_spec_parser.py new file mode 100644 index 0000000..24dc453 --- /dev/null +++ b/tests/unit/test_spec_parser.py @@ -0,0 +1,173 @@ +"""Unit tests for the spec parser module.""" +import pytest +import tempfile +import json +from pathlib import Path + +from api_testgen.core import SpecParser +from api_testgen.core.exceptions import InvalidOpenAPISpecError, UnsupportedVersionError + + +class TestSpecParser: + """Tests for SpecParser class.""" + + def test_load_valid_openapi_30_spec(self, temp_spec_file): + """Test loading a valid OpenAPI 3.0 specification.""" + parser = SpecParser(temp_spec_file) + spec = parser.load() + + assert spec is not None + assert parser.version == "3.0.0" + assert parser.base_path == "" + assert len(parser.servers) == 1 + + def test_load_valid_json_spec(self, temp_json_spec_file): + """Test loading a valid JSON OpenAPI specification.""" + parser = SpecParser(temp_json_spec_file) + spec = parser.load() + + assert spec is not None + assert parser.version == "3.0.0" + + def test_get_info(self, temp_spec_file): + """Test extracting API info from spec.""" + parser = SpecParser(temp_spec_file) + parser.load() + + info = parser.get_info() + + assert info["title"] == "Test Pet Store API" + assert info["version"] == "1.0.0" + assert "sample" in info["description"].lower() + + def test_get_paths(self, temp_spec_file): + """Test extracting paths from spec.""" + parser = SpecParser(temp_spec_file) + parser.load() + + paths = parser.get_paths() + + assert "/pets" in paths + assert "/pets/{petId}" in paths + + def test_get_endpoints(self, temp_spec_file): + """Test extracting endpoints from spec.""" + parser = SpecParser(temp_spec_file) + parser.load() + + endpoints = parser.get_endpoints() + + assert len(endpoints) == 4 + assert any(e["method"] == "get" and e["path"] == "/pets" for e in endpoints) + assert any(e["method"] == "post" and e["path"] == "/pets" for e in endpoints) + assert any(e["method"] == "get" and e["path"] == "/pets/{petId}" for e in endpoints) + assert any(e["method"] == "delete" and e["path"] == "/pets/{petId}" for e in endpoints) + + def test_get_security_schemes(self, temp_spec_file): + """Test extracting security schemes from spec.""" + parser = SpecParser(temp_spec_file) + parser.load() + + schemes = parser.get_security_schemes() + + assert "ApiKeyAuth" in schemes + assert "BearerAuth" in schemes + assert schemes["ApiKeyAuth"]["type"] == "apiKey" + assert schemes["BearerAuth"]["type"] == "http" + assert schemes["BearerAuth"]["scheme"] == "bearer" + + def test_get_definitions(self, temp_spec_file): + """Test extracting schema definitions from spec.""" + parser = SpecParser(temp_spec_file) + parser.load() + + definitions = parser.get_definitions() + + assert "Pet" in definitions + assert definitions["Pet"]["type"] == "object" + + def test_endpoint_with_parameters(self, temp_spec_file): + """Test endpoint parameter extraction.""" + parser = SpecParser(temp_spec_file) + parser.load() + + endpoints = parser.get_endpoints() + + pets_endpoint = next(e for e in endpoints if e["path"] == "/pets" and e["method"] == "get") + + assert len(pets_endpoint["parameters"]) == 2 + assert any(p["name"] == "limit" and p["in"] == "query" for p in pets_endpoint["parameters"]) + assert any(p["name"] == "status" and p["in"] == "query" for p in pets_endpoint["parameters"]) + + def test_endpoint_with_path_parameters(self, temp_spec_file): + """Test path parameter extraction.""" + parser = SpecParser(temp_spec_file) + parser.load() + + endpoints = parser.get_endpoints() + + pet_endpoint = next(e for e in endpoints if e["path"] == "/pets/{petId}" and e["method"] == "get") + + assert len(pet_endpoint["parameters"]) == 1 + param = pet_endpoint["parameters"][0] + assert param["name"] == "petId" + assert param["in"] == "path" + assert param["required"] is True + + def test_endpoint_with_request_body(self, temp_spec_file): + """Test request body extraction for OpenAPI 3.0.""" + parser = SpecParser(temp_spec_file) + parser.load() + + endpoints = parser.get_endpoints() + + create_endpoint = next(e for e in endpoints if e["path"] == "/pets" and e["method"] == "post") + + assert create_endpoint["request_body"] is not None + assert create_endpoint["request_body"]["required"] is True + + def test_endpoint_with_responses(self, temp_spec_file): + """Test response extraction.""" + parser = SpecParser(temp_spec_file) + parser.load() + + endpoints = parser.get_endpoints() + + pets_endpoint = next(e for e in endpoints if e["path"] == "/pets" and e["method"] == "get") + + assert "200" in pets_endpoint["responses"] + assert pets_endpoint["responses"]["200"]["description"] == "A list of pets" + + def test_to_dict(self, temp_spec_file): + """Test dictionary representation of parsed spec.""" + parser = SpecParser(temp_spec_file) + parser.load() + + spec_dict = parser.to_dict() + + assert "version" in spec_dict + assert "info" in spec_dict + assert "paths" in spec_dict + assert "endpoints" in spec_dict + assert "security_schemes" in spec_dict + assert "definitions" in spec_dict + + def test_nonexistent_file_raises_error(self): + """Test that nonexistent file raises InvalidOpenAPISpecError.""" + parser = SpecParser("/nonexistent/path/spec.yaml") + + with pytest.raises(InvalidOpenAPISpecError): + parser.load() + + def test_invalid_yaml_raises_error(self): + """Test that invalid YAML raises InvalidOpenAPISpecError.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write("invalid: yaml: content: [[[") + f.flush() + + parser = SpecParser(Path(f.name)) + + with pytest.raises(InvalidOpenAPISpecError): + parser.load() + + Path(f.name).unlink()