fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""Run command for executing prompts."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
import click
|
||||
from pathlib import Path
|
||||
|
||||
from promptforge.core.prompt import Prompt
|
||||
from promptforge.core.template import TemplateEngine
|
||||
from promptforge.core.config import get_config
|
||||
from promptforge.providers import ProviderFactory
|
||||
from promptforge.testing.validator import Validator
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -33,7 +35,11 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
||||
|
||||
template_engine = TemplateEngine()
|
||||
try:
|
||||
rendered = template_engine.render(prompt.content, variables, prompt.variables)
|
||||
rendered = template_engine.render(
|
||||
prompt.content,
|
||||
variables,
|
||||
prompt.variables,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"Template error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
@@ -42,10 +48,11 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
||||
selected_provider = provider or prompt.provider or config.defaults.provider
|
||||
|
||||
try:
|
||||
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||
provider_instance = ProviderFactory.create(
|
||||
selected_provider,
|
||||
model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None,
|
||||
temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7,
|
||||
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"Provider error: {e}", err=True)
|
||||
@@ -67,18 +74,26 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
||||
import json
|
||||
click.echo("\n" + json.dumps({"response": response}, indent=2))
|
||||
|
||||
asyncio.run(execute())
|
||||
if prompt.validation_rules:
|
||||
validate_response(prompt, response)
|
||||
|
||||
try:
|
||||
asyncio.run(execute())
|
||||
except Exception as e:
|
||||
click.echo(f"Execution error: {e}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def validate_response(prompt: Prompt, response: str):
|
||||
"""Validate response against rules."""
|
||||
for rule in prompt.validation_rules:
|
||||
if rule.type == "regex":
|
||||
import re
|
||||
if not re.search(rule.pattern or "", response):
|
||||
click.echo(f"Warning: Response failed regex validation", err=True)
|
||||
click.echo("Warning: Response failed regex validation", err=True)
|
||||
elif rule.type == "json":
|
||||
try:
|
||||
import json
|
||||
json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(f"Warning: Response is not valid JSON", err=True)
|
||||
click.echo("Warning: Response is not valid JSON", err=True)
|
||||
|
||||
Reference in New Issue
Block a user