Add testing module (ab_test, metrics, validator)
This commit is contained in:
79
src/promptforge/testing/ab_test.py
Normal file
79
src/promptforge/testing/ab_test.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
|
from .metrics import TestMetrics, MetricsCollector
|
||||||
|
from .results import TestResult, ComparisonResult
|
||||||
|
from ..core.prompt import Prompt
|
||||||
|
from ..providers.base import ProviderBase, ProviderResponse
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestConfig:
|
||||||
|
iterations: int = 3
|
||||||
|
parallel: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ABTest:
|
||||||
|
def __init__(self, provider: ProviderBase, config: ABTestConfig):
|
||||||
|
self.provider = provider
|
||||||
|
self.config = config
|
||||||
|
self.metrics_collector = MetricsCollector()
|
||||||
|
|
||||||
|
async def run_single(self, prompt: Prompt, variables: Dict[str, str]) -> TestResult:
|
||||||
|
test_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.provider.complete(
|
||||||
|
prompt.content.format(**variables) if variables else prompt.content
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = TestMetrics(
|
||||||
|
test_id=test_id,
|
||||||
|
prompt_name=prompt.name,
|
||||||
|
provider=self.provider.name,
|
||||||
|
model=self.provider.model,
|
||||||
|
latency_ms=response.latency_ms,
|
||||||
|
success=True,
|
||||||
|
tokens_used=response.usage.get("total_tokens", 0) if response.usage else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TestResult(success=True, response=response.content, metrics=metrics)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
metrics = TestMetrics(
|
||||||
|
test_id=test_id,
|
||||||
|
prompt_name=prompt.name,
|
||||||
|
provider=self.provider.name,
|
||||||
|
model=self.provider.model,
|
||||||
|
latency_ms=0,
|
||||||
|
success=False,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
return TestResult(success=False, response="", metrics=metrics, error=str(e))
|
||||||
|
|
||||||
|
async def run_comparison(self, prompts: List[Prompt]) -> Dict[str, ComparisonResult]:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
all_metrics: List[TestMetrics] = []
|
||||||
|
|
||||||
|
for _ in range(self.config.iterations):
|
||||||
|
result = await self.run_single(prompt, {})
|
||||||
|
all_metrics.append(result.metrics)
|
||||||
|
|
||||||
|
comparison = self.metrics_collector.compare(prompt.name, all_metrics)
|
||||||
|
results[prompt.name] = comparison
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def run_tests(self, prompt: Prompt, iterations: Optional[int] = None) -> ComparisonResult:
|
||||||
|
iterations = iterations or self.config.iterations
|
||||||
|
all_metrics: List[TestMetrics] = []
|
||||||
|
|
||||||
|
for _ in range(iterations):
|
||||||
|
result = await self.run_single(prompt, {})
|
||||||
|
all_metrics.append(result.metrics)
|
||||||
|
|
||||||
|
return self.metrics_collector.compare(prompt.name, all_metrics)
|
||||||
Reference in New Issue
Block a user