From 2a6449c20e18e78fa78ea3186f1e40373ca67f2b Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:49:00 +0000 Subject: [PATCH] fix: resolve CI linting and type errors --- app/src/promptforge/cli/commands/test.py | 81 ++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 app/src/promptforge/cli/commands/test.py diff --git a/app/src/promptforge/cli/commands/test.py b/app/src/promptforge/cli/commands/test.py new file mode 100644 index 0000000..549f008 --- /dev/null +++ b/app/src/promptforge/cli/commands/test.py @@ -0,0 +1,81 @@ +import asyncio +from typing import Any, Dict +import click + +from promptforge.core.prompt import Prompt +from promptforge.core.config import get_config +from promptforge.providers import ProviderFactory +from promptforge.testing import ABTest, ABTestConfig + + +@click.command() +@click.argument("prompt_names", nargs=-1, required=True) +@click.option("--provider", "-p", help="Provider to use") +@click.option("--iterations", "-i", default=3, help="Number of test iterations") +@click.option("--output", "-o", type=click.Choice(["text", "json"]), default="text") +@click.option("--parallel", is_flag=True, help="Run iterations in parallel") +@click.pass_obj +def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str, parallel: bool): + """Test prompts with A/B testing.""" + prompts_dir = ctx["prompts_dir"] + prompts = Prompt.list(prompts_dir) + + selected_prompts = [] + for name in prompt_names: + prompt = next((p for p in prompts if p.name == name), None) + if not prompt: + click.echo(f"Prompt '{name}' not found", err=True) + raise click.Abort() + selected_prompts.append(prompt) + + config = get_config() + selected_provider = provider or config.defaults.provider + + try: + provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {})) + provider_instance = ProviderFactory.create( + selected_provider, + model=provider_config.get("model") if isinstance(provider_config, dict) else None, + temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7, + ) + except Exception as e: + click.echo(f"Provider error: {e}", err=True) + raise click.Abort() + + test_config = ABTestConfig( + iterations=iterations, + parallel=parallel, + ) + + ab_test = ABTest(provider_instance, test_config) + + async def run_tests(): + results = await ab_test.run_comparison(selected_prompts) + return results + + try: + results = asyncio.run(run_tests()) + except Exception as e: + click.echo(f"Test error: {e}", err=True) + raise click.Abort() + + for name, summary in results.items(): + click.echo(f"\n=== {name} ===") + click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}") + click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms") + click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}") + click.echo(f"Avg Cost: ${summary.avg_cost:.4f}") + + if output == "json": + import json + output_data = { + name: { + "successful_runs": s.successful_runs, + "total_runs": s.total_runs, + "avg_latency_ms": s.avg_latency_ms, + "avg_tokens": s.avg_tokens, + "avg_cost": s.avg_cost, + } + for name, s in results.items() + } + click.echo(json.dumps(output_data, indent=2))