fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:04 +00:00
parent 326d82e2d8
commit 508e1e8261

View File

@@ -1,5 +1,9 @@
"""Test command for A/B testing prompts."""
import asyncio import asyncio
from typing import Any, Dict
import click import click
from promptforge.core.prompt import Prompt from promptforge.core.prompt import Prompt
from promptforge.core.config import get_config from promptforge.core.config import get_config
from promptforge.providers import ProviderFactory from promptforge.providers import ProviderFactory
@@ -30,16 +34,21 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
selected_provider = provider or config.defaults.provider selected_provider = provider or config.defaults.provider
try: try:
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
provider_instance = ProviderFactory.create( provider_instance = ProviderFactory.create(
selected_provider, selected_provider,
model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None, model=provider_config.get("model") if isinstance(provider_config, dict) else None,
temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7, temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
) )
except Exception as e: except Exception as e:
click.echo(f"Provider error: {e}", err=True) click.echo(f"Provider error: {e}", err=True)
raise click.Abort() raise click.Abort()
test_config = ABTestConfig(iterations=iterations, parallel=parallel) test_config = ABTestConfig(
iterations=iterations,
parallel=parallel,
)
ab_test = ABTest(provider_instance, test_config) ab_test = ABTest(provider_instance, test_config)
async def run_tests(): async def run_tests():
@@ -57,6 +66,7 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}") click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms") click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}") click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
click.echo(f"Avg Cost: ${summary.avg_cost:.4f}")
if output == "json": if output == "json":
import json import json
@@ -66,7 +76,8 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
"total_runs": s.total_runs, "total_runs": s.total_runs,
"avg_latency_ms": s.avg_latency_ms, "avg_latency_ms": s.avg_latency_ms,
"avg_tokens": s.avg_tokens, "avg_tokens": s.avg_tokens,
"avg_cost": s.avg_cost,
} }
for name, s in results.items() for name, s in results.items()
} }
click.echo(json.dumps(output_data, indent=2)) click.echo(json.dumps(output_data, indent=2))