fix: resolve CI linting and type errors
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:30 +00:00
parent 64cef11c7c
commit 7125b6933d

View File

@@ -1,86 +1,141 @@
"""Metrics collection for A/B testing."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class TestMetrics:
test_id: str
prompt_name: str
provider: str
model: str
latency_ms: float
success: bool
tokens_used: int = 0
cost_estimate: float = 0.0
error_message: Optional[str] = None
class MetricsSample:
"""Single metrics sample from a test run."""
timestamp: datetime = field(default_factory=datetime.utcnow)
latency_ms: float = 0.0
tokens_prompt: int = 0
tokens_completion: int = 0
tokens_total: int = 0
cost: float = 0.0
validation_passed: bool = False
validation_errors: List[str] = field(default_factory=list)
custom_metrics: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ComparisonResult:
prompt_name: str
total_runs: int
successful_runs: int
failed_runs: int
avg_latency_ms: float
min_latency_ms: float
max_latency_ms: float
avg_tokens: float
avg_cost: float
success_rate: float
all_metrics: List[TestMetrics] = field(default_factory=list)
class MetricsSummary:
"""Summary statistics for collected metrics."""
name: str
count: int = 0
latency: Dict[str, float] = field(default_factory=dict)
tokens: Dict[str, float] = field(default_factory=dict)
cost: Dict[str, float] = field(default_factory=dict)
validation_pass_rate: float = 0.0
samples: List[MetricsSample] = field(default_factory=list)
@classmethod
def from_samples(cls, name: str, samples: List[MetricsSample]) -> "MetricsSummary":
"""Create summary from list of samples."""
if not samples:
return cls(name=name)
latencies = [s.latency_ms for s in samples]
tokens = [s.tokens_total for s in samples]
costs = [s.cost for s in samples]
valid_count = sum(1 for s in samples if s.validation_passed)
return cls(
name=name,
count=len(samples),
latency={
"min": min(latencies),
"max": max(latencies),
"avg": sum(latencies) / len(latencies),
},
tokens={
"min": min(tokens),
"max": max(tokens),
"avg": sum(tokens) / len(tokens),
},
cost={
"min": min(costs),
"max": max(costs),
"avg": sum(costs) / len(costs),
},
validation_pass_rate=valid_count / len(samples),
samples=samples,
)
class MetricsCollector:
"""Collect and aggregate metrics from test runs."""
def __init__(self):
self.metrics: List[TestMetrics] = []
"""Initialize metrics collector."""
self._samples: List[MetricsSample] = []
def add(self, metrics: TestMetrics) -> None:
self.metrics.append(metrics)
def record(self, sample: MetricsSample) -> None:
"""Record a metrics sample."""
self._samples.append(sample)
def compare(self, prompt_name: str, metrics_list: List[TestMetrics]) -> ComparisonResult:
if not metrics_list:
return ComparisonResult(
prompt_name=prompt_name,
total_runs=0,
successful_runs=0,
failed_runs=0,
avg_latency_ms=0,
min_latency_ms=0,
max_latency_ms=0,
avg_tokens=0,
avg_cost=0,
success_rate=0,
)
successful = [m for m in metrics_list if m.success]
failed = [m for m in metrics_list if not m.success]
latencies = [m.latency_ms for m in successful]
tokens = [m.tokens_used for m in successful]
costs = [m.cost_estimate for m in successful]
return ComparisonResult(
prompt_name=prompt_name,
total_runs=len(metrics_list),
successful_runs=len(successful),
failed_runs=len(failed),
avg_latency_ms=sum(latencies) / len(latencies) if latencies else 0,
min_latency_ms=min(latencies) if latencies else 0,
max_latency_ms=max(latencies) if latencies else 0,
avg_tokens=sum(tokens) / len(tokens) if tokens else 0,
avg_cost=sum(costs) / len(costs) if costs else 0,
success_rate=len(successful) / len(metrics_list) if metrics_list else 0,
all_metrics=metrics_list,
def record_from_response(
self,
latency_ms: float,
usage: Dict[str, int],
validation_passed: bool = False,
validation_errors: Optional[List[str]] = None,
cost: float = 0.0,
custom_metrics: Optional[Dict[str, Any]] = None,
) -> MetricsSample:
"""Record metrics from a provider response."""
sample = MetricsSample(
latency_ms=latency_ms,
tokens_prompt=usage.get("prompt_tokens", 0),
tokens_completion=usage.get("completion_tokens", 0),
tokens_total=usage.get("total_tokens", usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)),
cost=cost,
validation_passed=validation_passed,
validation_errors=validation_errors or [],
custom_metrics=custom_metrics or {},
)
self.record(sample)
return sample
def get_summary(self) -> Dict[str, ComparisonResult]:
by_prompt: Dict[str, List[TestMetrics]] = {}
for m in self.metrics:
if m.prompt_name not in by_prompt:
by_prompt[m.prompt_name] = []
by_prompt[m.prompt_name].append(m)
def get_summary(self, name: str = "test") -> MetricsSummary:
"""Get summary of all collected metrics."""
return MetricsSummary.from_samples(name, self._samples)
def clear(self) -> None:
"""Clear all collected samples."""
self._samples.clear()
def get_samples(self) -> List[MetricsSample]:
"""Get all collected samples."""
return list(self._samples)
def compare(
self,
other: "MetricsCollector",
) -> Dict[str, Any]:
"""Compare metrics between two collectors.
Args:
other: Another metrics collector to compare against.
Returns:
Dictionary with comparison statistics.
"""
summary1 = self.get_summary("a")
summary2 = other.get_summary("b")
return {
name: self.compare(name, metrics)
for name, metrics in by_prompt.items()
}
"latency_delta_ms": summary2.latency.get("avg", 0) - summary1.latency.get("avg", 0),
"tokens_delta": summary2.tokens.get("avg", 0) - summary1.tokens.get("avg", 0),
"cost_delta": summary2.cost.get("avg", 0) - summary1.cost.get("avg", 0),
"validation_pass_rate_delta": (
summary2.validation_pass_rate - summary1.validation_pass_rate
),
"sample_count": {
"a": summary1.count,
"b": summary2.count,
},
}