diff --git a/app/src/promptforge/testing/ab_test.py b/app/src/promptforge/testing/ab_test.py new file mode 100644 index 0000000..fac96c0 --- /dev/null +++ b/app/src/promptforge/testing/ab_test.py @@ -0,0 +1,198 @@ +import time +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional + +from ..core.prompt import Prompt +from ..providers import ProviderBase, ProviderResponse + + +@dataclass +class ABTestConfig: + """Configuration for A/B test.""" + + iterations: int = 3 + provider: Optional[str] = None + max_tokens: Optional[int] = None + temperature: float = 0.7 + parallel: bool = False + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ABTestResult: + """Result of a single test run.""" + + prompt: Prompt + response: ProviderResponse + variables: Dict[str, Any] + iteration: int + passed_validation: bool = False + validation_errors: List[str] = field(default_factory=list) + latency_ms: float = 0.0 + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class ABTestSummary: + """Summary of A/B test results.""" + + prompt_name: str + config: ABTestConfig + total_runs: int + successful_runs: int + failed_runs: int + avg_latency_ms: float + avg_tokens: float + avg_cost: float + results: List[ABTestResult] + timestamp: datetime = field(default_factory=datetime.utcnow) + + +class ABTest: + """A/B test runner for comparing prompt variations.""" + + def __init__( + self, + provider: ProviderBase, + config: Optional[ABTestConfig] = None, + ): + """Initialize A/B test runner. + + Args: + provider: LLM provider to use. + config: Test configuration. + """ + self.provider = provider + self.config = config or ABTestConfig() + + async def run( + self, + prompt: Prompt, + variables: Dict[str, Any], + ) -> ABTestSummary: + """Run A/B test on a prompt. + + Args: + prompt: Prompt to test. + variables: Variables to substitute. + + Returns: + ABTestSummary with all test results. + """ + results: List[ABTestResult] = [] + latencies = [] + total_tokens = [] + + for i in range(self.config.iterations): + try: + result = await self._run_single(prompt, variables, i + 1) + results.append(result) + latencies.append(result.latency_ms) + total_tokens.append(result.response.usage.get("total_tokens", 0)) + except Exception: + results.append(ABTestResult( + prompt=prompt, + response=ProviderResponse( + content="", + model=prompt.provider or self.provider.name, + provider=self.provider.name, + ), + variables=variables, + iteration=i + 1, + passed_validation=False, + validation_errors=["Test execution failed"], + )) + + successful = sum(1 for r in results if r.passed_validation or r.response.content) + + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + avg_tokens = sum(total_tokens) / len(total_tokens) if total_tokens else 0 + + return ABTestSummary( + prompt_name=prompt.name, + config=self.config, + total_runs=self.config.iterations, + successful_runs=successful, + failed_runs=self.config.iterations - successful, + avg_latency_ms=avg_latency, + avg_tokens=avg_tokens, + avg_cost=self._estimate_cost(avg_tokens), + results=results, + ) + + async def run_comparison( + self, + prompts: List[Prompt], + shared_variables: Optional[Dict[str, Any]] = None, + ) -> Dict[str, ABTestSummary]: + """Run tests on multiple prompts for comparison. + + Args: + prompts: List of prompts to compare. + shared_variables: Variables shared across all prompts. + + Returns: + Dictionary mapping prompt names to their summaries. + """ + shared_variables = shared_variables or {} + summaries = {} + + for prompt in prompts: + variables = self._merge_variables(prompt, shared_variables) + summary = await self.run(prompt, variables) + summaries[prompt.name] = summary + + return summaries + + async def _run_single( + self, + prompt: Prompt, + variables: Dict[str, Any], + iteration: int, + ) -> ABTestResult: + """Run a single test iteration.""" + from ..core.template import TemplateEngine + template_engine = TemplateEngine() + + rendered = template_engine.render(prompt.content, variables, prompt.variables) + start_time = time.time() + + response = await self.provider.complete( + prompt=rendered, + max_tokens=self.config.max_tokens, + ) + + latency_ms = (time.time() - start_time) * 1000 + + return ABTestResult( + prompt=prompt, + response=response, + variables=variables, + iteration=iteration, + latency_ms=latency_ms, + ) + + def _merge_variables( + self, + prompt: Prompt, + shared: Dict[str, Any], + ) -> Dict[str, Any]: + """Merge shared variables with prompt-specific ones.""" + variables = shared.copy() + for var in prompt.variables: + if var.name not in variables and var.default is not None: + variables[var.name] = var.default + return variables + + def _estimate_cost(self, tokens: float) -> float: + """Estimate cost based on token usage.""" + rates = { + "gpt-4": 0.00003, + "gpt-4-turbo": 0.00001, + "gpt-3.5-turbo": 0.0000005, + "claude-3-sonnet-20240229": 0.000003, + "claude-3-opus-20240229": 0.000015, + } + rate = rates.get(self.provider.model, 0.000001) + return tokens * rate