fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,79 +1,200 @@
|
|||||||
import asyncio
|
"""A/B testing framework for comparing prompt variations."""
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
import time
|
||||||
from typing import AsyncIterator, Dict, List, Optional
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .metrics import TestMetrics, MetricsCollector
|
|
||||||
from .results import TestResult, ComparisonResult
|
|
||||||
from ..core.prompt import Prompt
|
from ..core.prompt import Prompt
|
||||||
from ..providers.base import ProviderBase, ProviderResponse
|
from ..providers import ProviderBase, ProviderResponse
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ABTestConfig:
|
class ABTestConfig:
|
||||||
|
"""Configuration for A/B test."""
|
||||||
|
|
||||||
iterations: int = 3
|
iterations: int = 3
|
||||||
|
provider: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: float = 0.7
|
||||||
parallel: bool = False
|
parallel: bool = False
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestResult:
|
||||||
|
"""Result of a single test run."""
|
||||||
|
|
||||||
|
prompt: Prompt
|
||||||
|
response: ProviderResponse
|
||||||
|
variables: Dict[str, Any]
|
||||||
|
iteration: int
|
||||||
|
passed_validation: bool = False
|
||||||
|
validation_errors: List[str] = field(default_factory=list)
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestSummary:
|
||||||
|
"""Summary of A/B test results."""
|
||||||
|
|
||||||
|
prompt_name: str
|
||||||
|
config: ABTestConfig
|
||||||
|
total_runs: int
|
||||||
|
successful_runs: int
|
||||||
|
failed_runs: int
|
||||||
|
avg_latency_ms: float
|
||||||
|
avg_tokens: float
|
||||||
|
avg_cost: float
|
||||||
|
results: List[ABTestResult]
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
class ABTest:
|
class ABTest:
|
||||||
def __init__(self, provider: ProviderBase, config: ABTestConfig):
|
"""A/B test runner for comparing prompt variations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: ProviderBase,
|
||||||
|
config: Optional[ABTestConfig] = None,
|
||||||
|
):
|
||||||
|
"""Initialize A/B test runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: LLM provider to use.
|
||||||
|
config: Test configuration.
|
||||||
|
"""
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.config = config
|
self.config = config or ABTestConfig()
|
||||||
self.metrics_collector = MetricsCollector()
|
|
||||||
|
|
||||||
async def run_single(self, prompt: Prompt, variables: Dict[str, str]) -> TestResult:
|
async def run(
|
||||||
test_id = str(uuid.uuid4())[:8]
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
) -> ABTestSummary:
|
||||||
|
"""Run A/B test on a prompt.
|
||||||
|
|
||||||
try:
|
Args:
|
||||||
response = await self.provider.complete(
|
prompt: Prompt to test.
|
||||||
prompt.content.format(**variables) if variables else prompt.content
|
variables: Variables to substitute.
|
||||||
)
|
|
||||||
|
|
||||||
metrics = TestMetrics(
|
Returns:
|
||||||
test_id=test_id,
|
ABTestSummary with all test results.
|
||||||
prompt_name=prompt.name,
|
"""
|
||||||
provider=self.provider.name,
|
results: List[ABTestResult] = []
|
||||||
model=self.provider.model,
|
latencies = []
|
||||||
latency_ms=response.latency_ms,
|
total_tokens = []
|
||||||
success=True,
|
|
||||||
tokens_used=response.usage.get("total_tokens", 0) if response.usage else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TestResult(success=True, response=response.content, metrics=metrics)
|
for i in range(self.config.iterations):
|
||||||
|
try:
|
||||||
|
result = await self._run_single(prompt, variables, i + 1)
|
||||||
|
results.append(result)
|
||||||
|
latencies.append(result.latency_ms)
|
||||||
|
total_tokens.append(result.response.usage.get("total_tokens", 0))
|
||||||
|
except Exception:
|
||||||
|
results.append(ABTestResult(
|
||||||
|
prompt=prompt,
|
||||||
|
response=ProviderResponse(
|
||||||
|
content="",
|
||||||
|
model=prompt.provider or self.provider.name,
|
||||||
|
provider=self.provider.name,
|
||||||
|
),
|
||||||
|
variables=variables,
|
||||||
|
iteration=i + 1,
|
||||||
|
passed_validation=False,
|
||||||
|
validation_errors=["Test execution failed"],
|
||||||
|
))
|
||||||
|
|
||||||
except Exception as e:
|
successful = sum(1 for r in results if r.passed_validation or r.response.content)
|
||||||
metrics = TestMetrics(
|
|
||||||
test_id=test_id,
|
|
||||||
prompt_name=prompt.name,
|
|
||||||
provider=self.provider.name,
|
|
||||||
model=self.provider.model,
|
|
||||||
latency_ms=0,
|
|
||||||
success=False,
|
|
||||||
error_message=str(e),
|
|
||||||
)
|
|
||||||
return TestResult(success=False, response="", metrics=metrics, error=str(e))
|
|
||||||
|
|
||||||
async def run_comparison(self, prompts: List[Prompt]) -> Dict[str, ComparisonResult]:
|
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
||||||
results = {}
|
avg_tokens = sum(total_tokens) / len(total_tokens) if total_tokens else 0
|
||||||
|
|
||||||
|
return ABTestSummary(
|
||||||
|
prompt_name=prompt.name,
|
||||||
|
config=self.config,
|
||||||
|
total_runs=self.config.iterations,
|
||||||
|
successful_runs=successful,
|
||||||
|
failed_runs=self.config.iterations - successful,
|
||||||
|
avg_latency_ms=avg_latency,
|
||||||
|
avg_tokens=avg_tokens,
|
||||||
|
avg_cost=self._estimate_cost(avg_tokens),
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_comparison(
|
||||||
|
self,
|
||||||
|
prompts: List[Prompt],
|
||||||
|
shared_variables: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, ABTestSummary]:
|
||||||
|
"""Run tests on multiple prompts for comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompts to compare.
|
||||||
|
shared_variables: Variables shared across all prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping prompt names to their summaries.
|
||||||
|
"""
|
||||||
|
shared_variables = shared_variables or {}
|
||||||
|
summaries = {}
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
all_metrics: List[TestMetrics] = []
|
variables = self._merge_variables(prompt, shared_variables)
|
||||||
|
summary = await self.run(prompt, variables)
|
||||||
|
summaries[prompt.name] = summary
|
||||||
|
|
||||||
for _ in range(self.config.iterations):
|
return summaries
|
||||||
result = await self.run_single(prompt, {})
|
|
||||||
all_metrics.append(result.metrics)
|
|
||||||
|
|
||||||
comparison = self.metrics_collector.compare(prompt.name, all_metrics)
|
async def _run_single(
|
||||||
results[prompt.name] = comparison
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
iteration: int,
|
||||||
|
) -> ABTestResult:
|
||||||
|
"""Run a single test iteration."""
|
||||||
|
from ..core.template import TemplateEngine
|
||||||
|
template_engine = TemplateEngine()
|
||||||
|
|
||||||
return results
|
rendered = template_engine.render(prompt.content, variables, prompt.variables)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
async def run_tests(self, prompt: Prompt, iterations: Optional[int] = None) -> ComparisonResult:
|
response = await self.provider.complete(
|
||||||
iterations = iterations or self.config.iterations
|
prompt=rendered,
|
||||||
all_metrics: List[TestMetrics] = []
|
max_tokens=self.config.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(iterations):
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
result = await self.run_single(prompt, {})
|
|
||||||
all_metrics.append(result.metrics)
|
|
||||||
|
|
||||||
return self.metrics_collector.compare(prompt.name, all_metrics)
|
return ABTestResult(
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
variables=variables,
|
||||||
|
iteration=iteration,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_variables(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
shared: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge shared variables with prompt-specific ones."""
|
||||||
|
variables = shared.copy()
|
||||||
|
for var in prompt.variables:
|
||||||
|
if var.name not in variables and var.default is not None:
|
||||||
|
variables[var.name] = var.default
|
||||||
|
return variables
|
||||||
|
|
||||||
|
def _estimate_cost(self, tokens: float) -> float:
|
||||||
|
"""Estimate cost based on token usage."""
|
||||||
|
rates = {
|
||||||
|
"gpt-4": 0.00003,
|
||||||
|
"gpt-4-turbo": 0.00001,
|
||||||
|
"gpt-3.5-turbo": 0.0000005,
|
||||||
|
"claude-3-sonnet-20240229": 0.000003,
|
||||||
|
"claude-3-opus-20240229": 0.000015,
|
||||||
|
}
|
||||||
|
rate = rates.get(self.provider.model, 0.000001)
|
||||||
|
return tokens * rate
|
||||||
|
|||||||
Reference in New Issue
Block a user