diff --git a/src/promptforge/cli/commands/run.py b/src/promptforge/cli/commands/run.py index 024679f..2f54405 100644 --- a/src/promptforge/cli/commands/run.py +++ b/src/promptforge/cli/commands/run.py @@ -1,11 +1,13 @@ +"""Run command for executing prompts.""" + import asyncio +from typing import Any, Dict import click -from pathlib import Path + from promptforge.core.prompt import Prompt from promptforge.core.template import TemplateEngine from promptforge.core.config import get_config from promptforge.providers import ProviderFactory -from promptforge.testing.validator import Validator @click.command() @@ -33,7 +35,11 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool): template_engine = TemplateEngine() try: - rendered = template_engine.render(prompt.content, variables, prompt.variables) + rendered = template_engine.render( + prompt.content, + variables, + prompt.variables, + ) except Exception as e: click.echo(f"Template error: {e}", err=True) raise click.Abort() @@ -42,10 +48,11 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool): selected_provider = provider or prompt.provider or config.defaults.provider try: + provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {})) provider_instance = ProviderFactory.create( selected_provider, - model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None, - temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7, + model=provider_config.get("model") if isinstance(provider_config, dict) else None, + temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7, ) except Exception as e: click.echo(f"Provider error: {e}", err=True) @@ -67,18 +74,26 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool): import json click.echo("\n" + json.dumps({"response": response}, indent=2)) - asyncio.run(execute()) + if prompt.validation_rules: + validate_response(prompt, response) + + try: + asyncio.run(execute()) + except Exception as e: + click.echo(f"Execution error: {e}", err=True) + raise click.Abort() def validate_response(prompt: Prompt, response: str): + """Validate response against rules.""" for rule in prompt.validation_rules: if rule.type == "regex": import re if not re.search(rule.pattern or "", response): - click.echo(f"Warning: Response failed regex validation", err=True) + click.echo("Warning: Response failed regex validation", err=True) elif rule.type == "json": try: import json json.loads(response) except json.JSONDecodeError: - click.echo(f"Warning: Response is not valid JSON", err=True) \ No newline at end of file + click.echo("Warning: Response is not valid JSON", err=True)