fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,5 +1,9 @@
|
|||||||
|
"""Test command for A/B testing prompts."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from promptforge.core.prompt import Prompt
|
from promptforge.core.prompt import Prompt
|
||||||
from promptforge.core.config import get_config
|
from promptforge.core.config import get_config
|
||||||
from promptforge.providers import ProviderFactory
|
from promptforge.providers import ProviderFactory
|
||||||
@@ -30,16 +34,21 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
|||||||
selected_provider = provider or config.defaults.provider
|
selected_provider = provider or config.defaults.provider
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||||
provider_instance = ProviderFactory.create(
|
provider_instance = ProviderFactory.create(
|
||||||
selected_provider,
|
selected_provider,
|
||||||
model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None,
|
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||||
temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7,
|
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Provider error: {e}", err=True)
|
click.echo(f"Provider error: {e}", err=True)
|
||||||
raise click.Abort()
|
raise click.Abort()
|
||||||
|
|
||||||
test_config = ABTestConfig(iterations=iterations, parallel=parallel)
|
test_config = ABTestConfig(
|
||||||
|
iterations=iterations,
|
||||||
|
parallel=parallel,
|
||||||
|
)
|
||||||
|
|
||||||
ab_test = ABTest(provider_instance, test_config)
|
ab_test = ABTest(provider_instance, test_config)
|
||||||
|
|
||||||
async def run_tests():
|
async def run_tests():
|
||||||
@@ -57,6 +66,7 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
|||||||
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
||||||
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
||||||
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
||||||
|
click.echo(f"Avg Cost: ${summary.avg_cost:.4f}")
|
||||||
|
|
||||||
if output == "json":
|
if output == "json":
|
||||||
import json
|
import json
|
||||||
@@ -66,7 +76,8 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
|||||||
"total_runs": s.total_runs,
|
"total_runs": s.total_runs,
|
||||||
"avg_latency_ms": s.avg_latency_ms,
|
"avg_latency_ms": s.avg_latency_ms,
|
||||||
"avg_tokens": s.avg_tokens,
|
"avg_tokens": s.avg_tokens,
|
||||||
|
"avg_cost": s.avg_cost,
|
||||||
}
|
}
|
||||||
for name, s in results.items()
|
for name, s in results.items()
|
||||||
}
|
}
|
||||||
click.echo(json.dumps(output_data, indent=2))
|
click.echo(json.dumps(output_data, indent=2))
|
||||||
|
|||||||
Reference in New Issue
Block a user