314 lines
9.7 KiB
Python
314 lines
9.7 KiB
Python
"""Authentication configuration for API TestGen."""
|
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from enum import Enum
|
|
|
|
from .exceptions import AuthConfigError, MissingSecuritySchemeError
|
|
|
|
|
|
class AuthType(str, Enum):
|
|
"""Types of authentication."""
|
|
API_KEY = "apiKey"
|
|
BEARER = "bearer"
|
|
BASIC = "basic"
|
|
NONE = "none"
|
|
|
|
|
|
class AuthConfig:
|
|
"""Authentication configuration for generated tests."""
|
|
|
|
def __init__(self):
|
|
"""Initialize authentication configuration."""
|
|
self._auth_methods: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def add_api_key(
|
|
self,
|
|
scheme_name: str,
|
|
header_name: str = "X-API-Key",
|
|
api_key: str = "",
|
|
) -> "AuthConfig":
|
|
"""Add API key authentication.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme in OpenAPI spec.
|
|
header_name: Name of the header containing the API key.
|
|
api_key: The API key value (can be set later).
|
|
|
|
Returns:
|
|
Self for method chaining.
|
|
"""
|
|
self._auth_methods[scheme_name] = {
|
|
"type": AuthType.API_KEY,
|
|
"header_name": header_name,
|
|
"api_key": api_key,
|
|
}
|
|
return self
|
|
|
|
def add_bearer(
|
|
self,
|
|
scheme_name: str,
|
|
token: str = "",
|
|
token_prefix: str = "Bearer",
|
|
) -> "AuthConfig":
|
|
"""Add Bearer token authentication.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme in OpenAPI spec.
|
|
token: The Bearer token value (can be set later).
|
|
token_prefix: The token prefix (default: Bearer).
|
|
|
|
Returns:
|
|
Self for method chaining.
|
|
"""
|
|
self._auth_methods[scheme_name] = {
|
|
"type": AuthType.BEARER,
|
|
"token": token,
|
|
"token_prefix": token_prefix,
|
|
}
|
|
return self
|
|
|
|
def add_basic(
|
|
self,
|
|
scheme_name: str,
|
|
username: str = "",
|
|
password: str = "",
|
|
) -> "AuthConfig":
|
|
"""Add Basic authentication.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme in OpenAPI spec.
|
|
username: The username (can be set later).
|
|
password: The password (can be set later).
|
|
|
|
Returns:
|
|
Self for method chaining.
|
|
"""
|
|
self._auth_methods[scheme_name] = {
|
|
"type": AuthType.BASIC,
|
|
"username": username,
|
|
"password": password,
|
|
}
|
|
return self
|
|
|
|
def get_auth_method(self, scheme_name: str) -> Optional[Dict[str, Any]]:
|
|
"""Get authentication method by scheme name.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme.
|
|
|
|
Returns:
|
|
Authentication method configuration or None.
|
|
"""
|
|
return self._auth_methods.get(scheme_name)
|
|
|
|
def get_all_methods(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Get all configured authentication methods.
|
|
|
|
Returns:
|
|
Dictionary of scheme names and their configurations.
|
|
"""
|
|
return self._auth_methods.copy()
|
|
|
|
def get_headers(self, scheme_name: str) -> Dict[str, str]:
|
|
"""Get authentication headers for a scheme.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme.
|
|
|
|
Returns:
|
|
Dictionary of header names and values.
|
|
|
|
Raises:
|
|
AuthConfigError: If scheme is not configured.
|
|
"""
|
|
method = self.get_auth_method(scheme_name)
|
|
if not method:
|
|
raise AuthConfigError(f"Authentication scheme '{scheme_name}' not configured")
|
|
|
|
if method["type"] == AuthType.API_KEY:
|
|
return {method["header_name"]: method["api_key"]}
|
|
elif method["type"] == AuthType.BEARER:
|
|
return {"Authorization": f"{method['token_prefix']} {method['token']}"}
|
|
elif method["type"] == AuthType.BASIC:
|
|
import base64
|
|
credentials = f"{method['username']}:{method['password']}"
|
|
encoded = base64.b64encode(credentials.encode()).decode()
|
|
return {"Authorization": f"Basic {encoded}"}
|
|
return {}
|
|
|
|
def build_from_spec(
|
|
self,
|
|
security_schemes: Dict[str, Any],
|
|
security_requirements: List[Dict[str, Any]],
|
|
) -> "AuthConfig":
|
|
"""Build auth configuration from OpenAPI security schemes.
|
|
|
|
Args:
|
|
security_schemes: Security schemes from OpenAPI spec.
|
|
security_requirements: Security requirements from endpoint.
|
|
|
|
Returns:
|
|
Self for method chaining.
|
|
|
|
Raises:
|
|
MissingSecuritySchemeError: If required scheme is not defined.
|
|
"""
|
|
for requirement in security_requirements:
|
|
for scheme_name in requirement.keys():
|
|
if scheme_name not in self._auth_methods:
|
|
if scheme_name not in security_schemes:
|
|
raise MissingSecuritySchemeError(
|
|
f"Security scheme '{scheme_name}' not found in spec"
|
|
)
|
|
scheme = security_schemes[scheme_name]
|
|
self._add_scheme_from_spec(scheme_name, scheme)
|
|
|
|
return self
|
|
|
|
def _add_scheme_from_spec(self, scheme_name: str, scheme: Dict[str, Any]) -> None:
|
|
"""Add authentication scheme from OpenAPI spec definition.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme.
|
|
scheme: The security scheme definition from OpenAPI spec.
|
|
"""
|
|
scheme_type = scheme.get("type", "")
|
|
|
|
if scheme_type == "apiKey":
|
|
self.add_api_key(
|
|
scheme_name,
|
|
header_name=scheme.get("name", "X-API-Key"),
|
|
)
|
|
elif scheme_type == "http":
|
|
scheme_scheme = scheme.get("scheme", "").lower()
|
|
if scheme_scheme == "bearer":
|
|
self.add_bearer(scheme_name)
|
|
elif scheme_scheme == "basic":
|
|
self.add_basic(scheme_name)
|
|
elif scheme_type == "openIdConnect":
|
|
self.add_bearer(scheme_name)
|
|
elif scheme_type == "oauth2":
|
|
self.add_bearer(scheme_name)
|
|
|
|
def generate_auth_code(self, scheme_name: str, framework: str = "pytest") -> str:
|
|
"""Generate authentication code for a test framework.
|
|
|
|
Args:
|
|
scheme_name: Name of the security scheme.
|
|
framework: Target test framework (pytest, jest, go).
|
|
|
|
Returns:
|
|
String containing authentication code snippet.
|
|
"""
|
|
method = self.get_auth_method(scheme_name)
|
|
if not method:
|
|
return ""
|
|
|
|
if framework == "pytest":
|
|
return self._generate_pytest_auth(method)
|
|
elif framework == "jest":
|
|
return self._generate_jest_auth(method)
|
|
elif framework == "go":
|
|
return self._generate_go_auth(method)
|
|
return ""
|
|
|
|
def _generate_pytest_auth(self, method: Dict[str, Any]) -> str:
|
|
"""Generate pytest authentication code.
|
|
|
|
Args:
|
|
method: Authentication method configuration.
|
|
|
|
Returns:
|
|
String containing pytest auth code.
|
|
"""
|
|
if method["type"] == AuthType.API_KEY:
|
|
return f'''
|
|
@pytest.fixture
|
|
def api_key_headers():
|
|
return {{"{method['header_name']}": "{method['api_key']}"}}
|
|
'''
|
|
elif method["type"] == AuthType.BEARER:
|
|
return f'''
|
|
@pytest.fixture
|
|
def bearer_headers():
|
|
return {{"Authorization": "{method['token_prefix']} {method['token']}"}}
|
|
'''
|
|
elif method["type"] == AuthType.BASIC:
|
|
return f'''
|
|
import base64
|
|
|
|
@pytest.fixture
|
|
def basic_headers():
|
|
credentials = f"{{"{method['username']}"}}:{{"{method['password']}"}}"
|
|
encoded = base64.b64encode(credentials.encode()).decode()
|
|
return {{"Authorization": f"Basic {{encoded}}"}}
|
|
'''
|
|
return ""
|
|
|
|
def _generate_jest_auth(self, method: Dict[str, Any]) -> str:
|
|
"""Generate Jest authentication code.
|
|
|
|
Args:
|
|
method: Authentication method configuration.
|
|
|
|
Returns:
|
|
String containing Jest auth code.
|
|
"""
|
|
if method["type"] == AuthType.API_KEY:
|
|
return f'''
|
|
const getApiKeyHeaders = () => ({{
|
|
"{method['header_name']}": process.env.API_KEY || "{method['api_key']}",
|
|
}});
|
|
'''
|
|
elif method["type"] == AuthType.BEARER:
|
|
return f'''
|
|
const getBearerHeaders = () => ({{
|
|
Authorization: `${{process.env.TOKEN_PREFIX || "{method['token_prefix']}"}} ${{process.env.TOKEN || "{method['token']}"}}`,
|
|
}});
|
|
'''
|
|
elif method["type"] == AuthType.BASIC:
|
|
return f'''
|
|
const getBasicHeaders = () => {{
|
|
const credentials = Buffer.from(`${{process.env.USERNAME || "{method['username']}"}}:${{process.env.PASSWORD || "{method['password']}"}}`).toString('base64');
|
|
return {{ Authorization: `Basic ${{credentials}}` }};
|
|
}};
|
|
'''
|
|
return ""
|
|
|
|
def _generate_go_auth(self, method: Dict[str, Any]) -> str:
|
|
"""Generate Go authentication code.
|
|
|
|
Args:
|
|
method: Authentication method configuration.
|
|
|
|
Returns:
|
|
String containing Go auth code.
|
|
"""
|
|
if method["type"] == AuthType.API_KEY:
|
|
return f'''
|
|
func getAPIKeyHeaders() map[string]string {{
|
|
return map[string]string{{
|
|
"{method['header_name']}": os.Getenv("API_KEY"),
|
|
}}
|
|
}}
|
|
'''
|
|
elif method["type"] == AuthType.BEARER:
|
|
return f'''
|
|
func getBearerHeaders() map[string]string {{
|
|
return map[string]string{{
|
|
"Authorization": fmt.Sprintf("%s %s", os.Getenv("TOKEN_PREFIX"), os.Getenv("TOKEN")),
|
|
}}
|
|
}}
|
|
'''
|
|
elif method["type"] == AuthType.BASIC:
|
|
return f'''
|
|
func getBasicHeaders(username, password string) map[string]string {{
|
|
auth := username + ":" + password
|
|
encoded := base64.StdEncoding.EncodeToString([]byte(auth))
|
|
return map[string]string{{
|
|
"Authorization": "Basic " + encoded,
|
|
}}
|
|
}}
|
|
'''
|
|
return ""
|