fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,86 +1,141 @@
|
||||
"""Metrics collection for A/B testing."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMetrics:
|
||||
test_id: str
|
||||
prompt_name: str
|
||||
provider: str
|
||||
model: str
|
||||
latency_ms: float
|
||||
success: bool
|
||||
tokens_used: int = 0
|
||||
cost_estimate: float = 0.0
|
||||
error_message: Optional[str] = None
|
||||
class MetricsSample:
|
||||
"""Single metrics sample from a test run."""
|
||||
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
latency_ms: float = 0.0
|
||||
tokens_prompt: int = 0
|
||||
tokens_completion: int = 0
|
||||
tokens_total: int = 0
|
||||
cost: float = 0.0
|
||||
validation_passed: bool = False
|
||||
validation_errors: List[str] = field(default_factory=list)
|
||||
custom_metrics: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComparisonResult:
|
||||
prompt_name: str
|
||||
total_runs: int
|
||||
successful_runs: int
|
||||
failed_runs: int
|
||||
avg_latency_ms: float
|
||||
min_latency_ms: float
|
||||
max_latency_ms: float
|
||||
avg_tokens: float
|
||||
avg_cost: float
|
||||
success_rate: float
|
||||
all_metrics: List[TestMetrics] = field(default_factory=list)
|
||||
class MetricsSummary:
|
||||
"""Summary statistics for collected metrics."""
|
||||
|
||||
name: str
|
||||
count: int = 0
|
||||
latency: Dict[str, float] = field(default_factory=dict)
|
||||
tokens: Dict[str, float] = field(default_factory=dict)
|
||||
cost: Dict[str, float] = field(default_factory=dict)
|
||||
validation_pass_rate: float = 0.0
|
||||
samples: List[MetricsSample] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_samples(cls, name: str, samples: List[MetricsSample]) -> "MetricsSummary":
|
||||
"""Create summary from list of samples."""
|
||||
if not samples:
|
||||
return cls(name=name)
|
||||
|
||||
latencies = [s.latency_ms for s in samples]
|
||||
tokens = [s.tokens_total for s in samples]
|
||||
costs = [s.cost for s in samples]
|
||||
valid_count = sum(1 for s in samples if s.validation_passed)
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
count=len(samples),
|
||||
latency={
|
||||
"min": min(latencies),
|
||||
"max": max(latencies),
|
||||
"avg": sum(latencies) / len(latencies),
|
||||
},
|
||||
tokens={
|
||||
"min": min(tokens),
|
||||
"max": max(tokens),
|
||||
"avg": sum(tokens) / len(tokens),
|
||||
},
|
||||
cost={
|
||||
"min": min(costs),
|
||||
"max": max(costs),
|
||||
"avg": sum(costs) / len(costs),
|
||||
},
|
||||
validation_pass_rate=valid_count / len(samples),
|
||||
samples=samples,
|
||||
)
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collect and aggregate metrics from test runs."""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics: List[TestMetrics] = []
|
||||
"""Initialize metrics collector."""
|
||||
self._samples: List[MetricsSample] = []
|
||||
|
||||
def add(self, metrics: TestMetrics) -> None:
|
||||
self.metrics.append(metrics)
|
||||
def record(self, sample: MetricsSample) -> None:
|
||||
"""Record a metrics sample."""
|
||||
self._samples.append(sample)
|
||||
|
||||
def compare(self, prompt_name: str, metrics_list: List[TestMetrics]) -> ComparisonResult:
|
||||
if not metrics_list:
|
||||
return ComparisonResult(
|
||||
prompt_name=prompt_name,
|
||||
total_runs=0,
|
||||
successful_runs=0,
|
||||
failed_runs=0,
|
||||
avg_latency_ms=0,
|
||||
min_latency_ms=0,
|
||||
max_latency_ms=0,
|
||||
avg_tokens=0,
|
||||
avg_cost=0,
|
||||
success_rate=0,
|
||||
)
|
||||
|
||||
successful = [m for m in metrics_list if m.success]
|
||||
failed = [m for m in metrics_list if not m.success]
|
||||
|
||||
latencies = [m.latency_ms for m in successful]
|
||||
tokens = [m.tokens_used for m in successful]
|
||||
costs = [m.cost_estimate for m in successful]
|
||||
|
||||
return ComparisonResult(
|
||||
prompt_name=prompt_name,
|
||||
total_runs=len(metrics_list),
|
||||
successful_runs=len(successful),
|
||||
failed_runs=len(failed),
|
||||
avg_latency_ms=sum(latencies) / len(latencies) if latencies else 0,
|
||||
min_latency_ms=min(latencies) if latencies else 0,
|
||||
max_latency_ms=max(latencies) if latencies else 0,
|
||||
avg_tokens=sum(tokens) / len(tokens) if tokens else 0,
|
||||
avg_cost=sum(costs) / len(costs) if costs else 0,
|
||||
success_rate=len(successful) / len(metrics_list) if metrics_list else 0,
|
||||
all_metrics=metrics_list,
|
||||
def record_from_response(
|
||||
self,
|
||||
latency_ms: float,
|
||||
usage: Dict[str, int],
|
||||
validation_passed: bool = False,
|
||||
validation_errors: Optional[List[str]] = None,
|
||||
cost: float = 0.0,
|
||||
custom_metrics: Optional[Dict[str, Any]] = None,
|
||||
) -> MetricsSample:
|
||||
"""Record metrics from a provider response."""
|
||||
sample = MetricsSample(
|
||||
latency_ms=latency_ms,
|
||||
tokens_prompt=usage.get("prompt_tokens", 0),
|
||||
tokens_completion=usage.get("completion_tokens", 0),
|
||||
tokens_total=usage.get("total_tokens", usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)),
|
||||
cost=cost,
|
||||
validation_passed=validation_passed,
|
||||
validation_errors=validation_errors or [],
|
||||
custom_metrics=custom_metrics or {},
|
||||
)
|
||||
self.record(sample)
|
||||
return sample
|
||||
|
||||
def get_summary(self) -> Dict[str, ComparisonResult]:
|
||||
by_prompt: Dict[str, List[TestMetrics]] = {}
|
||||
for m in self.metrics:
|
||||
if m.prompt_name not in by_prompt:
|
||||
by_prompt[m.prompt_name] = []
|
||||
by_prompt[m.prompt_name].append(m)
|
||||
def get_summary(self, name: str = "test") -> MetricsSummary:
|
||||
"""Get summary of all collected metrics."""
|
||||
return MetricsSummary.from_samples(name, self._samples)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all collected samples."""
|
||||
self._samples.clear()
|
||||
|
||||
def get_samples(self) -> List[MetricsSample]:
|
||||
"""Get all collected samples."""
|
||||
return list(self._samples)
|
||||
|
||||
def compare(
|
||||
self,
|
||||
other: "MetricsCollector",
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare metrics between two collectors.
|
||||
|
||||
Args:
|
||||
other: Another metrics collector to compare against.
|
||||
|
||||
Returns:
|
||||
Dictionary with comparison statistics.
|
||||
"""
|
||||
summary1 = self.get_summary("a")
|
||||
summary2 = other.get_summary("b")
|
||||
|
||||
return {
|
||||
name: self.compare(name, metrics)
|
||||
for name, metrics in by_prompt.items()
|
||||
"latency_delta_ms": summary2.latency.get("avg", 0) - summary1.latency.get("avg", 0),
|
||||
"tokens_delta": summary2.tokens.get("avg", 0) - summary1.tokens.get("avg", 0),
|
||||
"cost_delta": summary2.cost.get("avg", 0) - summary1.cost.get("avg", 0),
|
||||
"validation_pass_rate_delta": (
|
||||
summary2.validation_pass_rate - summary1.validation_pass_rate
|
||||
),
|
||||
"sample_count": {
|
||||
"a": summary1.count,
|
||||
"b": summary2.count,
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user