From 7125b6933d6cd4739f404ed8e04595cba95b1866 Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:58:30 +0000 Subject: [PATCH] fix: resolve CI linting and type errors --- src/promptforge/testing/metrics.py | 193 ++++++++++++++++++----------- 1 file changed, 124 insertions(+), 69 deletions(-) diff --git a/src/promptforge/testing/metrics.py b/src/promptforge/testing/metrics.py index 8076788..7a14e69 100644 --- a/src/promptforge/testing/metrics.py +++ b/src/promptforge/testing/metrics.py @@ -1,86 +1,141 @@ +"""Metrics collection for A/B testing.""" + from dataclasses import dataclass, field -from typing import Dict, List, Optional +from datetime import datetime +from typing import Any, Dict, List, Optional @dataclass -class TestMetrics: - test_id: str - prompt_name: str - provider: str - model: str - latency_ms: float - success: bool - tokens_used: int = 0 - cost_estimate: float = 0.0 - error_message: Optional[str] = None +class MetricsSample: + """Single metrics sample from a test run.""" + + timestamp: datetime = field(default_factory=datetime.utcnow) + latency_ms: float = 0.0 + tokens_prompt: int = 0 + tokens_completion: int = 0 + tokens_total: int = 0 + cost: float = 0.0 + validation_passed: bool = False + validation_errors: List[str] = field(default_factory=list) + custom_metrics: Dict[str, Any] = field(default_factory=dict) @dataclass -class ComparisonResult: - prompt_name: str - total_runs: int - successful_runs: int - failed_runs: int - avg_latency_ms: float - min_latency_ms: float - max_latency_ms: float - avg_tokens: float - avg_cost: float - success_rate: float - all_metrics: List[TestMetrics] = field(default_factory=list) +class MetricsSummary: + """Summary statistics for collected metrics.""" + + name: str + count: int = 0 + latency: Dict[str, float] = field(default_factory=dict) + tokens: Dict[str, float] = field(default_factory=dict) + cost: Dict[str, float] = field(default_factory=dict) + validation_pass_rate: float = 0.0 + samples: List[MetricsSample] = field(default_factory=list) + + @classmethod + def from_samples(cls, name: str, samples: List[MetricsSample]) -> "MetricsSummary": + """Create summary from list of samples.""" + if not samples: + return cls(name=name) + + latencies = [s.latency_ms for s in samples] + tokens = [s.tokens_total for s in samples] + costs = [s.cost for s in samples] + valid_count = sum(1 for s in samples if s.validation_passed) + + return cls( + name=name, + count=len(samples), + latency={ + "min": min(latencies), + "max": max(latencies), + "avg": sum(latencies) / len(latencies), + }, + tokens={ + "min": min(tokens), + "max": max(tokens), + "avg": sum(tokens) / len(tokens), + }, + cost={ + "min": min(costs), + "max": max(costs), + "avg": sum(costs) / len(costs), + }, + validation_pass_rate=valid_count / len(samples), + samples=samples, + ) class MetricsCollector: + """Collect and aggregate metrics from test runs.""" + def __init__(self): - self.metrics: List[TestMetrics] = [] + """Initialize metrics collector.""" + self._samples: List[MetricsSample] = [] - def add(self, metrics: TestMetrics) -> None: - self.metrics.append(metrics) + def record(self, sample: MetricsSample) -> None: + """Record a metrics sample.""" + self._samples.append(sample) - def compare(self, prompt_name: str, metrics_list: List[TestMetrics]) -> ComparisonResult: - if not metrics_list: - return ComparisonResult( - prompt_name=prompt_name, - total_runs=0, - successful_runs=0, - failed_runs=0, - avg_latency_ms=0, - min_latency_ms=0, - max_latency_ms=0, - avg_tokens=0, - avg_cost=0, - success_rate=0, - ) - - successful = [m for m in metrics_list if m.success] - failed = [m for m in metrics_list if not m.success] - - latencies = [m.latency_ms for m in successful] - tokens = [m.tokens_used for m in successful] - costs = [m.cost_estimate for m in successful] - - return ComparisonResult( - prompt_name=prompt_name, - total_runs=len(metrics_list), - successful_runs=len(successful), - failed_runs=len(failed), - avg_latency_ms=sum(latencies) / len(latencies) if latencies else 0, - min_latency_ms=min(latencies) if latencies else 0, - max_latency_ms=max(latencies) if latencies else 0, - avg_tokens=sum(tokens) / len(tokens) if tokens else 0, - avg_cost=sum(costs) / len(costs) if costs else 0, - success_rate=len(successful) / len(metrics_list) if metrics_list else 0, - all_metrics=metrics_list, + def record_from_response( + self, + latency_ms: float, + usage: Dict[str, int], + validation_passed: bool = False, + validation_errors: Optional[List[str]] = None, + cost: float = 0.0, + custom_metrics: Optional[Dict[str, Any]] = None, + ) -> MetricsSample: + """Record metrics from a provider response.""" + sample = MetricsSample( + latency_ms=latency_ms, + tokens_prompt=usage.get("prompt_tokens", 0), + tokens_completion=usage.get("completion_tokens", 0), + tokens_total=usage.get("total_tokens", usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)), + cost=cost, + validation_passed=validation_passed, + validation_errors=validation_errors or [], + custom_metrics=custom_metrics or {}, ) + self.record(sample) + return sample - def get_summary(self) -> Dict[str, ComparisonResult]: - by_prompt: Dict[str, List[TestMetrics]] = {} - for m in self.metrics: - if m.prompt_name not in by_prompt: - by_prompt[m.prompt_name] = [] - by_prompt[m.prompt_name].append(m) + def get_summary(self, name: str = "test") -> MetricsSummary: + """Get summary of all collected metrics.""" + return MetricsSummary.from_samples(name, self._samples) + + def clear(self) -> None: + """Clear all collected samples.""" + self._samples.clear() + + def get_samples(self) -> List[MetricsSample]: + """Get all collected samples.""" + return list(self._samples) + + def compare( + self, + other: "MetricsCollector", + ) -> Dict[str, Any]: + """Compare metrics between two collectors. + + Args: + other: Another metrics collector to compare against. + + Returns: + Dictionary with comparison statistics. + """ + summary1 = self.get_summary("a") + summary2 = other.get_summary("b") return { - name: self.compare(name, metrics) - for name, metrics in by_prompt.items() - } \ No newline at end of file + "latency_delta_ms": summary2.latency.get("avg", 0) - summary1.latency.get("avg", 0), + "tokens_delta": summary2.tokens.get("avg", 0) - summary1.tokens.get("avg", 0), + "cost_delta": summary2.cost.get("avg", 0) - summary1.cost.get("avg", 0), + "validation_pass_rate_delta": ( + summary2.validation_pass_rate - summary1.validation_pass_rate + ), + "sample_count": { + "a": summary1.count, + "b": summary2.count, + }, + }