fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
"""Test command for A/B testing prompts."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
import click
|
||||
|
||||
from promptforge.core.prompt import Prompt
|
||||
from promptforge.core.config import get_config
|
||||
from promptforge.providers import ProviderFactory
|
||||
@@ -30,16 +34,21 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
||||
selected_provider = provider or config.defaults.provider
|
||||
|
||||
try:
|
||||
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||
provider_instance = ProviderFactory.create(
|
||||
selected_provider,
|
||||
model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None,
|
||||
temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7,
|
||||
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"Provider error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
test_config = ABTestConfig(iterations=iterations, parallel=parallel)
|
||||
test_config = ABTestConfig(
|
||||
iterations=iterations,
|
||||
parallel=parallel,
|
||||
)
|
||||
|
||||
ab_test = ABTest(provider_instance, test_config)
|
||||
|
||||
async def run_tests():
|
||||
@@ -57,6 +66,7 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
||||
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
||||
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
||||
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
||||
click.echo(f"Avg Cost: ${summary.avg_cost:.4f}")
|
||||
|
||||
if output == "json":
|
||||
import json
|
||||
@@ -66,7 +76,8 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
||||
"total_runs": s.total_runs,
|
||||
"avg_latency_ms": s.avg_latency_ms,
|
||||
"avg_tokens": s.avg_tokens,
|
||||
"avg_cost": s.avg_cost,
|
||||
}
|
||||
for name, s in results.items()
|
||||
}
|
||||
click.echo(json.dumps(output_data, indent=2))
|
||||
click.echo(json.dumps(output_data, indent=2))
|
||||
|
||||
Reference in New Issue
Block a user