Add testing module (ab_test, metrics, validator)

This commit is contained in:
2026-02-04 12:32:35 +00:00
parent 820bd9801c
commit ad4c01dac7

View File

@@ -0,0 +1,79 @@
import asyncio
import uuid
from dataclasses import dataclass, field
from typing import AsyncIterator, Dict, List, Optional
from .metrics import TestMetrics, MetricsCollector
from .results import TestResult, ComparisonResult
from ..core.prompt import Prompt
from ..providers.base import ProviderBase, ProviderResponse
@dataclass
class ABTestConfig:
iterations: int = 3
parallel: bool = False
class ABTest:
def __init__(self, provider: ProviderBase, config: ABTestConfig):
self.provider = provider
self.config = config
self.metrics_collector = MetricsCollector()
async def run_single(self, prompt: Prompt, variables: Dict[str, str]) -> TestResult:
test_id = str(uuid.uuid4())[:8]
try:
response = await self.provider.complete(
prompt.content.format(**variables) if variables else prompt.content
)
metrics = TestMetrics(
test_id=test_id,
prompt_name=prompt.name,
provider=self.provider.name,
model=self.provider.model,
latency_ms=response.latency_ms,
success=True,
tokens_used=response.usage.get("total_tokens", 0) if response.usage else 0,
)
return TestResult(success=True, response=response.content, metrics=metrics)
except Exception as e:
metrics = TestMetrics(
test_id=test_id,
prompt_name=prompt.name,
provider=self.provider.name,
model=self.provider.model,
latency_ms=0,
success=False,
error_message=str(e),
)
return TestResult(success=False, response="", metrics=metrics, error=str(e))
async def run_comparison(self, prompts: List[Prompt]) -> Dict[str, ComparisonResult]:
results = {}
for prompt in prompts:
all_metrics: List[TestMetrics] = []
for _ in range(self.config.iterations):
result = await self.run_single(prompt, {})
all_metrics.append(result.metrics)
comparison = self.metrics_collector.compare(prompt.name, all_metrics)
results[prompt.name] = comparison
return results
async def run_tests(self, prompt: Prompt, iterations: Optional[int] = None) -> ComparisonResult:
iterations = iterations or self.config.iterations
all_metrics: List[TestMetrics] = []
for _ in range(iterations):
result = await self.run_single(prompt, {})
all_metrics.append(result.metrics)
return self.metrics_collector.compare(prompt.name, all_metrics)