fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
2026-02-04 12:49:09 +00:00
parent d8ecd258e9
commit 763828579b

View File

@@ -0,0 +1,198 @@
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..core.prompt import Prompt
from ..providers import ProviderBase, ProviderResponse
@dataclass
class ABTestConfig:
"""Configuration for A/B test."""
iterations: int = 3
provider: Optional[str] = None
max_tokens: Optional[int] = None
temperature: float = 0.7
parallel: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ABTestResult:
"""Result of a single test run."""
prompt: Prompt
response: ProviderResponse
variables: Dict[str, Any]
iteration: int
passed_validation: bool = False
validation_errors: List[str] = field(default_factory=list)
latency_ms: float = 0.0
timestamp: datetime = field(default_factory=datetime.utcnow)
@dataclass
class ABTestSummary:
"""Summary of A/B test results."""
prompt_name: str
config: ABTestConfig
total_runs: int
successful_runs: int
failed_runs: int
avg_latency_ms: float
avg_tokens: float
avg_cost: float
results: List[ABTestResult]
timestamp: datetime = field(default_factory=datetime.utcnow)
class ABTest:
"""A/B test runner for comparing prompt variations."""
def __init__(
self,
provider: ProviderBase,
config: Optional[ABTestConfig] = None,
):
"""Initialize A/B test runner.
Args:
provider: LLM provider to use.
config: Test configuration.
"""
self.provider = provider
self.config = config or ABTestConfig()
async def run(
self,
prompt: Prompt,
variables: Dict[str, Any],
) -> ABTestSummary:
"""Run A/B test on a prompt.
Args:
prompt: Prompt to test.
variables: Variables to substitute.
Returns:
ABTestSummary with all test results.
"""
results: List[ABTestResult] = []
latencies = []
total_tokens = []
for i in range(self.config.iterations):
try:
result = await self._run_single(prompt, variables, i + 1)
results.append(result)
latencies.append(result.latency_ms)
total_tokens.append(result.response.usage.get("total_tokens", 0))
except Exception:
results.append(ABTestResult(
prompt=prompt,
response=ProviderResponse(
content="",
model=prompt.provider or self.provider.name,
provider=self.provider.name,
),
variables=variables,
iteration=i + 1,
passed_validation=False,
validation_errors=["Test execution failed"],
))
successful = sum(1 for r in results if r.passed_validation or r.response.content)
avg_latency = sum(latencies) / len(latencies) if latencies else 0
avg_tokens = sum(total_tokens) / len(total_tokens) if total_tokens else 0
return ABTestSummary(
prompt_name=prompt.name,
config=self.config,
total_runs=self.config.iterations,
successful_runs=successful,
failed_runs=self.config.iterations - successful,
avg_latency_ms=avg_latency,
avg_tokens=avg_tokens,
avg_cost=self._estimate_cost(avg_tokens),
results=results,
)
async def run_comparison(
self,
prompts: List[Prompt],
shared_variables: Optional[Dict[str, Any]] = None,
) -> Dict[str, ABTestSummary]:
"""Run tests on multiple prompts for comparison.
Args:
prompts: List of prompts to compare.
shared_variables: Variables shared across all prompts.
Returns:
Dictionary mapping prompt names to their summaries.
"""
shared_variables = shared_variables or {}
summaries = {}
for prompt in prompts:
variables = self._merge_variables(prompt, shared_variables)
summary = await self.run(prompt, variables)
summaries[prompt.name] = summary
return summaries
async def _run_single(
self,
prompt: Prompt,
variables: Dict[str, Any],
iteration: int,
) -> ABTestResult:
"""Run a single test iteration."""
from ..core.template import TemplateEngine
template_engine = TemplateEngine()
rendered = template_engine.render(prompt.content, variables, prompt.variables)
start_time = time.time()
response = await self.provider.complete(
prompt=rendered,
max_tokens=self.config.max_tokens,
)
latency_ms = (time.time() - start_time) * 1000
return ABTestResult(
prompt=prompt,
response=response,
variables=variables,
iteration=iteration,
latency_ms=latency_ms,
)
def _merge_variables(
self,
prompt: Prompt,
shared: Dict[str, Any],
) -> Dict[str, Any]:
"""Merge shared variables with prompt-specific ones."""
variables = shared.copy()
for var in prompt.variables:
if var.name not in variables and var.default is not None:
variables[var.name] = var.default
return variables
def _estimate_cost(self, tokens: float) -> float:
"""Estimate cost based on token usage."""
rates = {
"gpt-4": 0.00003,
"gpt-4-turbo": 0.00001,
"gpt-3.5-turbo": 0.0000005,
"claude-3-sonnet-20240229": 0.000003,
"claude-3-opus-20240229": 0.000015,
}
rate = rates.get(self.provider.model, 0.000001)
return tokens * rate