fix: resolve CI linting and type errors
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:30 +00:00
parent 64cef11c7c
commit 7125b6933d

View File

@@ -1,86 +1,141 @@
"""Metrics collection for A/B testing."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass @dataclass
class TestMetrics: class MetricsSample:
test_id: str """Single metrics sample from a test run."""
prompt_name: str
provider: str timestamp: datetime = field(default_factory=datetime.utcnow)
model: str latency_ms: float = 0.0
latency_ms: float tokens_prompt: int = 0
success: bool tokens_completion: int = 0
tokens_used: int = 0 tokens_total: int = 0
cost_estimate: float = 0.0 cost: float = 0.0
error_message: Optional[str] = None validation_passed: bool = False
validation_errors: List[str] = field(default_factory=list)
custom_metrics: Dict[str, Any] = field(default_factory=dict)
@dataclass @dataclass
class ComparisonResult: class MetricsSummary:
prompt_name: str """Summary statistics for collected metrics."""
total_runs: int
successful_runs: int name: str
failed_runs: int count: int = 0
avg_latency_ms: float latency: Dict[str, float] = field(default_factory=dict)
min_latency_ms: float tokens: Dict[str, float] = field(default_factory=dict)
max_latency_ms: float cost: Dict[str, float] = field(default_factory=dict)
avg_tokens: float validation_pass_rate: float = 0.0
avg_cost: float samples: List[MetricsSample] = field(default_factory=list)
success_rate: float
all_metrics: List[TestMetrics] = field(default_factory=list) @classmethod
def from_samples(cls, name: str, samples: List[MetricsSample]) -> "MetricsSummary":
"""Create summary from list of samples."""
if not samples:
return cls(name=name)
latencies = [s.latency_ms for s in samples]
tokens = [s.tokens_total for s in samples]
costs = [s.cost for s in samples]
valid_count = sum(1 for s in samples if s.validation_passed)
return cls(
name=name,
count=len(samples),
latency={
"min": min(latencies),
"max": max(latencies),
"avg": sum(latencies) / len(latencies),
},
tokens={
"min": min(tokens),
"max": max(tokens),
"avg": sum(tokens) / len(tokens),
},
cost={
"min": min(costs),
"max": max(costs),
"avg": sum(costs) / len(costs),
},
validation_pass_rate=valid_count / len(samples),
samples=samples,
)
class MetricsCollector: class MetricsCollector:
"""Collect and aggregate metrics from test runs."""
def __init__(self): def __init__(self):
self.metrics: List[TestMetrics] = [] """Initialize metrics collector."""
self._samples: List[MetricsSample] = []
def add(self, metrics: TestMetrics) -> None: def record(self, sample: MetricsSample) -> None:
self.metrics.append(metrics) """Record a metrics sample."""
self._samples.append(sample)
def compare(self, prompt_name: str, metrics_list: List[TestMetrics]) -> ComparisonResult: def record_from_response(
if not metrics_list: self,
return ComparisonResult( latency_ms: float,
prompt_name=prompt_name, usage: Dict[str, int],
total_runs=0, validation_passed: bool = False,
successful_runs=0, validation_errors: Optional[List[str]] = None,
failed_runs=0, cost: float = 0.0,
avg_latency_ms=0, custom_metrics: Optional[Dict[str, Any]] = None,
min_latency_ms=0, ) -> MetricsSample:
max_latency_ms=0, """Record metrics from a provider response."""
avg_tokens=0, sample = MetricsSample(
avg_cost=0, latency_ms=latency_ms,
success_rate=0, tokens_prompt=usage.get("prompt_tokens", 0),
tokens_completion=usage.get("completion_tokens", 0),
tokens_total=usage.get("total_tokens", usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)),
cost=cost,
validation_passed=validation_passed,
validation_errors=validation_errors or [],
custom_metrics=custom_metrics or {},
) )
self.record(sample)
return sample
successful = [m for m in metrics_list if m.success] def get_summary(self, name: str = "test") -> MetricsSummary:
failed = [m for m in metrics_list if not m.success] """Get summary of all collected metrics."""
return MetricsSummary.from_samples(name, self._samples)
latencies = [m.latency_ms for m in successful] def clear(self) -> None:
tokens = [m.tokens_used for m in successful] """Clear all collected samples."""
costs = [m.cost_estimate for m in successful] self._samples.clear()
return ComparisonResult( def get_samples(self) -> List[MetricsSample]:
prompt_name=prompt_name, """Get all collected samples."""
total_runs=len(metrics_list), return list(self._samples)
successful_runs=len(successful),
failed_runs=len(failed),
avg_latency_ms=sum(latencies) / len(latencies) if latencies else 0,
min_latency_ms=min(latencies) if latencies else 0,
max_latency_ms=max(latencies) if latencies else 0,
avg_tokens=sum(tokens) / len(tokens) if tokens else 0,
avg_cost=sum(costs) / len(costs) if costs else 0,
success_rate=len(successful) / len(metrics_list) if metrics_list else 0,
all_metrics=metrics_list,
)
def get_summary(self) -> Dict[str, ComparisonResult]: def compare(
by_prompt: Dict[str, List[TestMetrics]] = {} self,
for m in self.metrics: other: "MetricsCollector",
if m.prompt_name not in by_prompt: ) -> Dict[str, Any]:
by_prompt[m.prompt_name] = [] """Compare metrics between two collectors.
by_prompt[m.prompt_name].append(m)
Args:
other: Another metrics collector to compare against.
Returns:
Dictionary with comparison statistics.
"""
summary1 = self.get_summary("a")
summary2 = other.get_summary("b")
return { return {
name: self.compare(name, metrics) "latency_delta_ms": summary2.latency.get("avg", 0) - summary1.latency.get("avg", 0),
for name, metrics in by_prompt.items() "tokens_delta": summary2.tokens.get("avg", 0) - summary1.tokens.get("avg", 0),
"cost_delta": summary2.cost.get("avg", 0) - summary1.cost.get("avg", 0),
"validation_pass_rate_delta": (
summary2.validation_pass_rate - summary1.validation_pass_rate
),
"sample_count": {
"a": summary1.count,
"b": summary2.count,
},
} }