fix: resolve CI linting and type errors
Some checks are pending
CI / test (push) Has started running
Some checks are pending
CI / test (push) Has started running
This commit is contained in:
81
app/src/promptforge/cli/commands/test.py
Normal file
81
app/src/promptforge/cli/commands/test.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
import click
|
||||||
|
|
||||||
|
from promptforge.core.prompt import Prompt
|
||||||
|
from promptforge.core.config import get_config
|
||||||
|
from promptforge.providers import ProviderFactory
|
||||||
|
from promptforge.testing import ABTest, ABTestConfig
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("prompt_names", nargs=-1, required=True)
|
||||||
|
@click.option("--provider", "-p", help="Provider to use")
|
||||||
|
@click.option("--iterations", "-i", default=3, help="Number of test iterations")
|
||||||
|
@click.option("--output", "-o", type=click.Choice(["text", "json"]), default="text")
|
||||||
|
@click.option("--parallel", is_flag=True, help="Run iterations in parallel")
|
||||||
|
@click.pass_obj
|
||||||
|
def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str, parallel: bool):
|
||||||
|
"""Test prompts with A/B testing."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
selected_prompts = []
|
||||||
|
for name in prompt_names:
|
||||||
|
prompt = next((p for p in prompts if p.name == name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
selected_prompts.append(prompt)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
selected_provider = provider or config.defaults.provider
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||||
|
provider_instance = ProviderFactory.create(
|
||||||
|
selected_provider,
|
||||||
|
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||||
|
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Provider error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
test_config = ABTestConfig(
|
||||||
|
iterations=iterations,
|
||||||
|
parallel=parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
ab_test = ABTest(provider_instance, test_config)
|
||||||
|
|
||||||
|
async def run_tests():
|
||||||
|
results = await ab_test.run_comparison(selected_prompts)
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = asyncio.run(run_tests())
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Test error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
for name, summary in results.items():
|
||||||
|
click.echo(f"\n=== {name} ===")
|
||||||
|
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
||||||
|
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
||||||
|
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
||||||
|
click.echo(f"Avg Cost: ${summary.avg_cost:.4f}")
|
||||||
|
|
||||||
|
if output == "json":
|
||||||
|
import json
|
||||||
|
output_data = {
|
||||||
|
name: {
|
||||||
|
"successful_runs": s.successful_runs,
|
||||||
|
"total_runs": s.total_runs,
|
||||||
|
"avg_latency_ms": s.avg_latency_ms,
|
||||||
|
"avg_tokens": s.avg_tokens,
|
||||||
|
"avg_cost": s.avg_cost,
|
||||||
|
}
|
||||||
|
for name, s in results.items()
|
||||||
|
}
|
||||||
|
click.echo(json.dumps(output_data, indent=2))
|
||||||
Reference in New Issue
Block a user