Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d75b6af809 | |||
| 6e3681177d | |||
| b074630e6b | |||
| 4c7ac24ecc | |||
| b57b670c4b | |||
| 2dea0d8fd0 | |||
| 2ce95e406a | |||
| 5a275b812b | |||
| b4076327d8 | |||
| 7125b6933d | |||
| 64cef11c7c | |||
| 944ea90346 | |||
| 9fb868c8f5 | |||
| e86adcfbfc | |||
| 8090d3eeba | |||
| 03ed9d92b2 | |||
| 627c0ec550 | |||
| fa7365ca37 | |||
| eabd05b6c4 | |||
| 3a893f2b3c | |||
| 6135b499c4 | |||
| 571a5309ba | |||
| 6d6f4a509f | |||
| 925e44ceb4 | |||
| 3dd57bf725 | |||
| 1f9d843207 | |||
| 48edd1a9e0 | |||
| 508e1e8261 | |||
| 326d82e2d8 | |||
| d38570a6c9 | |||
| 9906685345 | |||
| 4b98c93700 | |||
| fc0f538543 | |||
| 6210cd6606 | |||
| 656e27770d | |||
| 73c75e4646 | |||
| cfea40a938 | |||
| b6f4c80108 | |||
| ff0ab4b3ef | |||
| 7decd593d8 | |||
| 2b20fb2b46 | |||
| a9ce421924 | |||
| 578edafab3 | |||
| d18434b37a | |||
| a00741ef93 | |||
| 51b5e2898d | |||
| fa94abb0cc | |||
| 763828579b | |||
| d8ecd258e9 | |||
| 8639f988b8 | |||
| 6435e18aa2 | |||
| 0993900953 | |||
| 3525029e7e | |||
| 914ccb2e65 | |||
| a6176bc1fd | |||
| e35de6502a | |||
| b57f7e74da | |||
| 946b7e125a | |||
| 2a6449c20e | |||
| d85bac7d65 | |||
| 6f8f018f4f | |||
| 76eb92d124 | |||
| 2e5c9f9a3a | |||
| 96c0398323 |
@@ -1 +1,54 @@
|
|||||||
/app/.gitea/workflows/ci.yml
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: python -m pytest tests/test_core.py tests/test_providers.py tests/test_testing.py -v --tb=short
|
||||||
|
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install linting tools
|
||||||
|
run: pip install ruff>=0.1.0
|
||||||
|
|
||||||
|
- name: Run ruff linter
|
||||||
|
run: python -m ruff check src/promptforge/ tests/
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install type checker
|
||||||
|
run: pip install mypy>=1.0.0
|
||||||
|
|
||||||
|
- name: Run mypy type checker
|
||||||
|
run: python -m mypy src/promptforge/ --ignore-missing-imports
|
||||||
|
|||||||
43
.gitignore
vendored
43
.gitignore
vendored
@@ -1 +1,42 @@
|
|||||||
/app/.gitignore
|
# Project specific
|
||||||
|
prompts/
|
||||||
|
!prompts/.gitkeep
|
||||||
|
configs/
|
||||||
|
!configs/promptforge.yaml
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
config.yaml
|
||||||
|
.history/
|
||||||
|
*.log
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
|||||||
227
README.md
227
README.md
@@ -1 +1,226 @@
|
|||||||
/app/README.md
|
# PromptForge
|
||||||
|
|
||||||
|
A CLI tool for versioning, testing, and sharing AI prompts across different LLM providers. Treat prompts as code with git integration, A/B testing, templating, and a shared prompt registry.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Version Control**: Track prompt changes with Git, create branches for A/B testing variations
|
||||||
|
- **Multi-Provider Support**: Unified API for OpenAI, Anthropic, and Ollama
|
||||||
|
- **Prompt Templating**: Jinja2-based variable substitution with type validation
|
||||||
|
- **A/B Testing Framework**: Compare prompt variations with statistical analysis
|
||||||
|
- **Output Validation**: Validate LLM responses against JSON schemas or regex patterns
|
||||||
|
- **Prompt Registry**: Share prompts via local and remote registries
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install promptforge
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from source:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yourusername/promptforge.git
|
||||||
|
cd promptforge
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
1. Initialize a new PromptForge project:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pf init
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create your first prompt:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pf prompt create "Summarizer" -c "Summarize the following text: {{text}}"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the prompt:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pf run Summarizer -v text="Your long text here..."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Create a `configs/promptforge.yaml` file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
providers:
|
||||||
|
openai:
|
||||||
|
api_key: ${OPENAI_API_KEY}
|
||||||
|
model: gpt-4
|
||||||
|
temperature: 0.7
|
||||||
|
anthropic:
|
||||||
|
api_key: ${ANTHROPIC_API_KEY}
|
||||||
|
model: claude-3-sonnet-20240229
|
||||||
|
ollama:
|
||||||
|
base_url: http://localhost:11434
|
||||||
|
model: llama2
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
provider: openai
|
||||||
|
output_format: text
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating Prompts
|
||||||
|
|
||||||
|
Prompts are YAML files with front matter:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
---
|
||||||
|
name: Code Explainer
|
||||||
|
description: Explain code snippets
|
||||||
|
version: "1.0.0"
|
||||||
|
provider: openai
|
||||||
|
tags: [coding, education]
|
||||||
|
variables:
|
||||||
|
- name: language
|
||||||
|
type: choice
|
||||||
|
required: true
|
||||||
|
choices: [python, javascript, rust, go]
|
||||||
|
- name: code
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
validation:
|
||||||
|
- type: regex
|
||||||
|
pattern: "(def|function|fn|func)"
|
||||||
|
---
|
||||||
|
Explain this {{language}} code:
|
||||||
|
|
||||||
|
{{code}}
|
||||||
|
|
||||||
|
Focus on:
|
||||||
|
- What the code does
|
||||||
|
- Key functions/classes used
|
||||||
|
- Any potential improvements
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Prompts
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run with variables
|
||||||
|
pf run "Code Explainer" -v language=python -v code="def hello(): print('world')"
|
||||||
|
|
||||||
|
# Use a different provider
|
||||||
|
pf run "Code Explainer" -p anthropic -v language=rust -v code="..."
|
||||||
|
|
||||||
|
# Output as JSON
|
||||||
|
pf run "Code Explainer" -o json -v ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Version Control
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create a version commit
|
||||||
|
pf version create "Added validation rules"
|
||||||
|
|
||||||
|
# View history
|
||||||
|
pf version history
|
||||||
|
|
||||||
|
# Create a branch for A/B testing
|
||||||
|
pf version branch test-variation-a
|
||||||
|
|
||||||
|
# List all branches
|
||||||
|
pf version list
|
||||||
|
```
|
||||||
|
|
||||||
|
## A/B Testing
|
||||||
|
|
||||||
|
Compare prompt variations:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test a single prompt
|
||||||
|
pf test "Code Explainer" --iterations 5
|
||||||
|
|
||||||
|
# Compare multiple prompts
|
||||||
|
pf test "Prompt A" "Prompt B" --iterations 3
|
||||||
|
|
||||||
|
# Run in parallel
|
||||||
|
pf test "Prompt A" "Prompt B" --parallel
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output Validation
|
||||||
|
|
||||||
|
Add validation rules to your prompts:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
validation:
|
||||||
|
- type: regex
|
||||||
|
pattern: "^\\d+\\. .+"
|
||||||
|
message: "Response must be numbered list"
|
||||||
|
|
||||||
|
- type: json
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
summary:
|
||||||
|
type: string
|
||||||
|
minLength: 10
|
||||||
|
keywords:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prompt Registry
|
||||||
|
|
||||||
|
### Local Registry
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List local prompts
|
||||||
|
pf registry list
|
||||||
|
|
||||||
|
# Add prompt to registry
|
||||||
|
pf registry add "Code Explainer" --author "Your Name"
|
||||||
|
|
||||||
|
# Search registry
|
||||||
|
pf registry search "python"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Remote Registry
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Pull a prompt from remote
|
||||||
|
pf registry pull <entry_id>
|
||||||
|
|
||||||
|
# Publish your prompt
|
||||||
|
pf registry publish "Code Explainer"
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Core Classes
|
||||||
|
|
||||||
|
- `Prompt`: Main prompt model with YAML serialization
|
||||||
|
- `TemplateEngine`: Jinja2-based template rendering
|
||||||
|
- `GitManager`: Git integration for version control
|
||||||
|
- `ProviderBase`: Abstract interface for LLM providers
|
||||||
|
|
||||||
|
### Providers
|
||||||
|
|
||||||
|
- `OpenAIProvider`: OpenAI GPT models
|
||||||
|
- `AnthropicProvider`: Anthropic Claude models
|
||||||
|
- `OllamaProvider`: Local Ollama models
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
- `ABTest`: A/B test runner
|
||||||
|
- `Validator`: Response validation framework
|
||||||
|
- `MetricsCollector`: Metrics aggregation
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch
|
||||||
|
3. Make your changes
|
||||||
|
4. Run tests: `pytest tests/ -v`
|
||||||
|
5. Submit a pull request
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see LICENSE file for details.
|
||||||
|
|||||||
140
app/src/promptforge/cli/commands/prompt.py
Normal file
140
app/src/promptforge/cli/commands/prompt.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import click
|
||||||
|
|
||||||
|
from promptforge.core.prompt import Prompt, PromptVariable, VariableType
|
||||||
|
from promptforge.core.template import TemplateEngine
|
||||||
|
from promptforge.core.git_manager import GitManager
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def prompt():
|
||||||
|
"""Manage prompts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@prompt.command("create")
|
||||||
|
@click.argument("name")
|
||||||
|
@click.option("--content", "-c", help="Prompt content")
|
||||||
|
@click.option("--description", "-d", help="Prompt description")
|
||||||
|
@click.option("--provider", "-p", help="Default provider")
|
||||||
|
@click.option("--tag", "-t", multiple=True, help="Tags for the prompt")
|
||||||
|
@click.pass_obj
|
||||||
|
def create(ctx, name: str, content: str, description: str, provider: str, tag: tuple):
|
||||||
|
"""Create a new prompt."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
click.echo("Enter prompt content (end with Ctrl+D on new line):")
|
||||||
|
content = click.get_text_stream("stdin").read()
|
||||||
|
|
||||||
|
variables = []
|
||||||
|
if "{" in content and "}" in content:
|
||||||
|
template_engine = TemplateEngine()
|
||||||
|
var_names = template_engine.get_variables(content)
|
||||||
|
for var_name in var_names:
|
||||||
|
var_type = click.prompt(
|
||||||
|
f"Variable '{var_name}' type",
|
||||||
|
type=click.Choice(["string", "integer", "float", "boolean", "choice"]),
|
||||||
|
default="string",
|
||||||
|
)
|
||||||
|
is_required = click.confirm(f"Is '{var_name}' required?", default=True)
|
||||||
|
default_val = None
|
||||||
|
if not is_required:
|
||||||
|
default_val = click.prompt(f"Default value for '{var_name}'", default="")
|
||||||
|
variables.append(PromptVariable(
|
||||||
|
name=var_name,
|
||||||
|
type=VariableType(var_type),
|
||||||
|
required=is_required,
|
||||||
|
default=default_val,
|
||||||
|
))
|
||||||
|
|
||||||
|
prompt = Prompt(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
content=content,
|
||||||
|
provider=provider,
|
||||||
|
tags=list(tag),
|
||||||
|
variables=variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
filepath = prompt.save(prompts_dir)
|
||||||
|
click.echo(f"Created prompt: {filepath}")
|
||||||
|
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
if (prompts_dir / ".git").exists():
|
||||||
|
try:
|
||||||
|
git_manager.commit(f"Add prompt: {name}")
|
||||||
|
click.echo("Committed to git")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@prompt.command("list")
|
||||||
|
@click.pass_obj
|
||||||
|
def list_prompts(ctx):
|
||||||
|
"""List all prompts."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
if not prompts:
|
||||||
|
click.echo("No prompts found. Create one with 'pf prompt create'")
|
||||||
|
return
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
status = "✓" if prompt.provider else "○"
|
||||||
|
click.echo(f"{status} {prompt.name} (v{prompt.version})")
|
||||||
|
if prompt.description:
|
||||||
|
click.echo(f" {prompt.description}")
|
||||||
|
|
||||||
|
|
||||||
|
@prompt.command("show")
|
||||||
|
@click.argument("name")
|
||||||
|
@click.pass_obj
|
||||||
|
def show(ctx, name: str):
|
||||||
|
"""Show prompt details."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
prompt = next((p for p in prompts if p.name == name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
click.echo(f"Name: {prompt.name}")
|
||||||
|
click.echo(f"Version: {prompt.version}")
|
||||||
|
if prompt.description:
|
||||||
|
click.echo(f"Description: {prompt.description}")
|
||||||
|
if prompt.provider:
|
||||||
|
click.echo(f"Provider: {prompt.provider}")
|
||||||
|
if prompt.tags:
|
||||||
|
click.echo(f"Tags: {', '.join(prompt.tags)}")
|
||||||
|
click.echo(f"Created: {prompt.created_at.isoformat()}")
|
||||||
|
click.echo(f"Updated: {prompt.updated_at.isoformat()}")
|
||||||
|
click.echo("")
|
||||||
|
click.echo("--- Content ---")
|
||||||
|
click.echo(prompt.content)
|
||||||
|
|
||||||
|
|
||||||
|
@prompt.command("delete")
|
||||||
|
@click.argument("name")
|
||||||
|
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation")
|
||||||
|
@click.pass_obj
|
||||||
|
def delete(ctx, name: str, yes: bool):
|
||||||
|
"""Delete a prompt."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
prompt = next((p for p in prompts if p.name == name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
if not yes and not click.confirm(f"Delete prompt '{name}'?"):
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
filepath = prompts_dir / f"{name.lower().replace(' ', '_').replace('/', '_')}.yaml"
|
||||||
|
if filepath.exists():
|
||||||
|
filepath.unlink()
|
||||||
|
click.echo(f"Deleted prompt: {name}")
|
||||||
|
else:
|
||||||
|
click.echo("Prompt file not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
109
app/src/promptforge/cli/commands/registry.py
Normal file
109
app/src/promptforge/cli/commands/registry.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import click
|
||||||
|
|
||||||
|
from promptforge.registry import LocalRegistry, RemoteRegistry, RegistryEntry
|
||||||
|
from promptforge.core.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def registry():
|
||||||
|
"""Manage prompt registry."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@registry.command("list")
|
||||||
|
@click.option("--tag", help="Filter by tag")
|
||||||
|
@click.option("--limit", default=20, help="Maximum results")
|
||||||
|
@click.pass_obj
|
||||||
|
def registry_list(ctx, tag: str, limit: int):
|
||||||
|
"""List prompts in local registry."""
|
||||||
|
registry = LocalRegistry()
|
||||||
|
entries = registry.list(tag=tag, limit=limit)
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
click.echo("No prompts in registry")
|
||||||
|
return
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
click.echo(f"{entry.name} (v{entry.version})")
|
||||||
|
if entry.description:
|
||||||
|
click.echo(f" {entry.description}")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.command("add")
|
||||||
|
@click.argument("prompt_name")
|
||||||
|
@click.option("--author", help="Author name")
|
||||||
|
@click.pass_obj
|
||||||
|
def registry_add(ctx, prompt_name: str, author: str):
|
||||||
|
"""Add a prompt to the local registry."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
prompt = next((p for p in prompts if p.name == prompt_name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{prompt_name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
entry = RegistryEntry.from_prompt(prompt, author=author)
|
||||||
|
registry = LocalRegistry()
|
||||||
|
registry.add(entry)
|
||||||
|
|
||||||
|
click.echo(f"Added '{prompt_name}' to registry")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.command("search")
|
||||||
|
@click.argument("query")
|
||||||
|
@click.option("--limit", default=20, help="Maximum results")
|
||||||
|
@click.pass_obj
|
||||||
|
def registry_search(ctx, query: str, limit: int):
|
||||||
|
"""Search local registry."""
|
||||||
|
registry = LocalRegistry()
|
||||||
|
results = registry.search(query)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
click.echo("No results found")
|
||||||
|
return
|
||||||
|
|
||||||
|
for result in results[:limit]:
|
||||||
|
entry = result.entry
|
||||||
|
click.echo(f"{entry.name} (score: {result.relevance_score})")
|
||||||
|
if entry.description:
|
||||||
|
click.echo(f" {entry.description}")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.command("pull")
|
||||||
|
@click.argument("entry_id")
|
||||||
|
@click.pass_obj
|
||||||
|
def registry_pull(ctx, entry_id: str):
|
||||||
|
"""Pull a prompt from remote registry."""
|
||||||
|
remote = RemoteRegistry()
|
||||||
|
local = LocalRegistry()
|
||||||
|
|
||||||
|
if remote.pull(entry_id, local):
|
||||||
|
click.echo(f"Pulled entry {entry_id}")
|
||||||
|
else:
|
||||||
|
click.echo("Entry not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.command("publish")
|
||||||
|
@click.argument("prompt_name")
|
||||||
|
@click.pass_obj
|
||||||
|
def registry_publish(ctx, prompt_name: str):
|
||||||
|
"""Publish a prompt to remote registry."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
prompt = next((p for p in prompts if p.name == prompt_name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{prompt_name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
entry = RegistryEntry.from_prompt(prompt)
|
||||||
|
remote = RemoteRegistry()
|
||||||
|
|
||||||
|
try:
|
||||||
|
published = remote.publish(entry)
|
||||||
|
click.echo(f"Published '{prompt_name}' as {published.id}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Publish error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
97
app/src/promptforge/cli/commands/run.py
Normal file
97
app/src/promptforge/cli/commands/run.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
import click
|
||||||
|
|
||||||
|
from promptforge.core.prompt import Prompt
|
||||||
|
from promptforge.core.template import TemplateEngine
|
||||||
|
from promptforge.core.config import get_config
|
||||||
|
from promptforge.providers import ProviderFactory
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("name")
|
||||||
|
@click.option("--provider", "-p", help="Override provider")
|
||||||
|
@click.option("--var", "-v", multiple=True, help="Variables in key=value format")
|
||||||
|
@click.option("--output", "-o", type=click.Choice(["text", "json", "yaml"]), default="text")
|
||||||
|
@click.option("--stream/--no-stream", default=True, help="Stream response")
|
||||||
|
@click.pass_obj
|
||||||
|
def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
||||||
|
"""Run a prompt with variables."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
prompt = next((p for p in prompts if p.name == name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
variables = {}
|
||||||
|
for v in var:
|
||||||
|
if "=" in v:
|
||||||
|
k, val = v.split("=", 1)
|
||||||
|
variables[k.strip()] = val.strip()
|
||||||
|
|
||||||
|
template_engine = TemplateEngine()
|
||||||
|
try:
|
||||||
|
rendered = template_engine.render(
|
||||||
|
prompt.content,
|
||||||
|
variables,
|
||||||
|
prompt.variables,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Template error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
selected_provider = provider or prompt.provider or config.defaults.provider
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||||
|
provider_instance = ProviderFactory.create(
|
||||||
|
selected_provider,
|
||||||
|
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||||
|
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Provider error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
async def execute():
|
||||||
|
if stream:
|
||||||
|
full_response = []
|
||||||
|
async for chunk in provider_instance.stream_complete(rendered):
|
||||||
|
click.echo(chunk, nl=False)
|
||||||
|
full_response.append(chunk)
|
||||||
|
response = "".join(full_response)
|
||||||
|
else:
|
||||||
|
result = await provider_instance.complete(rendered)
|
||||||
|
response = result.content
|
||||||
|
click.echo(response)
|
||||||
|
|
||||||
|
if output == "json":
|
||||||
|
import json
|
||||||
|
click.echo("\n" + json.dumps({"response": response}, indent=2))
|
||||||
|
|
||||||
|
if prompt.validation_rules:
|
||||||
|
validate_response(prompt, response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(execute())
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Execution error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_response(prompt: Prompt, response: str):
|
||||||
|
"""Validate response against rules."""
|
||||||
|
for rule in prompt.validation_rules:
|
||||||
|
if rule.type == "regex":
|
||||||
|
import re
|
||||||
|
if not re.search(rule.pattern or "", response):
|
||||||
|
click.echo("Warning: Response failed regex validation", err=True)
|
||||||
|
elif rule.type == "json":
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
json.loads(response)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
click.echo("Warning: Response is not valid JSON", err=True)
|
||||||
81
app/src/promptforge/cli/commands/test.py
Normal file
81
app/src/promptforge/cli/commands/test.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
import click
|
||||||
|
|
||||||
|
from promptforge.core.prompt import Prompt
|
||||||
|
from promptforge.core.config import get_config
|
||||||
|
from promptforge.providers import ProviderFactory
|
||||||
|
from promptforge.testing import ABTest, ABTestConfig
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("prompt_names", nargs=-1, required=True)
|
||||||
|
@click.option("--provider", "-p", help="Provider to use")
|
||||||
|
@click.option("--iterations", "-i", default=3, help="Number of test iterations")
|
||||||
|
@click.option("--output", "-o", type=click.Choice(["text", "json"]), default="text")
|
||||||
|
@click.option("--parallel", is_flag=True, help="Run iterations in parallel")
|
||||||
|
@click.pass_obj
|
||||||
|
def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str, parallel: bool):
|
||||||
|
"""Test prompts with A/B testing."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
prompts = Prompt.list(prompts_dir)
|
||||||
|
|
||||||
|
selected_prompts = []
|
||||||
|
for name in prompt_names:
|
||||||
|
prompt = next((p for p in prompts if p.name == name), None)
|
||||||
|
if not prompt:
|
||||||
|
click.echo(f"Prompt '{name}' not found", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
selected_prompts.append(prompt)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
selected_provider = provider or config.defaults.provider
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||||
|
provider_instance = ProviderFactory.create(
|
||||||
|
selected_provider,
|
||||||
|
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||||
|
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Provider error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
test_config = ABTestConfig(
|
||||||
|
iterations=iterations,
|
||||||
|
parallel=parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
ab_test = ABTest(provider_instance, test_config)
|
||||||
|
|
||||||
|
async def run_tests():
|
||||||
|
results = await ab_test.run_comparison(selected_prompts)
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = asyncio.run(run_tests())
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Test error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
for name, summary in results.items():
|
||||||
|
click.echo(f"\n=== {name} ===")
|
||||||
|
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
||||||
|
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
||||||
|
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
||||||
|
click.echo(f"Avg Cost: ${summary.avg_cost:.4f}")
|
||||||
|
|
||||||
|
if output == "json":
|
||||||
|
import json
|
||||||
|
output_data = {
|
||||||
|
name: {
|
||||||
|
"successful_runs": s.successful_runs,
|
||||||
|
"total_runs": s.total_runs,
|
||||||
|
"avg_latency_ms": s.avg_latency_ms,
|
||||||
|
"avg_tokens": s.avg_tokens,
|
||||||
|
"avg_cost": s.avg_cost,
|
||||||
|
}
|
||||||
|
for name, s in results.items()
|
||||||
|
}
|
||||||
|
click.echo(json.dumps(output_data, indent=2))
|
||||||
147
app/src/promptforge/cli/commands/version.py
Normal file
147
app/src/promptforge/cli/commands/version.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import click
|
||||||
|
|
||||||
|
from promptforge.core.git_manager import GitManager
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def version():
|
||||||
|
"""Manage prompt versions."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("history")
|
||||||
|
@click.argument("prompt_name", required=False)
|
||||||
|
@click.pass_obj
|
||||||
|
def history(ctx, prompt_name: str):
|
||||||
|
"""Show version history."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
if not (prompts_dir / ".git").exists():
|
||||||
|
click.echo("Git not initialized. Run 'pf init' first.", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
try:
|
||||||
|
commits = git_manager.log()
|
||||||
|
|
||||||
|
if prompt_name:
|
||||||
|
file_history = git_manager.get_file_history(
|
||||||
|
f"{prompt_name.lower().replace(' ', '_')}.yaml"
|
||||||
|
)
|
||||||
|
for commit in file_history:
|
||||||
|
click.echo(f"{commit['sha'][:7]} - {commit['message']}")
|
||||||
|
click.echo(f" {commit['date']} by {commit['author']}")
|
||||||
|
else:
|
||||||
|
for commit in commits:
|
||||||
|
hexsha = commit.hexsha
|
||||||
|
if isinstance(hexsha, bytes):
|
||||||
|
hexsha = hexsha.decode('utf-8')
|
||||||
|
message = commit.message
|
||||||
|
if isinstance(message, bytes):
|
||||||
|
message = message.decode('utf-8')
|
||||||
|
click.echo(f"{hexsha[:7]} - {message.strip()}")
|
||||||
|
click.echo(f" {commit.author.name} - {commit.committed_datetime.isoformat()}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("create")
|
||||||
|
@click.argument("message")
|
||||||
|
@click.option("--author", help="Commit author")
|
||||||
|
@click.pass_obj
|
||||||
|
def create(ctx, message: str, author: str):
|
||||||
|
"""Create a version commit."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
if not (prompts_dir / ".git").exists():
|
||||||
|
git_manager.init()
|
||||||
|
|
||||||
|
try:
|
||||||
|
git_manager.commit(message, author=author)
|
||||||
|
click.echo(f"Created version: {message}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("branch")
|
||||||
|
@click.argument("branch_name")
|
||||||
|
@click.pass_obj
|
||||||
|
def branch(ctx, branch_name: str):
|
||||||
|
"""Create a branch for prompt variations."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
git_manager.create_branch(branch_name)
|
||||||
|
click.echo(f"Created branch: {branch_name}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("switch")
|
||||||
|
@click.argument("branch_name")
|
||||||
|
@click.pass_obj
|
||||||
|
def switch(ctx, branch_name: str):
|
||||||
|
"""Switch to a branch."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
git_manager.switch_branch(branch_name)
|
||||||
|
click.echo(f"Switched to branch: {branch_name}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("list")
|
||||||
|
@click.pass_obj
|
||||||
|
def list_branches(ctx):
|
||||||
|
"""List all branches."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
branches = git_manager.list_branches()
|
||||||
|
for branch in branches:
|
||||||
|
click.echo(f"* {branch}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("diff")
|
||||||
|
@click.pass_obj
|
||||||
|
def diff(ctx):
|
||||||
|
"""Show uncommitted changes."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
changes = git_manager.diff()
|
||||||
|
if changes:
|
||||||
|
click.echo(changes)
|
||||||
|
else:
|
||||||
|
click.echo("No uncommitted changes")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("status")
|
||||||
|
@click.pass_obj
|
||||||
|
def status(ctx):
|
||||||
|
"""Show git status."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
status_output = git_manager.status()
|
||||||
|
click.echo(status_output)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
22
app/src/promptforge/cli/main.py
Normal file
22
app/src/promptforge/cli/main.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import click
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from promptforge import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.version_option(version=__version__, prog_name="PromptForge")
|
||||||
|
@click.option(
|
||||||
|
"--prompts-dir",
|
||||||
|
type=click.Path(exists=False, file_okay=False, dir_okay=True, path_type=Path),
|
||||||
|
help="Directory containing prompts",
|
||||||
|
)
|
||||||
|
@click.pass_context
|
||||||
|
def main(ctx: click.Context, prompts_dir: Path):
|
||||||
|
"""PromptForge - AI Prompt Versioning, Testing & Registry."""
|
||||||
|
ctx.ensure_object(dict)
|
||||||
|
ctx.obj["prompts_dir"] = prompts_dir or Path.cwd() / "prompts"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
202
app/src/promptforge/core/git_manager.py
Normal file
202
app/src/promptforge/core/git_manager.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from git import Repo, Commit, GitCommandError
|
||||||
|
|
||||||
|
from .exceptions import GitError
|
||||||
|
|
||||||
|
|
||||||
|
class GitManager:
|
||||||
|
"""Manage git operations for prompt directories."""
|
||||||
|
|
||||||
|
def __init__(self, prompts_dir: Path):
|
||||||
|
"""Initialize git manager for a prompts directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts_dir: Path to the prompts directory.
|
||||||
|
"""
|
||||||
|
self.prompts_dir = Path(prompts_dir)
|
||||||
|
self.repo: Optional[Repo] = None
|
||||||
|
|
||||||
|
def init(self) -> bool:
|
||||||
|
"""Initialize a git repository in the prompts directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if repository was created, False if it already exists.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.prompts_dir.exists():
|
||||||
|
self.prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self._is_git_repo():
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.repo = Repo.init(str(self.prompts_dir))
|
||||||
|
self._configure_gitignore()
|
||||||
|
self.repo.index.commit("Initial commit: PromptForge repository")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
raise GitError(f"Failed to initialize git repository: {e}")
|
||||||
|
|
||||||
|
def _is_git_repo(self) -> bool:
|
||||||
|
"""Check if prompts_dir is a git repository."""
|
||||||
|
try:
|
||||||
|
self.repo = Repo(str(self.prompts_dir))
|
||||||
|
return not self.repo.bare
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _configure_gitignore(self) -> None:
|
||||||
|
"""Create .gitignore for prompts directory."""
|
||||||
|
gitignore_path = self.prompts_dir / ".gitignore"
|
||||||
|
if not gitignore_path.exists():
|
||||||
|
gitignore_path.write_text("*.lock\n.temp*\n")
|
||||||
|
|
||||||
|
def commit(self, message: str, author: Optional[str] = None) -> Commit:
|
||||||
|
"""Commit all changes to prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Commit message.
|
||||||
|
author: Author string (e.g., "Name <email>").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created commit object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GitError: If commit fails.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
try:
|
||||||
|
self.repo.index.add(["*"])
|
||||||
|
author_arg: Any = author # type: ignore[assignment]
|
||||||
|
return self.repo.index.commit(message, author=author_arg)
|
||||||
|
except GitCommandError as e:
|
||||||
|
raise GitError(f"Failed to commit changes: {e}")
|
||||||
|
|
||||||
|
def log(self, max_count: int = 20) -> List[Commit]:
|
||||||
|
"""Get commit history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_count: Maximum number of commits to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of commit objects.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
try:
|
||||||
|
return list(self.repo.iter_commits(max_count=max_count))
|
||||||
|
except GitCommandError as e:
|
||||||
|
raise GitError(f"Failed to get commit log: {e}")
|
||||||
|
|
||||||
|
def show_commit(self, commit_sha: str) -> str:
|
||||||
|
"""Show content of a specific commit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commit_sha: SHA of the commit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Commit diff as string.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
try:
|
||||||
|
commit = self.repo.commit(commit_sha)
|
||||||
|
diff_result = commit.diff("HEAD~1" if commit_sha == "HEAD" else f"{commit_sha}^")
|
||||||
|
return str(diff_result) if diff_result else ""
|
||||||
|
except Exception as e:
|
||||||
|
raise GitError(f"Failed to show commit: {e}")
|
||||||
|
|
||||||
|
def create_branch(self, branch_name: str) -> None:
|
||||||
|
"""Create a new branch for prompt variations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_name: Name of the new branch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GitError: If branch creation fails.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
try:
|
||||||
|
self.repo.create_head(branch_name)
|
||||||
|
except GitCommandError as e:
|
||||||
|
raise GitError(f"Failed to create branch: {e}")
|
||||||
|
|
||||||
|
def switch_branch(self, branch_name: str) -> None:
|
||||||
|
"""Switch to a different branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_name: Name of the branch to switch to.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GitError: If branch switch fails.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
try:
|
||||||
|
self.repo.heads[branch_name].checkout()
|
||||||
|
except GitCommandError as e:
|
||||||
|
raise GitError(f"Failed to switch branch: {e}")
|
||||||
|
|
||||||
|
def list_branches(self) -> List[str]:
|
||||||
|
"""List all branches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of branch names.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
return [head.name for head in self.repo.heads]
|
||||||
|
|
||||||
|
def status(self) -> str:
|
||||||
|
"""Get git status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Status string.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
return self.repo.git.status()
|
||||||
|
|
||||||
|
def diff(self) -> str:
|
||||||
|
"""Show uncommitted changes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Diff as string.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
return self.repo.git.diff()
|
||||||
|
|
||||||
|
def get_file_history(self, filename: str) -> List[dict]:
|
||||||
|
"""Get commit history for a specific file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of commit info dictionaries.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
commits = []
|
||||||
|
try:
|
||||||
|
for commit in self.repo.iter_commits("--all", filename):
|
||||||
|
commits.append({
|
||||||
|
"sha": commit.hexsha,
|
||||||
|
"author": str(commit.author),
|
||||||
|
"date": datetime.fromtimestamp(commit.authored_date).isoformat(),
|
||||||
|
"message": commit.message,
|
||||||
|
})
|
||||||
|
except GitCommandError:
|
||||||
|
pass
|
||||||
|
return commits
|
||||||
|
|
||||||
|
def _ensure_repo(self) -> None:
|
||||||
|
"""Ensure repository is initialized."""
|
||||||
|
if self.repo is None:
|
||||||
|
if not self._is_git_repo():
|
||||||
|
raise GitError("Git repository not initialized. Run 'pf init' first.")
|
||||||
161
app/src/promptforge/core/prompt.py
Normal file
161
app/src/promptforge/core/prompt.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class VariableType(str, Enum):
|
||||||
|
"""Supported variable types."""
|
||||||
|
|
||||||
|
STRING = "string"
|
||||||
|
INTEGER = "integer"
|
||||||
|
FLOAT = "float"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
CHOICE = "choice"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptVariable(BaseModel):
|
||||||
|
"""Definition of a template variable."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: VariableType = VariableType.STRING
|
||||||
|
description: Optional[str] = None
|
||||||
|
required: bool = True
|
||||||
|
default: Optional[Any] = None
|
||||||
|
choices: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@field_validator('choices')
|
||||||
|
@classmethod
|
||||||
|
def validate_choices(cls, v, info):
|
||||||
|
if v is not None and info.data.get('type') != VariableType.CHOICE:
|
||||||
|
raise ValueError("choices only valid for CHOICE type")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationRule(BaseModel):
|
||||||
|
"""Validation rule for prompt output."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
pattern: Optional[str] = None
|
||||||
|
json_schema: Optional[Dict[str, Any]] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(BaseModel):
|
||||||
|
"""Prompt model with metadata and template."""
|
||||||
|
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
content: str
|
||||||
|
variables: List[PromptVariable] = Field(default_factory=list)
|
||||||
|
validation_rules: List[ValidationRule] = Field(default_factory=list)
|
||||||
|
provider: Optional[str] = None
|
||||||
|
tags: List[str] = Field(default_factory=list)
|
||||||
|
version: str = "1.0.0"
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
hash: str = ""
|
||||||
|
|
||||||
|
@field_validator('hash', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def compute_hash(cls, v, info):
|
||||||
|
if not v:
|
||||||
|
content = info.data.get('content', '')
|
||||||
|
return hashlib.md5(content.encode()).hexdigest()
|
||||||
|
return v
|
||||||
|
|
||||||
|
def to_dict(self, exclude_none: bool = False) -> Dict[str, Any]:
|
||||||
|
"""Export prompt to dictionary."""
|
||||||
|
data = super().model_dump(exclude_none=exclude_none)
|
||||||
|
data['created_at'] = self.created_at.isoformat()
|
||||||
|
data['updated_at'] = self.updated_at.isoformat()
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, yaml_content: str) -> "Prompt":
|
||||||
|
"""Parse prompt from YAML with front matter."""
|
||||||
|
content = yaml_content.strip()
|
||||||
|
|
||||||
|
if not content.startswith('---'):
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
prompt_content = content
|
||||||
|
else:
|
||||||
|
parts = content[4:].split('\n---', 1)
|
||||||
|
metadata = yaml.safe_load(parts[0]) or {}
|
||||||
|
prompt_content = parts[1].strip() if len(parts) > 1 else ''
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'name': metadata.get('name', 'Untitled'),
|
||||||
|
'description': metadata.get('description'),
|
||||||
|
'content': prompt_content,
|
||||||
|
'variables': [PromptVariable(**v) for v in metadata.get('variables', [])],
|
||||||
|
'validation_rules': [ValidationRule(**r) for r in metadata.get('validation', [])],
|
||||||
|
'provider': metadata.get('provider'),
|
||||||
|
'tags': metadata.get('tags', []),
|
||||||
|
'version': metadata.get('version', '1.0.0'),
|
||||||
|
}
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def to_yaml(self) -> str:
|
||||||
|
"""Export prompt to YAML front matter format."""
|
||||||
|
def var_to_dict(v):
|
||||||
|
d = v.model_dump()
|
||||||
|
d['type'] = v.type.value
|
||||||
|
return d
|
||||||
|
|
||||||
|
def rule_to_dict(r):
|
||||||
|
return r.model_dump()
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'name': self.name,
|
||||||
|
'description': self.description,
|
||||||
|
'provider': self.provider,
|
||||||
|
'tags': self.tags,
|
||||||
|
'version': self.version,
|
||||||
|
'variables': [var_to_dict(v) for v in self.variables],
|
||||||
|
'validation': [rule_to_dict(r) for r in self.validation_rules],
|
||||||
|
}
|
||||||
|
yaml_str = yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
|
||||||
|
return f"---\n{yaml_str}---\n{self.content}"
|
||||||
|
|
||||||
|
def save(self, prompts_dir: Path) -> Path:
|
||||||
|
"""Save prompt to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts_dir: Directory to save prompt in.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved file.
|
||||||
|
"""
|
||||||
|
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = self.name.lower().replace(' ', '_').replace('/', '_') + '.yaml'
|
||||||
|
filepath = prompts_dir / filename
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
f.write(self.to_yaml())
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, filepath: Path) -> "Prompt":
|
||||||
|
"""Load prompt from file."""
|
||||||
|
with open(filepath, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
return cls.from_yaml(content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list(cls, prompts_dir: Path) -> List["Prompt"]:
|
||||||
|
"""List all prompts in directory."""
|
||||||
|
prompts: List["Prompt"] = []
|
||||||
|
if not prompts_dir.exists():
|
||||||
|
return prompts
|
||||||
|
for filepath in prompts_dir.glob('*.yaml'):
|
||||||
|
try:
|
||||||
|
prompts.append(cls.load(filepath))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return sorted(prompts, key=lambda p: p.name)
|
||||||
150
app/src/promptforge/providers/anthropic.py
Normal file
150
app/src/promptforge/providers/anthropic.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import time
|
||||||
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
|
from anthropic import Anthropic, APIError, RateLimitError
|
||||||
|
|
||||||
|
from .base import ProviderBase, ProviderResponse
|
||||||
|
from ..core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(ProviderBase):
|
||||||
|
"""Anthropic Claude models provider."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "claude-3-sonnet-20240229",
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize Anthropic provider."""
|
||||||
|
super().__init__(api_key, model, temperature, **kwargs)
|
||||||
|
self._client: Optional[Anthropic] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "anthropic"
|
||||||
|
|
||||||
|
def _get_client(self) -> Anthropic:
|
||||||
|
"""Get or create Anthropic client."""
|
||||||
|
if self._client is None:
|
||||||
|
api_key = self.api_key or self._get_api_key_from_env()
|
||||||
|
if not api_key:
|
||||||
|
raise ProviderError(
|
||||||
|
"Anthropic API key not configured. "
|
||||||
|
"Set ANTHROPIC_API_KEY env var or pass api_key parameter."
|
||||||
|
)
|
||||||
|
self._client = Anthropic(api_key=api_key)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _get_api_key_from_env(self) -> Optional[str]:
|
||||||
|
import os
|
||||||
|
return os.environ.get("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderResponse:
|
||||||
|
"""Send completion request to Anthropic."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
system_message = system_prompt
|
||||||
|
user_message = prompt
|
||||||
|
else:
|
||||||
|
system_message = None
|
||||||
|
user_message = prompt
|
||||||
|
|
||||||
|
response = client.messages.create( # type: ignore[arg-type]
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_tokens or 4096,
|
||||||
|
temperature=self.temperature,
|
||||||
|
system=system_message, # type: ignore[arg-type]
|
||||||
|
messages=[{"role": "user", "content": user_message}],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for block in response.content:
|
||||||
|
if block.type == "text":
|
||||||
|
content += block.text
|
||||||
|
|
||||||
|
return ProviderResponse(
|
||||||
|
content=content,
|
||||||
|
model=self.model,
|
||||||
|
provider=self.name,
|
||||||
|
usage={
|
||||||
|
"input_tokens": response.usage.input_tokens,
|
||||||
|
"output_tokens": response.usage.output_tokens,
|
||||||
|
},
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={
|
||||||
|
"stop_reason": response.stop_reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except APIError as e:
|
||||||
|
raise ProviderError(f"Anthropic API error: {e}")
|
||||||
|
except RateLimitError as e:
|
||||||
|
raise ProviderError(f"Anthropic rate limit exceeded: {e}")
|
||||||
|
|
||||||
|
async def stream_complete( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completion from Anthropic."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
system_message = system_prompt
|
||||||
|
user_message = prompt
|
||||||
|
else:
|
||||||
|
system_message = None
|
||||||
|
user_message = prompt
|
||||||
|
|
||||||
|
with client.messages.stream( # type: ignore[arg-type]
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_tokens or 4096,
|
||||||
|
temperature=self.temperature,
|
||||||
|
system=system_message, # type: ignore[arg-type]
|
||||||
|
messages=[{"role": "user", "content": user_message}],
|
||||||
|
**kwargs,
|
||||||
|
) as stream:
|
||||||
|
for text in stream.text_stream:
|
||||||
|
yield text
|
||||||
|
except APIError as e:
|
||||||
|
raise ProviderError(f"Anthropic API error: {e}")
|
||||||
|
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Validate Anthropic API key."""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
api_key = self.api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return False
|
||||||
|
_ = Anthropic(api_key=api_key)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available Anthropic models."""
|
||||||
|
return [
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
"claude-3-sonnet-20240229",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
"claude-2.1",
|
||||||
|
"claude-2.0",
|
||||||
|
"claude-instant-1.2",
|
||||||
|
]
|
||||||
104
app/src/promptforge/providers/base.py
Normal file
104
app/src/promptforge/providers/base.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderResponse:
|
||||||
|
"""Response from an LLM provider."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
usage: Dict[str, int] = field(default_factory=dict)
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderBase(ABC):
|
||||||
|
"""Abstract base class for LLM providers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "gpt-4",
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication.
|
||||||
|
model: Model identifier to use.
|
||||||
|
temperature: Sampling temperature (0.0-1.0).
|
||||||
|
**kwargs: Additional provider-specific options.
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Provider name identifier."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderResponse:
|
||||||
|
"""Send a completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderResponse with the generated content.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stream_complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completions incrementally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Chunks of generated content.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Validate that the API key is configured correctly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> List[str]:
|
||||||
|
"""List available models for this provider."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_system_prompt(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract system prompt from prompt if using special syntax."""
|
||||||
|
if "---" in prompt:
|
||||||
|
parts = prompt.split("---", 1)
|
||||||
|
return parts[0].strip()
|
||||||
|
return None
|
||||||
74
app/src/promptforge/providers/factory.py
Normal file
74
app/src/promptforge/providers/factory.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .base import ProviderBase
|
||||||
|
from .openai import OpenAIProvider
|
||||||
|
from .anthropic import AnthropicProvider
|
||||||
|
from .ollama import OllamaProvider
|
||||||
|
from ..core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderFactory:
|
||||||
|
"""Factory for creating LLM provider instances."""
|
||||||
|
|
||||||
|
_providers: Dict[str, type] = {
|
||||||
|
"openai": OpenAIProvider,
|
||||||
|
"anthropic": AnthropicProvider,
|
||||||
|
"ollama": OllamaProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str, provider_class: type) -> None:
|
||||||
|
"""Register a new provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider identifier.
|
||||||
|
provider_class: Provider class to register.
|
||||||
|
"""
|
||||||
|
if not issubclass(provider_class, ProviderBase):
|
||||||
|
raise TypeError("Provider must be a subclass of ProviderBase")
|
||||||
|
cls._providers[name.lower()] = provider_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
provider_name: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderBase:
|
||||||
|
"""Create a provider instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider to create.
|
||||||
|
api_key: API key for the provider.
|
||||||
|
model: Model to use (uses default if not specified).
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
**kwargs: Additional provider-specific options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider instance.
|
||||||
|
"""
|
||||||
|
provider_class = cls._providers.get(provider_name.lower())
|
||||||
|
if provider_class is None:
|
||||||
|
available = ", ".join(cls._providers.keys())
|
||||||
|
raise ProviderError(
|
||||||
|
f"Unknown provider: {provider_name}. Available: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model or getattr(provider_class, "_default_model", "gpt-4"),
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_providers(cls) -> list[str]:
|
||||||
|
"""List available provider names."""
|
||||||
|
return list(cls._providers.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_class(cls, name: str) -> Optional[type]:
|
||||||
|
"""Get provider class by name."""
|
||||||
|
return cls._providers.get(name.lower())
|
||||||
177
app/src/promptforge/providers/ollama.py
Normal file
177
app/src/promptforge/providers/ollama.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, AsyncIterator, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .base import ProviderBase, ProviderResponse
|
||||||
|
from ..core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaProvider(ProviderBase):
|
||||||
|
"""Ollama local model provider."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "llama2",
|
||||||
|
temperature: float = 0.7,
|
||||||
|
base_url: str = "http://localhost:11434",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize Ollama provider."""
|
||||||
|
super().__init__(api_key, model, temperature, **kwargs)
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "ollama"
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create HTTP client."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(timeout=120.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _get_api_url(self, endpoint: str) -> str:
|
||||||
|
"""Get full URL for an endpoint."""
|
||||||
|
return f"{self.base_url}/{endpoint.lstrip('/')}"
|
||||||
|
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderResponse:
|
||||||
|
"""Send completion request to Ollama."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
payload["options"]["num_predict"] = max_tokens
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
self._get_api_url("/api/chat"),
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for msg in data.get("message", {}).get("content", ""):
|
||||||
|
if isinstance(msg, str):
|
||||||
|
content += msg
|
||||||
|
elif isinstance(msg, dict):
|
||||||
|
content += msg.get("content", "")
|
||||||
|
|
||||||
|
return ProviderResponse(
|
||||||
|
content=content,
|
||||||
|
model=self.model,
|
||||||
|
provider=self.name,
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": data.get("prompt_eval_count", 0),
|
||||||
|
"completion_tokens": data.get("eval_count", 0),
|
||||||
|
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
|
||||||
|
},
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={
|
||||||
|
"done": data.get("done", False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise ProviderError(f"Ollama connection error: {e}")
|
||||||
|
|
||||||
|
async def stream_complete( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completion from Ollama."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True,
|
||||||
|
"options": {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
payload["options"]["num_predict"] = max_tokens
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
self._get_api_url("/api/chat"),
|
||||||
|
json=payload,
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line:
|
||||||
|
data = json.loads(line)
|
||||||
|
if "message" in data:
|
||||||
|
content = data["message"].get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise ProviderError(f"Ollama connection error: {e}")
|
||||||
|
|
||||||
|
async def pull_model(self, model: Optional[str] = None) -> bool:
|
||||||
|
"""Pull a model from Ollama registry."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_model = model or self.model
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
self._get_api_url("/api/pull"),
|
||||||
|
json={"name": target_model, "stream": False},
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
return True
|
||||||
|
except httpx.HTTPError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Ollama doesn't use API keys, always returns True."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available Ollama models."""
|
||||||
|
return [
|
||||||
|
"llama2",
|
||||||
|
"llama2-uncensored",
|
||||||
|
"mistral",
|
||||||
|
"mixtral",
|
||||||
|
"codellama",
|
||||||
|
"deepseek-coder",
|
||||||
|
"neural-chat",
|
||||||
|
]
|
||||||
149
app/src/promptforge/providers/openai.py
Normal file
149
app/src/promptforge/providers/openai.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import time
|
||||||
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError
|
||||||
|
|
||||||
|
from .base import ProviderBase, ProviderResponse
|
||||||
|
from ..core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(ProviderBase):
|
||||||
|
"""OpenAI GPT models provider."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "gpt-4",
|
||||||
|
temperature: float = 0.7,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize OpenAI provider."""
|
||||||
|
super().__init__(api_key, model, temperature, **kwargs)
|
||||||
|
self.base_url = base_url
|
||||||
|
self._client: Optional[AsyncOpenAI] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
def _get_client(self) -> AsyncOpenAI:
|
||||||
|
"""Get or create OpenAI client."""
|
||||||
|
if self._client is None:
|
||||||
|
api_key = self.api_key or self._get_api_key_from_env()
|
||||||
|
if not api_key:
|
||||||
|
raise ProviderError(
|
||||||
|
"OpenAI API key not configured. "
|
||||||
|
"Set OPENAI_API_KEY env var or pass api_key parameter."
|
||||||
|
)
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _get_api_key_from_env(self) -> Optional[str]:
|
||||||
|
import os
|
||||||
|
return os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderResponse:
|
||||||
|
"""Send completion request to OpenAI."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
response = await client.chat.completions.create( # type: ignore[arg-type]
|
||||||
|
model=self.model,
|
||||||
|
messages=messages, # type: ignore[union-attr]
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
return ProviderResponse(
|
||||||
|
content=response.choices[0].message.content or "",
|
||||||
|
model=self.model,
|
||||||
|
provider=self.name,
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens, # type: ignore[union-attr]
|
||||||
|
"completion_tokens": response.usage.completion_tokens, # type: ignore[union-attr]
|
||||||
|
"total_tokens": response.usage.total_tokens, # type: ignore[union-attr]
|
||||||
|
},
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except APIError as e:
|
||||||
|
raise ProviderError(f"OpenAI API error: {e}")
|
||||||
|
except RateLimitError as e:
|
||||||
|
raise ProviderError(f"OpenAI rate limit exceeded: {e}")
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise ProviderError(f"OpenAI connection error: {e}")
|
||||||
|
|
||||||
|
async def stream_complete( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completion from OpenAI."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
stream = await client.chat.completions.create( # type: ignore[arg-type]
|
||||||
|
model=self.model,
|
||||||
|
messages=messages, # type: ignore[union-attr]
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream: # type: ignore[union-attr]
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
except APIError as e:
|
||||||
|
raise ProviderError(f"OpenAI API error: {e}")
|
||||||
|
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Validate OpenAI API key."""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
api_key = self.api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return False
|
||||||
|
_ = AsyncOpenAI(api_key=api_key)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available OpenAI models."""
|
||||||
|
return [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
]
|
||||||
153
app/src/promptforge/registry/local.py
Normal file
153
app/src/promptforge/registry/local.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .models import RegistryEntry, RegistrySearchResult
|
||||||
|
from ..core.exceptions import RegistryError
|
||||||
|
|
||||||
|
|
||||||
|
class LocalRegistry:
|
||||||
|
"""Local prompt registry stored as JSON files."""
|
||||||
|
|
||||||
|
def __init__(self, registry_path: Optional[str] = None):
|
||||||
|
"""Initialize local registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry_path: Path to registry directory. Defaults to ~/.promptforge/registry
|
||||||
|
"""
|
||||||
|
self.registry_path = Path(registry_path or self._default_path())
|
||||||
|
self.registry_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._index_file = self.registry_path / "index.json"
|
||||||
|
|
||||||
|
def _default_path(self) -> str:
|
||||||
|
import os
|
||||||
|
return os.path.expanduser("~/.promptforge/registry")
|
||||||
|
|
||||||
|
def _load_index(self) -> Dict[str, RegistryEntry]:
|
||||||
|
"""Load registry index."""
|
||||||
|
if not self._index_file.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with open(self._index_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return {
|
||||||
|
k: RegistryEntry(**v) for k, v in data.items()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise RegistryError(f"Failed to load registry index: {e}")
|
||||||
|
|
||||||
|
def _save_index(self, index: Dict[str, RegistryEntry]) -> None:
|
||||||
|
"""Save registry index."""
|
||||||
|
try:
|
||||||
|
data = {k: v.model_dump() for k, v in index.items()}
|
||||||
|
with open(self._index_file, 'w') as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
raise RegistryError(f"Failed to save registry index: {e}")
|
||||||
|
|
||||||
|
def add(self, entry: RegistryEntry) -> None:
|
||||||
|
"""Add an entry to the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry: Registry entry to add.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
entry.id = str(entry.id or uuid.uuid4())
|
||||||
|
entry.added_at = entry.added_at or datetime.utcnow()
|
||||||
|
entry.updated_at = datetime.utcnow()
|
||||||
|
index[entry.id] = entry
|
||||||
|
self._save_index(index)
|
||||||
|
|
||||||
|
def remove(self, entry_id: str) -> bool:
|
||||||
|
"""Remove an entry from the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: ID of entry to remove.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was removed, False if not found.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
if entry_id not in index:
|
||||||
|
return False
|
||||||
|
del index[entry_id]
|
||||||
|
self._save_index(index)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(self, entry_id: str) -> Optional[RegistryEntry]:
|
||||||
|
"""Get an entry by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: Entry ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Registry entry or None if not found.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
return index.get(entry_id)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self,
|
||||||
|
tag: Optional[str] = None,
|
||||||
|
author: Optional[str] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> List[RegistryEntry]:
|
||||||
|
"""List entries in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: Filter by tag.
|
||||||
|
author: Filter by author.
|
||||||
|
limit: Maximum results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
results = list(index.values())
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
results = [e for e in results if tag in e.tags]
|
||||||
|
|
||||||
|
if author:
|
||||||
|
results = [e for e in results if e.author == author]
|
||||||
|
|
||||||
|
results.sort(key=lambda e: e.added_at or datetime.min, reverse=True)
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
def search(self, query: str) -> List[RegistrySearchResult]:
|
||||||
|
"""Search registry entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries with relevance scores.
|
||||||
|
"""
|
||||||
|
entries = self.list(limit=100)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
query_lower = query.lower()
|
||||||
|
for entry in entries:
|
||||||
|
score = 0
|
||||||
|
|
||||||
|
if query_lower in entry.name.lower():
|
||||||
|
score += 10
|
||||||
|
if entry.description and query_lower in entry.description.lower():
|
||||||
|
score += 5
|
||||||
|
if any(query_lower in tag for tag in entry.tags):
|
||||||
|
score += 3
|
||||||
|
|
||||||
|
if score > 0:
|
||||||
|
results.append(RegistrySearchResult(
|
||||||
|
entry=entry,
|
||||||
|
relevance_score=score,
|
||||||
|
))
|
||||||
|
|
||||||
|
return sorted(results, key=lambda r: r.relevance_score, reverse=True)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Get total number of entries."""
|
||||||
|
index = self._load_index()
|
||||||
|
return len(index)
|
||||||
86
app/src/promptforge/registry/models.py
Normal file
86
app/src/promptforge/registry/models.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..core.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryEntry(BaseModel):
|
||||||
|
"""Entry in the prompt registry."""
|
||||||
|
|
||||||
|
id: Optional[str] = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
content: str
|
||||||
|
author: Optional[str] = None
|
||||||
|
version: str = "1.0.0"
|
||||||
|
tags: List[str] = Field(default_factory=list)
|
||||||
|
provider: Optional[str] = None
|
||||||
|
variables: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
validation_rules: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
downloads: int = 0
|
||||||
|
likes: int = 0
|
||||||
|
rating: float = 0.0
|
||||||
|
is_local: bool = True
|
||||||
|
is_published: bool = False
|
||||||
|
added_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
def to_prompt_content(self) -> str:
|
||||||
|
"""Convert entry to prompt YAML content."""
|
||||||
|
from ..core.prompt import Prompt, PromptVariable, ValidationRule
|
||||||
|
|
||||||
|
variables = [PromptVariable(**v) for v in self.variables]
|
||||||
|
validation_rules = [ValidationRule(**r) for r in self.validation_rules]
|
||||||
|
|
||||||
|
prompt = Prompt(
|
||||||
|
id=str(self.id) if self.id else "",
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
content=self.content,
|
||||||
|
variables=variables,
|
||||||
|
validation_rules=validation_rules,
|
||||||
|
provider=self.provider,
|
||||||
|
tags=self.tags,
|
||||||
|
version=self.version,
|
||||||
|
)
|
||||||
|
return prompt.to_yaml()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_prompt(cls, prompt: "Prompt", author: Optional[str] = None) -> "RegistryEntry":
|
||||||
|
"""Create registry entry from a Prompt."""
|
||||||
|
return cls(
|
||||||
|
id=str(prompt.id),
|
||||||
|
name=prompt.name,
|
||||||
|
description=prompt.description,
|
||||||
|
content=prompt.content,
|
||||||
|
author=author,
|
||||||
|
version=prompt.version,
|
||||||
|
tags=prompt.tags,
|
||||||
|
provider=prompt.provider,
|
||||||
|
variables=[v.model_dump() for v in prompt.variables],
|
||||||
|
validation_rules=[r.model_dump() for r in prompt.validation_rules],
|
||||||
|
is_local=True,
|
||||||
|
added_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrySearchResult(BaseModel):
|
||||||
|
"""Search result from registry."""
|
||||||
|
|
||||||
|
entry: RegistryEntry
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
highlights: Dict[str, List[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryStats(BaseModel):
|
||||||
|
"""Registry statistics."""
|
||||||
|
|
||||||
|
total_entries: int = 0
|
||||||
|
local_entries: int = 0
|
||||||
|
published_entries: int = 0
|
||||||
|
popular_tags: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
top_authors: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
147
app/src/promptforge/registry/remote.py
Normal file
147
app/src/promptforge/registry/remote.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .models import RegistryEntry, RegistrySearchResult
|
||||||
|
from ..core.exceptions import RegistryError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .local import LocalRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteRegistry:
|
||||||
|
"""Remote prompt registry accessed via HTTP API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "https://registry.promptforge.io",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize remote registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Base URL of the registry API.
|
||||||
|
api_key: API key for authentication.
|
||||||
|
"""
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.api_key = api_key
|
||||||
|
self._session = requests.Session()
|
||||||
|
if api_key:
|
||||||
|
self._session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||||
|
|
||||||
|
def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
data: Optional[Dict] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Make HTTP request to registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method.
|
||||||
|
endpoint: API endpoint.
|
||||||
|
data: Request data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response JSON.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RegistryError: If request fails.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/v1{endpoint}"
|
||||||
|
try:
|
||||||
|
response = self._session.request(method, url, json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
raise RegistryError(f"Registry API error: {e.response.text}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RegistryError(f"Registry connection error: {e}")
|
||||||
|
|
||||||
|
def search(self, query: str, limit: int = 20) -> List[RegistrySearchResult]:
|
||||||
|
"""Search remote registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query.
|
||||||
|
limit: Maximum results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
|
data = self._request("GET", f"/search?q={query}&limit={limit}")
|
||||||
|
results = []
|
||||||
|
for item in data.get("results", []):
|
||||||
|
entry = RegistryEntry(**item)
|
||||||
|
results.append(RegistrySearchResult(
|
||||||
|
entry=entry,
|
||||||
|
relevance_score=item.get("score", 0),
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get(self, entry_id: str) -> Optional[RegistryEntry]:
|
||||||
|
"""Get entry by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: Entry ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Registry entry or None.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = self._request("GET", f"/entries/{entry_id}")
|
||||||
|
return RegistryEntry(**data)
|
||||||
|
except RegistryError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def pull(self, entry_id: str, local_registry: "LocalRegistry") -> bool:
|
||||||
|
"""Pull entry from remote to local registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: Entry ID to pull.
|
||||||
|
local_registry: Local registry to save to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful.
|
||||||
|
"""
|
||||||
|
entry = self.get(entry_id)
|
||||||
|
if entry is None:
|
||||||
|
return False
|
||||||
|
entry.is_local = True
|
||||||
|
local_registry.add(entry)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def publish(self, entry: RegistryEntry) -> RegistryEntry:
|
||||||
|
"""Publish entry to remote registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry: Entry to publish.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Published entry with server-assigned ID.
|
||||||
|
"""
|
||||||
|
data = self._request("POST", "/entries", entry.model_dump())
|
||||||
|
return RegistryEntry(**data)
|
||||||
|
|
||||||
|
def list_popular(self, limit: int = 10) -> List[RegistryEntry]:
|
||||||
|
"""List popular entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of popular entries.
|
||||||
|
"""
|
||||||
|
data = self._request("GET", f"/popular?limit={limit}")
|
||||||
|
return [RegistryEntry(**item) for item in data.get("entries", [])]
|
||||||
|
|
||||||
|
def validate_connection(self) -> bool:
|
||||||
|
"""Validate connection to remote registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._request("GET", "/health")
|
||||||
|
return True
|
||||||
|
except RegistryError:
|
||||||
|
return False
|
||||||
198
app/src/promptforge/testing/ab_test.py
Normal file
198
app/src/promptforge/testing/ab_test.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from ..core.prompt import Prompt
|
||||||
|
from ..providers import ProviderBase, ProviderResponse
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestConfig:
|
||||||
|
"""Configuration for A/B test."""
|
||||||
|
|
||||||
|
iterations: int = 3
|
||||||
|
provider: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: float = 0.7
|
||||||
|
parallel: bool = False
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestResult:
|
||||||
|
"""Result of a single test run."""
|
||||||
|
|
||||||
|
prompt: Prompt
|
||||||
|
response: ProviderResponse
|
||||||
|
variables: Dict[str, Any]
|
||||||
|
iteration: int
|
||||||
|
passed_validation: bool = False
|
||||||
|
validation_errors: List[str] = field(default_factory=list)
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestSummary:
|
||||||
|
"""Summary of A/B test results."""
|
||||||
|
|
||||||
|
prompt_name: str
|
||||||
|
config: ABTestConfig
|
||||||
|
total_runs: int
|
||||||
|
successful_runs: int
|
||||||
|
failed_runs: int
|
||||||
|
avg_latency_ms: float
|
||||||
|
avg_tokens: float
|
||||||
|
avg_cost: float
|
||||||
|
results: List[ABTestResult]
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class ABTest:
|
||||||
|
"""A/B test runner for comparing prompt variations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: ProviderBase,
|
||||||
|
config: Optional[ABTestConfig] = None,
|
||||||
|
):
|
||||||
|
"""Initialize A/B test runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: LLM provider to use.
|
||||||
|
config: Test configuration.
|
||||||
|
"""
|
||||||
|
self.provider = provider
|
||||||
|
self.config = config or ABTestConfig()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
) -> ABTestSummary:
|
||||||
|
"""Run A/B test on a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt to test.
|
||||||
|
variables: Variables to substitute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ABTestSummary with all test results.
|
||||||
|
"""
|
||||||
|
results: List[ABTestResult] = []
|
||||||
|
latencies = []
|
||||||
|
total_tokens = []
|
||||||
|
|
||||||
|
for i in range(self.config.iterations):
|
||||||
|
try:
|
||||||
|
result = await self._run_single(prompt, variables, i + 1)
|
||||||
|
results.append(result)
|
||||||
|
latencies.append(result.latency_ms)
|
||||||
|
total_tokens.append(result.response.usage.get("total_tokens", 0))
|
||||||
|
except Exception:
|
||||||
|
results.append(ABTestResult(
|
||||||
|
prompt=prompt,
|
||||||
|
response=ProviderResponse(
|
||||||
|
content="",
|
||||||
|
model=prompt.provider or self.provider.name,
|
||||||
|
provider=self.provider.name,
|
||||||
|
),
|
||||||
|
variables=variables,
|
||||||
|
iteration=i + 1,
|
||||||
|
passed_validation=False,
|
||||||
|
validation_errors=["Test execution failed"],
|
||||||
|
))
|
||||||
|
|
||||||
|
successful = sum(1 for r in results if r.passed_validation or r.response.content)
|
||||||
|
|
||||||
|
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
||||||
|
avg_tokens = sum(total_tokens) / len(total_tokens) if total_tokens else 0
|
||||||
|
|
||||||
|
return ABTestSummary(
|
||||||
|
prompt_name=prompt.name,
|
||||||
|
config=self.config,
|
||||||
|
total_runs=self.config.iterations,
|
||||||
|
successful_runs=successful,
|
||||||
|
failed_runs=self.config.iterations - successful,
|
||||||
|
avg_latency_ms=avg_latency,
|
||||||
|
avg_tokens=avg_tokens,
|
||||||
|
avg_cost=self._estimate_cost(avg_tokens),
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_comparison(
|
||||||
|
self,
|
||||||
|
prompts: List[Prompt],
|
||||||
|
shared_variables: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, ABTestSummary]:
|
||||||
|
"""Run tests on multiple prompts for comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompts to compare.
|
||||||
|
shared_variables: Variables shared across all prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping prompt names to their summaries.
|
||||||
|
"""
|
||||||
|
shared_variables = shared_variables or {}
|
||||||
|
summaries = {}
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
variables = self._merge_variables(prompt, shared_variables)
|
||||||
|
summary = await self.run(prompt, variables)
|
||||||
|
summaries[prompt.name] = summary
|
||||||
|
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
async def _run_single(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
iteration: int,
|
||||||
|
) -> ABTestResult:
|
||||||
|
"""Run a single test iteration."""
|
||||||
|
from ..core.template import TemplateEngine
|
||||||
|
template_engine = TemplateEngine()
|
||||||
|
|
||||||
|
rendered = template_engine.render(prompt.content, variables, prompt.variables)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
response = await self.provider.complete(
|
||||||
|
prompt=rendered,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
return ABTestResult(
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
variables=variables,
|
||||||
|
iteration=iteration,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_variables(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
shared: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge shared variables with prompt-specific ones."""
|
||||||
|
variables = shared.copy()
|
||||||
|
for var in prompt.variables:
|
||||||
|
if var.name not in variables and var.default is not None:
|
||||||
|
variables[var.name] = var.default
|
||||||
|
return variables
|
||||||
|
|
||||||
|
def _estimate_cost(self, tokens: float) -> float:
|
||||||
|
"""Estimate cost based on token usage."""
|
||||||
|
rates = {
|
||||||
|
"gpt-4": 0.00003,
|
||||||
|
"gpt-4-turbo": 0.00001,
|
||||||
|
"gpt-3.5-turbo": 0.0000005,
|
||||||
|
"claude-3-sonnet-20240229": 0.000003,
|
||||||
|
"claude-3-opus-20240229": 0.000015,
|
||||||
|
}
|
||||||
|
rate = rates.get(self.provider.model, 0.000001)
|
||||||
|
return tokens * rate
|
||||||
248
app/src/promptforge/testing/validator.py
Normal file
248
app/src/promptforge/testing/validator.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Validator(ABC):
|
||||||
|
"""Abstract base class for validators."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str:
|
||||||
|
"""Get validator name."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RegexValidator(Validator):
|
||||||
|
"""Validates responses against regex patterns."""
|
||||||
|
|
||||||
|
def __init__(self, pattern: str, flags: int = 0):
|
||||||
|
"""Initialize regex validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Regex pattern to match.
|
||||||
|
flags: Regex flags (e.g., re.IGNORECASE).
|
||||||
|
"""
|
||||||
|
self.pattern = pattern
|
||||||
|
self.flags = flags
|
||||||
|
self._regex = re.compile(pattern, flags)
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate response matches regex pattern."""
|
||||||
|
if not self._regex.search(response):
|
||||||
|
return False, f"Response does not match pattern: {self.pattern}"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return f"regex({self.pattern})"
|
||||||
|
|
||||||
|
|
||||||
|
class JSONSchemaValidator(Validator):
|
||||||
|
"""Validates JSON responses against a schema."""
|
||||||
|
|
||||||
|
def __init__(self, schema: Dict[str, Any]):
|
||||||
|
"""Initialize JSON schema validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON schema to validate against.
|
||||||
|
"""
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate JSON response against schema."""
|
||||||
|
try:
|
||||||
|
data = json.loads(response)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return False, f"Invalid JSON: {e}"
|
||||||
|
|
||||||
|
errors = self._validate_object(data, self.schema, "")
|
||||||
|
if errors:
|
||||||
|
return False, "; ".join(errors)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def _validate_object(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
schema: Dict[str, Any],
|
||||||
|
path: str,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Recursively validate against schema."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if "type" in schema:
|
||||||
|
expected_type = schema["type"]
|
||||||
|
type_checks = {
|
||||||
|
"array": (list, "array"),
|
||||||
|
"object": (dict, "object"),
|
||||||
|
"string": (str, "string"),
|
||||||
|
"number": ((int, float), "number"),
|
||||||
|
"boolean": (bool, "boolean"),
|
||||||
|
"integer": ((int,), "integer"),
|
||||||
|
}
|
||||||
|
if expected_type in type_checks:
|
||||||
|
expected_class, type_name = type_checks[expected_type]
|
||||||
|
if not isinstance(data, expected_class): # type: ignore[arg-type]
|
||||||
|
actual_type = type(data).__name__
|
||||||
|
errors.append(f"{path}: expected {type_name}, got {actual_type}")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
if "properties" in schema and isinstance(data, dict):
|
||||||
|
for prop, prop_schema in schema["properties"].items():
|
||||||
|
if prop in data:
|
||||||
|
errors.extend(
|
||||||
|
self._validate_object(data[prop], prop_schema, f"{path}.{prop}")
|
||||||
|
)
|
||||||
|
elif prop_schema.get("required", False):
|
||||||
|
errors.append(f"{path}.{prop}: required property missing")
|
||||||
|
|
||||||
|
if "enum" in schema and data not in schema["enum"]:
|
||||||
|
errors.append(f"{path}: value must be one of {schema['enum']}")
|
||||||
|
|
||||||
|
if "minLength" in schema and isinstance(data, str):
|
||||||
|
if len(data) < schema["minLength"]:
|
||||||
|
errors.append(f"{path}: string too short (min {schema['minLength']})")
|
||||||
|
|
||||||
|
if "maxLength" in schema and isinstance(data, str):
|
||||||
|
if len(data) > schema["maxLength"]:
|
||||||
|
errors.append(f"{path}: string too long (max {schema['maxLength']})")
|
||||||
|
|
||||||
|
if "minimum" in schema and isinstance(data, (int, float)):
|
||||||
|
if data < schema["minimum"]:
|
||||||
|
errors.append(f"{path}: value below minimum ({schema['minimum']})")
|
||||||
|
|
||||||
|
if "maximum" in schema and isinstance(data, (int, float)):
|
||||||
|
if data > schema["maximum"]:
|
||||||
|
errors.append(f"{path}: value above maximum ({schema['maximum']})")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "json-schema"
|
||||||
|
|
||||||
|
|
||||||
|
class LengthValidator(Validator):
|
||||||
|
"""Validates response length constraints."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Initialize length validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_length: Minimum number of characters.
|
||||||
|
max_length: Maximum number of characters.
|
||||||
|
"""
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate response length."""
|
||||||
|
if self.min_length is not None and len(response) < self.min_length:
|
||||||
|
return False, f"Response too short (min {self.min_length} chars)"
|
||||||
|
if self.max_length is not None and len(response) > self.max_length:
|
||||||
|
return False, f"Response too long (max {self.max_length} chars)"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
parts = ["length"]
|
||||||
|
if self.min_length:
|
||||||
|
parts.append(f"min={self.min_length}")
|
||||||
|
if self.max_length:
|
||||||
|
parts.append(f"max={self.max_length}")
|
||||||
|
return "(" + ", ".join(parts) + ")"
|
||||||
|
|
||||||
|
|
||||||
|
class ContainsValidator(Validator):
|
||||||
|
"""Validates response contains expected content."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
required_strings: List[str],
|
||||||
|
all_required: bool = False,
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize contains validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_strings: Strings that must be present.
|
||||||
|
all_required: If True, all strings must be present.
|
||||||
|
case_sensitive: Whether to match case.
|
||||||
|
"""
|
||||||
|
self.required_strings = required_strings
|
||||||
|
self.all_required = all_required
|
||||||
|
self.case_sensitive = case_sensitive
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate response contains required strings."""
|
||||||
|
strings = self.required_strings
|
||||||
|
response_lower = response.lower() if not self.case_sensitive else response
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
for s in strings:
|
||||||
|
check_str = s.lower() if not self.case_sensitive else s
|
||||||
|
if check_str not in response_lower:
|
||||||
|
missing.append(s)
|
||||||
|
|
||||||
|
if self.all_required:
|
||||||
|
if missing:
|
||||||
|
return False, f"Missing required content: {', '.join(missing)}"
|
||||||
|
else:
|
||||||
|
if len(missing) == len(strings):
|
||||||
|
return False, "Response does not contain any expected content"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
mode = "all" if self.all_required else "any"
|
||||||
|
return f"contains({mode}, {self.required_strings})"
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeValidator(Validator):
|
||||||
|
"""Combines multiple validators."""
|
||||||
|
|
||||||
|
def __init__(self, validators: List[Validator], mode: str = "all"):
|
||||||
|
"""Initialize composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
validators: List of validators to combine.
|
||||||
|
mode: "all" (AND) or "any" (OR) behavior.
|
||||||
|
"""
|
||||||
|
self.validators = validators
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate using all validators."""
|
||||||
|
results = [v.validate(response) for v in self.validators]
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if self.mode == "all":
|
||||||
|
for valid, error in results:
|
||||||
|
if not valid:
|
||||||
|
errors.append(error)
|
||||||
|
if errors:
|
||||||
|
return False, "; ".join(e for e in errors if e)
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
for valid, _ in results:
|
||||||
|
if valid:
|
||||||
|
return True, None
|
||||||
|
return False, "No validator passed"
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
names = [v.get_name() for v in self.validators]
|
||||||
|
return f"composite({self.mode}, {names})"
|
||||||
183
app/tests/test_core.py
Normal file
183
app/tests/test_core.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from promptforge.core.prompt import Prompt, PromptVariable, VariableType
|
||||||
|
from promptforge.core.template import TemplateEngine
|
||||||
|
from promptforge.core.exceptions import MissingVariableError, InvalidPromptError
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptModel:
|
||||||
|
"""Tests for Prompt model."""
|
||||||
|
|
||||||
|
def test_prompt_creation(self):
|
||||||
|
"""Test basic prompt creation."""
|
||||||
|
prompt = Prompt(
|
||||||
|
name="Test Prompt",
|
||||||
|
content="Hello, {{name}}!",
|
||||||
|
variables=[
|
||||||
|
PromptVariable(name="name", type=VariableType.STRING, required=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert prompt.name == "Test Prompt"
|
||||||
|
assert len(prompt.variables) == 1
|
||||||
|
assert prompt.version == "1.0.0"
|
||||||
|
|
||||||
|
def test_prompt_from_yaml(self, sample_prompt_yaml):
|
||||||
|
"""Test parsing prompt from YAML."""
|
||||||
|
prompt = Prompt.from_yaml(sample_prompt_yaml)
|
||||||
|
assert prompt.name == "Test Prompt"
|
||||||
|
assert prompt.description == "A test prompt for unit testing"
|
||||||
|
assert len(prompt.variables) == 2
|
||||||
|
assert prompt.variables[0].name == "name"
|
||||||
|
assert prompt.variables[0].required is True
|
||||||
|
|
||||||
|
def test_prompt_to_yaml(self, sample_prompt_yaml):
|
||||||
|
"""Test exporting prompt to YAML."""
|
||||||
|
prompt = Prompt.from_yaml(sample_prompt_yaml)
|
||||||
|
yaml_str = prompt.to_yaml()
|
||||||
|
assert "---\n" in yaml_str
|
||||||
|
assert "name: Test Prompt" in yaml_str
|
||||||
|
|
||||||
|
def test_prompt_save_and_load(self, temp_prompts_dir, sample_prompt_yaml):
|
||||||
|
"""Test saving and loading prompts."""
|
||||||
|
prompt = Prompt.from_yaml(sample_prompt_yaml)
|
||||||
|
filepath = prompt.save(temp_prompts_dir)
|
||||||
|
|
||||||
|
assert filepath.exists()
|
||||||
|
loaded = Prompt.load(filepath)
|
||||||
|
assert loaded.name == prompt.name
|
||||||
|
assert loaded.content == prompt.content
|
||||||
|
|
||||||
|
def test_prompt_list(self, temp_prompts_dir):
|
||||||
|
"""Test listing prompts."""
|
||||||
|
prompts = Prompt.list(temp_prompts_dir)
|
||||||
|
assert len(prompts) == 0
|
||||||
|
|
||||||
|
prompt = Prompt(name="Test1", content="Test")
|
||||||
|
prompt.save(temp_prompts_dir)
|
||||||
|
|
||||||
|
prompts = Prompt.list(temp_prompts_dir)
|
||||||
|
assert len(prompts) == 1
|
||||||
|
assert prompts[0].name == "Test1"
|
||||||
|
|
||||||
|
def test_variable_types(self):
|
||||||
|
"""Test different variable types."""
|
||||||
|
string_var = PromptVariable(name="text", type=VariableType.STRING)
|
||||||
|
int_var = PromptVariable(name="number", type=VariableType.INTEGER)
|
||||||
|
bool_var = PromptVariable(name="flag", type=VariableType.BOOLEAN)
|
||||||
|
choice_var = PromptVariable(
|
||||||
|
name="choice",
|
||||||
|
type=VariableType.CHOICE,
|
||||||
|
choices=["a", "b", "c"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert string_var.type == VariableType.STRING
|
||||||
|
assert int_var.type == VariableType.INTEGER
|
||||||
|
assert bool_var.type == VariableType.BOOLEAN
|
||||||
|
assert choice_var.choices == ["a", "b", "c"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestTemplateEngine:
|
||||||
|
"""Tests for TemplateEngine."""
|
||||||
|
|
||||||
|
def test_get_variables(self):
|
||||||
|
"""Test extracting variables from template."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
content = "Hello, {{name}}! You have {{count}} items."
|
||||||
|
vars = engine.get_variables(content)
|
||||||
|
assert "name" in vars
|
||||||
|
assert "count" in vars
|
||||||
|
|
||||||
|
def test_render_basic(self):
|
||||||
|
"""Test basic template rendering."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
content = "Hello, {{name}}!"
|
||||||
|
result = engine.render(content, {"name": "World"})
|
||||||
|
assert result == "Hello, World!"
|
||||||
|
|
||||||
|
def test_render_multiple_vars(self, sample_prompt_content):
|
||||||
|
"""Test rendering with multiple variables."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
result = engine.render(
|
||||||
|
sample_prompt_content,
|
||||||
|
{"name": "Alice", "count": 5}
|
||||||
|
)
|
||||||
|
assert result == "Hello, Alice! You have 5 messages."
|
||||||
|
|
||||||
|
def test_render_missing_required(self):
|
||||||
|
"""Test missing required variable raises error."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
content = "Hello, {{name}}!"
|
||||||
|
with pytest.raises(Exception): # StrictUndefined raises on missing vars
|
||||||
|
engine.render(content, {})
|
||||||
|
|
||||||
|
def test_render_with_defaults(self):
|
||||||
|
"""Test rendering with default values."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
content = "Hello, {{name}}!"
|
||||||
|
variables = [
|
||||||
|
PromptVariable(name="name", required=False, default="Guest")
|
||||||
|
]
|
||||||
|
result = engine.render(content, {}, variables)
|
||||||
|
assert result == "Hello, Guest!"
|
||||||
|
|
||||||
|
def test_validate_variables_valid(self):
|
||||||
|
"""Test variable validation with valid values."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
variables = [
|
||||||
|
PromptVariable(name="name", type=VariableType.STRING, required=True),
|
||||||
|
PromptVariable(name="count", type=VariableType.INTEGER, required=False),
|
||||||
|
]
|
||||||
|
errors = engine.validate_variables(
|
||||||
|
{"name": "Alice", "count": 5},
|
||||||
|
variables
|
||||||
|
)
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
def test_validate_variables_missing_required(self):
|
||||||
|
"""Test validation fails for missing required variable."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
variables = [
|
||||||
|
PromptVariable(name="name", type=VariableType.STRING, required=True),
|
||||||
|
]
|
||||||
|
errors = engine.validate_variables({}, variables)
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "name" in errors[0]
|
||||||
|
|
||||||
|
def test_validate_variables_type_error(self):
|
||||||
|
"""Test validation fails for wrong type."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
variables = [
|
||||||
|
PromptVariable(name="count", type=VariableType.INTEGER, required=True),
|
||||||
|
]
|
||||||
|
errors = engine.validate_variables({"count": "not a number"}, variables)
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "integer" in errors[0].lower()
|
||||||
|
|
||||||
|
def test_validate_choices(self):
|
||||||
|
"""Test choice validation."""
|
||||||
|
engine = TemplateEngine()
|
||||||
|
variables = [
|
||||||
|
PromptVariable(
|
||||||
|
name="color",
|
||||||
|
type=VariableType.CHOICE,
|
||||||
|
required=True,
|
||||||
|
choices=["red", "green", "blue"]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
errors = engine.validate_variables({"color": "yellow"}, variables)
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert "one of" in errors[0].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptions:
|
||||||
|
"""Tests for custom exceptions."""
|
||||||
|
|
||||||
|
def test_missing_variable_error(self):
|
||||||
|
"""Test MissingVariableError."""
|
||||||
|
error = MissingVariableError("Missing: name")
|
||||||
|
assert "name" in str(error)
|
||||||
|
|
||||||
|
def test_invalid_prompt_error(self):
|
||||||
|
"""Test InvalidPromptError."""
|
||||||
|
error = InvalidPromptError("Invalid YAML")
|
||||||
|
assert "YAML" in str(error)
|
||||||
137
app/tests/test_providers.py
Normal file
137
app/tests/test_providers.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from promptforge.providers.base import ProviderBase, ProviderResponse
|
||||||
|
from promptforge.providers.factory import ProviderFactory
|
||||||
|
from promptforge.providers.openai import OpenAIProvider
|
||||||
|
from promptforge.providers.anthropic import AnthropicProvider
|
||||||
|
from promptforge.providers.ollama import OllamaProvider
|
||||||
|
from promptforge.core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderBase:
|
||||||
|
"""Tests for ProviderBase abstract class."""
|
||||||
|
|
||||||
|
def test_response_creation(self):
|
||||||
|
"""Test ProviderResponse creation."""
|
||||||
|
response = ProviderResponse(
|
||||||
|
content="Hello",
|
||||||
|
model="gpt-4",
|
||||||
|
provider="openai",
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8},
|
||||||
|
latency_ms=100.5,
|
||||||
|
)
|
||||||
|
assert response.content == "Hello"
|
||||||
|
assert response.usage["total_tokens"] == 8
|
||||||
|
|
||||||
|
def test_provider_requires_implementation(self):
|
||||||
|
"""Test that ProviderBase requires implementation."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
_ = ProviderBase()
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderFactory:
|
||||||
|
"""Tests for ProviderFactory."""
|
||||||
|
|
||||||
|
def test_list_providers(self):
|
||||||
|
"""Test listing available providers."""
|
||||||
|
providers = ProviderFactory.list_providers()
|
||||||
|
assert "openai" in providers
|
||||||
|
assert "anthropic" in providers
|
||||||
|
assert "ollama" in providers
|
||||||
|
|
||||||
|
def test_create_openai(self):
|
||||||
|
"""Test creating OpenAI provider."""
|
||||||
|
provider = ProviderFactory.create("openai", model="gpt-4")
|
||||||
|
assert isinstance(provider, OpenAIProvider)
|
||||||
|
assert provider.model == "gpt-4"
|
||||||
|
|
||||||
|
def test_create_anthropic(self):
|
||||||
|
"""Test creating Anthropic provider."""
|
||||||
|
provider = ProviderFactory.create("anthropic", model="claude-3")
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
assert provider.model == "claude-3"
|
||||||
|
|
||||||
|
def test_create_ollama(self):
|
||||||
|
"""Test creating Ollama provider."""
|
||||||
|
provider = ProviderFactory.create("ollama", model="llama2")
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
assert provider.model == "llama2"
|
||||||
|
|
||||||
|
def test_create_unknown_provider(self):
|
||||||
|
"""Test creating unknown provider raises error."""
|
||||||
|
with pytest.raises(ProviderError):
|
||||||
|
ProviderFactory.create("unknown")
|
||||||
|
|
||||||
|
def test_provider_temperature(self):
|
||||||
|
"""Test provider temperature setting."""
|
||||||
|
provider = ProviderFactory.create("openai", temperature=0.5)
|
||||||
|
assert provider.temperature == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIProvider:
|
||||||
|
"""Tests for OpenAIProvider."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name."""
|
||||||
|
provider = OpenAIProvider()
|
||||||
|
assert provider.name == "openai"
|
||||||
|
|
||||||
|
def test_list_models(self):
|
||||||
|
"""Test listing available models."""
|
||||||
|
provider = OpenAIProvider()
|
||||||
|
models = provider.list_models()
|
||||||
|
assert "gpt-4" in models
|
||||||
|
assert "gpt-3.5-turbo" in models
|
||||||
|
|
||||||
|
def test_validate_api_key_missing(self):
|
||||||
|
"""Test API key validation when missing."""
|
||||||
|
provider = OpenAIProvider(api_key=None)
|
||||||
|
with patch.dict('os.environ', {}, clear=True):
|
||||||
|
assert provider.validate_api_key() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicProvider:
|
||||||
|
"""Tests for AnthropicProvider."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name."""
|
||||||
|
provider = AnthropicProvider()
|
||||||
|
assert provider.name == "anthropic"
|
||||||
|
|
||||||
|
def test_list_models(self):
|
||||||
|
"""Test listing available models."""
|
||||||
|
provider = AnthropicProvider()
|
||||||
|
models = provider.list_models()
|
||||||
|
assert "claude-3-sonnet-20240229" in models
|
||||||
|
|
||||||
|
def test_validate_api_key_missing(self):
|
||||||
|
"""Test API key validation when missing."""
|
||||||
|
provider = AnthropicProvider(api_key=None)
|
||||||
|
with patch.dict('os.environ', {}, clear=True):
|
||||||
|
assert provider.validate_api_key() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaProvider:
|
||||||
|
"""Tests for OllamaProvider."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.name == "ollama"
|
||||||
|
|
||||||
|
def test_list_models(self):
|
||||||
|
"""Test listing available models."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
models = provider.list_models()
|
||||||
|
assert "llama2" in models
|
||||||
|
|
||||||
|
def test_validate_api_key_not_needed(self):
|
||||||
|
"""Test Ollama doesn't require API key."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.validate_api_key() is True
|
||||||
|
|
||||||
|
def test_default_base_url(self):
|
||||||
|
"""Test default Ollama base URL."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.base_url == "http://localhost:11434"
|
||||||
239
app/tests/test_testing.py
Normal file
239
app/tests/test_testing.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
from promptforge.testing.validator import (
|
||||||
|
RegexValidator,
|
||||||
|
JSONSchemaValidator,
|
||||||
|
LengthValidator,
|
||||||
|
ContainsValidator,
|
||||||
|
CompositeValidator,
|
||||||
|
)
|
||||||
|
from promptforge.testing.metrics import MetricsCollector, MetricsSample
|
||||||
|
from promptforge.testing.results import TestSessionResults, ResultFormatter
|
||||||
|
from promptforge.testing.ab_test import ABTestConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegexValidator:
|
||||||
|
"""Tests for RegexValidator."""
|
||||||
|
|
||||||
|
def test_valid_pattern(self):
|
||||||
|
"""Test matching pattern."""
|
||||||
|
validator = RegexValidator(r"^Hello.*")
|
||||||
|
is_valid, error = validator.validate("Hello, World!")
|
||||||
|
assert is_valid is True
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_invalid_pattern(self):
|
||||||
|
"""Test non-matching pattern."""
|
||||||
|
validator = RegexValidator(r"^Hello.*")
|
||||||
|
is_valid, error = validator.validate("Goodbye")
|
||||||
|
assert is_valid is False
|
||||||
|
assert error is not None
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
"""Test case insensitive matching."""
|
||||||
|
validator = RegexValidator(r"hello", flags=2) # re.IGNORECASE
|
||||||
|
is_valid, _ = validator.validate("HELLO")
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONSchemaValidator:
|
||||||
|
"""Tests for JSONSchemaValidator."""
|
||||||
|
|
||||||
|
def test_valid_json(self):
|
||||||
|
"""Test valid JSON against schema."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
validator = JSONSchemaValidator(schema)
|
||||||
|
is_valid, error = validator.validate('{"name": "Alice", "age": 30}')
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
def test_invalid_json(self):
|
||||||
|
"""Test invalid JSON."""
|
||||||
|
validator = JSONSchemaValidator({})
|
||||||
|
is_valid, error = validator.validate("not json")
|
||||||
|
assert is_valid is False
|
||||||
|
assert "JSON" in error
|
||||||
|
|
||||||
|
def test_type_mismatch(self):
|
||||||
|
"""Test type mismatch in schema."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"count": {"type": "integer"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
validator = JSONSchemaValidator(schema)
|
||||||
|
is_valid, error = validator.validate('{"count": "not a number"}')
|
||||||
|
assert is_valid is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestLengthValidator:
|
||||||
|
"""Tests for LengthValidator."""
|
||||||
|
|
||||||
|
def test_within_bounds(self):
|
||||||
|
"""Test valid length."""
|
||||||
|
validator = LengthValidator(min_length=5, max_length=100)
|
||||||
|
is_valid, error = validator.validate("Hello, World!")
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
def test_too_short(self):
|
||||||
|
"""Test string too short."""
|
||||||
|
validator = LengthValidator(min_length=10)
|
||||||
|
is_valid, error = validator.validate("Hi")
|
||||||
|
assert is_valid is False
|
||||||
|
assert "short" in error
|
||||||
|
|
||||||
|
def test_too_long(self):
|
||||||
|
"""Test string too long."""
|
||||||
|
validator = LengthValidator(max_length=5)
|
||||||
|
is_valid, error = validator.validate("Hello, World!")
|
||||||
|
assert is_valid is False
|
||||||
|
assert "long" in error
|
||||||
|
|
||||||
|
|
||||||
|
class TestContainsValidator:
|
||||||
|
"""Tests for ContainsValidator."""
|
||||||
|
|
||||||
|
def test_contains_string(self):
|
||||||
|
"""Test string contains required content."""
|
||||||
|
validator = ContainsValidator(required_strings=["hello", "world"])
|
||||||
|
is_valid, error = validator.validate("Say hello to the world")
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
def test_missing_content(self):
|
||||||
|
"""Test missing required content."""
|
||||||
|
validator = ContainsValidator(required_strings=["hello", "world"])
|
||||||
|
is_valid, error = validator.validate("Just some random text")
|
||||||
|
assert is_valid is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompositeValidator:
|
||||||
|
"""Tests for CompositeValidator."""
|
||||||
|
|
||||||
|
def test_all_mode(self):
|
||||||
|
"""Test AND mode validation."""
|
||||||
|
validators = [
|
||||||
|
RegexValidator(r"^Hello.*"),
|
||||||
|
LengthValidator(min_length=5),
|
||||||
|
]
|
||||||
|
composite = CompositeValidator(validators, mode="all")
|
||||||
|
is_valid, _ = composite.validate("Hello, World!")
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
def test_any_mode(self):
|
||||||
|
"""Test OR mode validation."""
|
||||||
|
validators = [
|
||||||
|
RegexValidator(r"^Hello.*"),
|
||||||
|
RegexValidator(r"^Goodbye.*"),
|
||||||
|
]
|
||||||
|
composite = CompositeValidator(validators, mode="any")
|
||||||
|
is_valid, _ = composite.validate("Goodbye, World!")
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsCollector:
|
||||||
|
"""Tests for MetricsCollector."""
|
||||||
|
|
||||||
|
def test_record_sample(self):
|
||||||
|
"""Test recording metrics sample."""
|
||||||
|
collector = MetricsCollector()
|
||||||
|
sample = MetricsSample(
|
||||||
|
latency_ms=100.0,
|
||||||
|
tokens_total=50,
|
||||||
|
validation_passed=True,
|
||||||
|
)
|
||||||
|
collector.record(sample)
|
||||||
|
|
||||||
|
summary = collector.get_summary()
|
||||||
|
assert summary.count == 1
|
||||||
|
assert summary.latency["avg"] == 100.0
|
||||||
|
|
||||||
|
def test_record_from_response(self):
|
||||||
|
"""Test recording from provider response."""
|
||||||
|
collector = MetricsCollector()
|
||||||
|
sample = collector.record_from_response(
|
||||||
|
latency_ms=50.0,
|
||||||
|
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||||
|
validation_passed=True,
|
||||||
|
)
|
||||||
|
assert sample.tokens_total == 15
|
||||||
|
|
||||||
|
def test_clear_samples(self):
|
||||||
|
"""Test clearing samples."""
|
||||||
|
collector = MetricsCollector()
|
||||||
|
collector.record(MetricsSample())
|
||||||
|
collector.clear()
|
||||||
|
|
||||||
|
summary = collector.get_summary()
|
||||||
|
assert summary.count == 0
|
||||||
|
|
||||||
|
def test_compare_collectors(self):
|
||||||
|
"""Test comparing two collectors."""
|
||||||
|
collector1 = MetricsCollector()
|
||||||
|
collector1.record(MetricsSample(latency_ms=100.0))
|
||||||
|
|
||||||
|
collector2 = MetricsCollector()
|
||||||
|
collector2.record(MetricsSample(latency_ms=200.0))
|
||||||
|
|
||||||
|
comparison = collector1.compare(collector2)
|
||||||
|
assert comparison["latency_delta_ms"] == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultFormatter:
|
||||||
|
"""Tests for ResultFormatter."""
|
||||||
|
|
||||||
|
def test_format_text(self):
|
||||||
|
"""Test formatting results as text."""
|
||||||
|
from promptforge.testing.results import TestResult
|
||||||
|
results = TestSessionResults(
|
||||||
|
test_id="test-123",
|
||||||
|
name="My Test",
|
||||||
|
)
|
||||||
|
results.results.append(TestResult(
|
||||||
|
test_id="1",
|
||||||
|
prompt_name="prompt1",
|
||||||
|
provider="openai",
|
||||||
|
success=True,
|
||||||
|
response="Hello",
|
||||||
|
))
|
||||||
|
|
||||||
|
formatted = ResultFormatter.format_text(results)
|
||||||
|
assert "My Test" in formatted
|
||||||
|
assert "PASS" in formatted
|
||||||
|
|
||||||
|
def test_format_json(self):
|
||||||
|
"""Test formatting results as JSON."""
|
||||||
|
results = TestSessionResults(
|
||||||
|
test_id="test-123",
|
||||||
|
name="My Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = ResultFormatter.format_json(results)
|
||||||
|
import json
|
||||||
|
data = json.loads(formatted)
|
||||||
|
assert data["name"] == "My Test"
|
||||||
|
assert "results" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestABTestConfig:
|
||||||
|
"""Tests for ABTestConfig."""
|
||||||
|
|
||||||
|
def test_default_config(self):
|
||||||
|
"""Test default A/B test configuration."""
|
||||||
|
config = ABTestConfig()
|
||||||
|
assert config.iterations == 3
|
||||||
|
assert config.parallel is False
|
||||||
|
assert config.temperature == 0.7
|
||||||
|
|
||||||
|
def test_custom_config(self):
|
||||||
|
"""Test custom A/B test configuration."""
|
||||||
|
config = ABTestConfig(
|
||||||
|
iterations=5,
|
||||||
|
parallel=True,
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
|
assert config.iterations == 5
|
||||||
|
assert config.parallel is True
|
||||||
@@ -1 +1,64 @@
|
|||||||
/app/pyproject.toml
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "promptforge"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A CLI tool for versioning, testing, and sharing AI prompts across different LLM providers"
|
||||||
|
readme = "README.md"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
authors = [
|
||||||
|
{name = "PromptForge Team"}
|
||||||
|
]
|
||||||
|
keywords = ["cli", "prompt", "ai", "llm", "versioning"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Environment :: Console",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"click>=8.1.0",
|
||||||
|
"pyyaml>=6.0.1",
|
||||||
|
"gitpython>=3.1.40",
|
||||||
|
"openai>=1.3.0",
|
||||||
|
"anthropic>=0.18.0",
|
||||||
|
"rich>=13.6.0",
|
||||||
|
"jinja2>=3.1.2",
|
||||||
|
"pydantic>=2.5.0",
|
||||||
|
"requests>=2.31.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.4.0",
|
||||||
|
"pytest-cov>=4.1.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
pf = "promptforge.cli.main:main"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ['py310']
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
|||||||
@@ -1 +1,11 @@
|
|||||||
/app/requirements.txt
|
pyyaml>=6.0.1
|
||||||
|
click>=8.1.0
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-cov>=4.1.0
|
||||||
|
gitpython>=3.1.40
|
||||||
|
openai>=1.3.0
|
||||||
|
anthropic>=0.18.0
|
||||||
|
rich>=13.6.0
|
||||||
|
jinja2>=3.1.2
|
||||||
|
pydantic>=2.5.0
|
||||||
|
requests>=2.31.0
|
||||||
|
|||||||
@@ -1 +1,7 @@
|
|||||||
/app/setup.cfg
|
[flake8]
|
||||||
|
max-line-length = 100
|
||||||
|
exclude = .git,__pycache__,build,dist
|
||||||
|
|
||||||
|
[isort]
|
||||||
|
profile = black
|
||||||
|
line_length = 100
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import os
|
"""PromptForge - A CLI tool for versioning, testing, and sharing AI prompts."""
|
||||||
import sys
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""PromptForge CLI interface."""
|
||||||
|
|
||||||
|
from .main import main
|
||||||
|
|
||||||
|
__all__ = ["main"]
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""CLI command imports."""
|
||||||
|
|
||||||
|
from .init import init
|
||||||
|
from .prompt import prompt
|
||||||
|
from .run import run
|
||||||
|
from .test import test
|
||||||
|
from .registry import registry
|
||||||
|
from .version import version
|
||||||
|
|
||||||
|
__all__ = ["init", "prompt", "run", "test", "registry", "version"]
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
"""Init command for initializing prompt repositories."""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from promptforge.core.git_manager import GitManager
|
from promptforge.core.git_manager import GitManager
|
||||||
|
|
||||||
|
|
||||||
@@ -8,7 +11,10 @@ from promptforge.core.git_manager import GitManager
|
|||||||
@click.option("--force", is_flag=True, help="Force reinitialization")
|
@click.option("--force", is_flag=True, help="Force reinitialization")
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def init(ctx, directory: str, force: bool):
|
def init(ctx, directory: str, force: bool):
|
||||||
"""Initialize a new PromptForge repository."""
|
"""Initialize a new PromptForge repository.
|
||||||
|
|
||||||
|
Creates a prompts directory with git version control.
|
||||||
|
"""
|
||||||
prompts_dir = Path(directory) / "prompts"
|
prompts_dir = Path(directory) / "prompts"
|
||||||
git_manager = GitManager(prompts_dir)
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import sys
|
"""Prompt management commands."""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
from promptforge.core.prompt import Prompt, PromptVariable, VariableType
|
from promptforge.core.prompt import Prompt, PromptVariable, VariableType
|
||||||
from promptforge.core.template import TemplateEngine
|
from promptforge.core.template import TemplateEngine
|
||||||
from promptforge.core.git_manager import GitManager
|
from promptforge.core.git_manager import GitManager
|
||||||
@@ -109,7 +109,10 @@ def show(ctx, name: str):
|
|||||||
click.echo(f"Provider: {prompt.provider}")
|
click.echo(f"Provider: {prompt.provider}")
|
||||||
if prompt.tags:
|
if prompt.tags:
|
||||||
click.echo(f"Tags: {', '.join(prompt.tags)}")
|
click.echo(f"Tags: {', '.join(prompt.tags)}")
|
||||||
click.echo(f"\n--- Content ---")
|
click.echo(f"Created: {prompt.created_at.isoformat()}")
|
||||||
|
click.echo(f"Updated: {prompt.updated_at.isoformat()}")
|
||||||
|
click.echo("")
|
||||||
|
click.echo("--- Content ---")
|
||||||
click.echo(prompt.content)
|
click.echo(prompt.content)
|
||||||
|
|
||||||
|
|
||||||
@@ -135,5 +138,5 @@ def delete(ctx, name: str, yes: bool):
|
|||||||
filepath.unlink()
|
filepath.unlink()
|
||||||
click.echo(f"Deleted prompt: {name}")
|
click.echo(f"Deleted prompt: {name}")
|
||||||
else:
|
else:
|
||||||
click.echo(f"Prompt file not found", err=True)
|
click.echo("Prompt file not found", err=True)
|
||||||
raise click.Abort()
|
raise click.Abort()
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
|
"""Registry commands for sharing prompts."""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from promptforge.registry import LocalRegistry, RemoteRegistry, RegistryEntry
|
from promptforge.registry import LocalRegistry, RemoteRegistry, RegistryEntry
|
||||||
from promptforge.core.prompt import Prompt
|
from promptforge.core.prompt import Prompt
|
||||||
|
|
||||||
@@ -51,8 +54,9 @@ def registry_add(ctx, prompt_name: str, author: str):
|
|||||||
|
|
||||||
@registry.command("search")
|
@registry.command("search")
|
||||||
@click.argument("query")
|
@click.argument("query")
|
||||||
|
@click.option("--limit", default=20, help="Maximum results")
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def registry_search(ctx, query: str):
|
def registry_search(ctx, query: str, limit: int):
|
||||||
"""Search local registry."""
|
"""Search local registry."""
|
||||||
registry = LocalRegistry()
|
registry = LocalRegistry()
|
||||||
results = registry.search(query)
|
results = registry.search(query)
|
||||||
@@ -61,7 +65,7 @@ def registry_search(ctx, query: str):
|
|||||||
click.echo("No results found")
|
click.echo("No results found")
|
||||||
return
|
return
|
||||||
|
|
||||||
for result in results:
|
for result in results[:limit]:
|
||||||
entry = result.entry
|
entry = result.entry
|
||||||
click.echo(f"{entry.name} (score: {result.relevance_score})")
|
click.echo(f"{entry.name} (score: {result.relevance_score})")
|
||||||
if entry.description:
|
if entry.description:
|
||||||
@@ -79,7 +83,7 @@ def registry_pull(ctx, entry_id: str):
|
|||||||
if remote.pull(entry_id, local):
|
if remote.pull(entry_id, local):
|
||||||
click.echo(f"Pulled entry {entry_id}")
|
click.echo(f"Pulled entry {entry_id}")
|
||||||
else:
|
else:
|
||||||
click.echo(f"Entry not found", err=True)
|
click.echo("Entry not found", err=True)
|
||||||
raise click.Abort()
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
|
"""Run command for executing prompts."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
import click
|
import click
|
||||||
from pathlib import Path
|
|
||||||
from promptforge.core.prompt import Prompt
|
from promptforge.core.prompt import Prompt
|
||||||
from promptforge.core.template import TemplateEngine
|
from promptforge.core.template import TemplateEngine
|
||||||
from promptforge.core.config import get_config
|
from promptforge.core.config import get_config
|
||||||
from promptforge.providers import ProviderFactory
|
from promptforge.providers import ProviderFactory
|
||||||
from promptforge.testing.validator import Validator
|
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@@ -33,7 +35,11 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
|||||||
|
|
||||||
template_engine = TemplateEngine()
|
template_engine = TemplateEngine()
|
||||||
try:
|
try:
|
||||||
rendered = template_engine.render(prompt.content, variables, prompt.variables)
|
rendered = template_engine.render(
|
||||||
|
prompt.content,
|
||||||
|
variables,
|
||||||
|
prompt.variables,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Template error: {e}", err=True)
|
click.echo(f"Template error: {e}", err=True)
|
||||||
raise click.Abort()
|
raise click.Abort()
|
||||||
@@ -42,10 +48,11 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
|||||||
selected_provider = provider or prompt.provider or config.defaults.provider
|
selected_provider = provider or prompt.provider or config.defaults.provider
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||||
provider_instance = ProviderFactory.create(
|
provider_instance = ProviderFactory.create(
|
||||||
selected_provider,
|
selected_provider,
|
||||||
model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None,
|
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||||
temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7,
|
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Provider error: {e}", err=True)
|
click.echo(f"Provider error: {e}", err=True)
|
||||||
@@ -67,18 +74,26 @@ def run(ctx, name: str, provider: str, var: tuple, output: str, stream: bool):
|
|||||||
import json
|
import json
|
||||||
click.echo("\n" + json.dumps({"response": response}, indent=2))
|
click.echo("\n" + json.dumps({"response": response}, indent=2))
|
||||||
|
|
||||||
|
if prompt.validation_rules:
|
||||||
|
validate_response(prompt, response)
|
||||||
|
|
||||||
|
try:
|
||||||
asyncio.run(execute())
|
asyncio.run(execute())
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Execution error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
def validate_response(prompt: Prompt, response: str):
|
def validate_response(prompt: Prompt, response: str):
|
||||||
|
"""Validate response against rules."""
|
||||||
for rule in prompt.validation_rules:
|
for rule in prompt.validation_rules:
|
||||||
if rule.type == "regex":
|
if rule.type == "regex":
|
||||||
import re
|
import re
|
||||||
if not re.search(rule.pattern or "", response):
|
if not re.search(rule.pattern or "", response):
|
||||||
click.echo(f"Warning: Response failed regex validation", err=True)
|
click.echo("Warning: Response failed regex validation", err=True)
|
||||||
elif rule.type == "json":
|
elif rule.type == "json":
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
json.loads(response)
|
json.loads(response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
click.echo(f"Warning: Response is not valid JSON", err=True)
|
click.echo("Warning: Response is not valid JSON", err=True)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
|
"""Test command for A/B testing prompts."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from promptforge.core.prompt import Prompt
|
from promptforge.core.prompt import Prompt
|
||||||
from promptforge.core.config import get_config
|
from promptforge.core.config import get_config
|
||||||
from promptforge.providers import ProviderFactory
|
from promptforge.providers import ProviderFactory
|
||||||
@@ -30,16 +34,21 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
|||||||
selected_provider = provider or config.defaults.provider
|
selected_provider = provider or config.defaults.provider
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
provider_config: Dict[str, Any] = dict(config.providers.get(selected_provider, {}))
|
||||||
provider_instance = ProviderFactory.create(
|
provider_instance = ProviderFactory.create(
|
||||||
selected_provider,
|
selected_provider,
|
||||||
model=config.providers.get(selected_provider, {}).model if selected_provider in config.providers else None,
|
model=provider_config.get("model") if isinstance(provider_config, dict) else None,
|
||||||
temperature=config.providers.get(selected_provider, {}).temperature if selected_provider in config.providers else 0.7,
|
temperature=provider_config.get("temperature", 0.7) if isinstance(provider_config, dict) else 0.7,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Provider error: {e}", err=True)
|
click.echo(f"Provider error: {e}", err=True)
|
||||||
raise click.Abort()
|
raise click.Abort()
|
||||||
|
|
||||||
test_config = ABTestConfig(iterations=iterations, parallel=parallel)
|
test_config = ABTestConfig(
|
||||||
|
iterations=iterations,
|
||||||
|
parallel=parallel,
|
||||||
|
)
|
||||||
|
|
||||||
ab_test = ABTest(provider_instance, test_config)
|
ab_test = ABTest(provider_instance, test_config)
|
||||||
|
|
||||||
async def run_tests():
|
async def run_tests():
|
||||||
@@ -57,6 +66,7 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
|||||||
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
click.echo(f"Successful: {summary.successful_runs}/{summary.total_runs}")
|
||||||
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
click.echo(f"Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
||||||
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
click.echo(f"Avg Tokens: {summary.avg_tokens:.0f}")
|
||||||
|
click.echo(f"Avg Cost: ${summary.avg_cost:.4f}")
|
||||||
|
|
||||||
if output == "json":
|
if output == "json":
|
||||||
import json
|
import json
|
||||||
@@ -66,6 +76,7 @@ def test(ctx, prompt_names: tuple, provider: str, iterations: int, output: str,
|
|||||||
"total_runs": s.total_runs,
|
"total_runs": s.total_runs,
|
||||||
"avg_latency_ms": s.avg_latency_ms,
|
"avg_latency_ms": s.avg_latency_ms,
|
||||||
"avg_tokens": s.avg_tokens,
|
"avg_tokens": s.avg_tokens,
|
||||||
|
"avg_cost": s.avg_cost,
|
||||||
}
|
}
|
||||||
for name, s in results.items()
|
for name, s in results.items()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
|
"""Version control commands."""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from promptforge.core.git_manager import GitManager
|
from promptforge.core.git_manager import GitManager
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +35,13 @@ def history(ctx, prompt_name: str):
|
|||||||
click.echo(f" {commit['date']} by {commit['author']}")
|
click.echo(f" {commit['date']} by {commit['author']}")
|
||||||
else:
|
else:
|
||||||
for commit in commits:
|
for commit in commits:
|
||||||
click.echo(f"{commit.hexsha[:7]} - {commit.message.strip()}")
|
hexsha = commit.hexsha
|
||||||
|
if isinstance(hexsha, bytes):
|
||||||
|
hexsha = hexsha.decode('utf-8')
|
||||||
|
message = commit.message
|
||||||
|
if isinstance(message, bytes):
|
||||||
|
message = message.decode('utf-8')
|
||||||
|
click.echo(f"{hexsha[:7]} - {message.strip()}")
|
||||||
click.echo(f" {commit.author.name} - {commit.committed_datetime.isoformat()}")
|
click.echo(f" {commit.author.name} - {commit.committed_datetime.isoformat()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Error: {e}", err=True)
|
click.echo(f"Error: {e}", err=True)
|
||||||
@@ -75,6 +84,22 @@ def branch(ctx, branch_name: str):
|
|||||||
raise click.Abort()
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
|
@version.command("switch")
|
||||||
|
@click.argument("branch_name")
|
||||||
|
@click.pass_obj
|
||||||
|
def switch(ctx, branch_name: str):
|
||||||
|
"""Switch to a branch."""
|
||||||
|
prompts_dir = ctx["prompts_dir"]
|
||||||
|
git_manager = GitManager(prompts_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
git_manager.switch_branch(branch_name)
|
||||||
|
click.echo(f"Switched to branch: {branch_name}")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"Error: {e}", err=True)
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
|
||||||
@version.command("list")
|
@version.command("list")
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def list_branches(ctx):
|
def list_branches(ctx):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import sys
|
"""Main CLI entry point."""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
"""Core prompt management modules."""
|
||||||
|
|
||||||
|
from .prompt import Prompt, PromptVariable
|
||||||
|
from .template import TemplateEngine
|
||||||
|
from .config import Config
|
||||||
|
from .git_manager import GitManager
|
||||||
|
from .exceptions import (
|
||||||
|
InvalidPromptError,
|
||||||
|
ProviderError,
|
||||||
|
ValidationError,
|
||||||
|
GitError,
|
||||||
|
RegistryError,
|
||||||
|
MissingVariableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Prompt",
|
||||||
|
"PromptVariable",
|
||||||
|
"TemplateEngine",
|
||||||
|
"Config",
|
||||||
|
"GitManager",
|
||||||
|
"InvalidPromptError",
|
||||||
|
"ProviderError",
|
||||||
|
"ValidationError",
|
||||||
|
"GitError",
|
||||||
|
"RegistryError",
|
||||||
|
"MissingVariableError",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Configuration management for PromptForge."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@@ -8,6 +10,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
|
|
||||||
class ProviderConfig(BaseModel):
|
class ProviderConfig(BaseModel):
|
||||||
|
"""Configuration for an LLM provider."""
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
model: str = "gpt-4"
|
model: str = "gpt-4"
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
@@ -15,21 +19,29 @@ class ProviderConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class RegistryConfig(BaseModel):
|
class RegistryConfig(BaseModel):
|
||||||
|
"""Configuration for the prompt registry."""
|
||||||
|
|
||||||
local_path: str = "~/.promptforge/registry"
|
local_path: str = "~/.promptforge/registry"
|
||||||
remote_url: str = "https://registry.promptforge.io"
|
remote_url: str = "https://registry.promptforge.io"
|
||||||
|
|
||||||
|
|
||||||
class DefaultsConfig(BaseModel):
|
class DefaultsConfig(BaseModel):
|
||||||
|
"""Default settings for PromptForge."""
|
||||||
|
|
||||||
provider: str = "openai"
|
provider: str = "openai"
|
||||||
output_format: str = "text"
|
output_format: str = "text"
|
||||||
|
|
||||||
|
|
||||||
class ValidationConfig(BaseModel):
|
class ValidationConfig(BaseModel):
|
||||||
|
"""Validation settings."""
|
||||||
|
|
||||||
strict_mode: bool = False
|
strict_mode: bool = False
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
"""Main configuration for PromptForge."""
|
||||||
|
|
||||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict)
|
providers: Dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||||
registry: RegistryConfig = Field(default_factory=RegistryConfig)
|
registry: RegistryConfig = Field(default_factory=RegistryConfig)
|
||||||
defaults: DefaultsConfig = Field(default_factory=DefaultsConfig)
|
defaults: DefaultsConfig = Field(default_factory=DefaultsConfig)
|
||||||
@@ -37,6 +49,7 @@ class Config(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def _expand_env_vars(value: Any) -> Any:
|
def _expand_env_vars(value: Any) -> Any:
|
||||||
|
"""Expand environment variables in a value."""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if value.startswith("${") and value.endswith("}"):
|
if value.startswith("${") and value.endswith("}"):
|
||||||
env_var = value[2:-1]
|
env_var = value[2:-1]
|
||||||
@@ -45,6 +58,7 @@ def _expand_env_vars(value: Any) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process configuration dictionary, expanding environment variables."""
|
||||||
processed = {}
|
processed = {}
|
||||||
for key, value in config_dict.items():
|
for key, value in config_dict.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@@ -55,6 +69,14 @@ def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def load_config(config_path: Optional[Path] = None) -> Config:
|
def load_config(config_path: Optional[Path] = None) -> Config:
|
||||||
|
"""Load configuration from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to configuration file. If None, looks in standard locations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Config object with all settings.
|
||||||
|
"""
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = Path.cwd() / "configs" / "promptforge.yaml"
|
config_path = Path.cwd() / "configs" / "promptforge.yaml"
|
||||||
|
|
||||||
@@ -70,4 +92,5 @@ def load_config(config_path: Optional[Path] = None) -> Config:
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_config() -> Config:
|
def get_config() -> Config:
|
||||||
|
"""Get cached configuration."""
|
||||||
return load_config()
|
return load_config()
|
||||||
@@ -1,38 +1,49 @@
|
|||||||
|
"""Custom exceptions for PromptForge."""
|
||||||
|
|
||||||
|
|
||||||
class PromptForgeError(Exception):
|
class PromptForgeError(Exception):
|
||||||
"""Base exception for PromptForge errors."""
|
"""Base exception for PromptForge errors."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidPromptError(PromptForgeError):
|
class InvalidPromptError(PromptForgeError):
|
||||||
"""Raised when a prompt YAML is malformed."""
|
"""Raised when a prompt YAML is malformed."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ProviderError(PromptForgeError):
|
class ProviderError(PromptForgeError):
|
||||||
"""Raised when LLM API operations fail."""
|
"""Raised when LLM API operations fail."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(PromptForgeError):
|
class ValidationError(PromptForgeError):
|
||||||
"""Raised when response validation fails."""
|
"""Raised when response validation fails."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GitError(PromptForgeError):
|
class GitError(PromptForgeError):
|
||||||
"""Raised when git operations fail."""
|
"""Raised when git operations fail."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RegistryError(PromptForgeError):
|
class RegistryError(PromptForgeError):
|
||||||
"""Raised when registry operations fail."""
|
"""Raised when registry operations fail."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MissingVariableError(PromptForgeError):
|
class MissingVariableError(PromptForgeError):
|
||||||
"""Raised when a required template variable is missing."""
|
"""Raised when a required template variable is missing."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConfigurationError(PromptForgeError):
|
class ConfigurationError(PromptForgeError):
|
||||||
"""Raised when configuration is invalid."""
|
"""Raised when configuration is invalid."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
|
"""Git integration for prompt version control."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from git import Repo, Commit, GitCommandError
|
from git import Repo, Commit, GitCommandError
|
||||||
@@ -8,11 +10,23 @@ from .exceptions import GitError
|
|||||||
|
|
||||||
|
|
||||||
class GitManager:
|
class GitManager:
|
||||||
|
"""Manage git operations for prompt directories."""
|
||||||
|
|
||||||
def __init__(self, prompts_dir: Path):
|
def __init__(self, prompts_dir: Path):
|
||||||
|
"""Initialize git manager for a prompts directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts_dir: Path to the prompts directory.
|
||||||
|
"""
|
||||||
self.prompts_dir = Path(prompts_dir)
|
self.prompts_dir = Path(prompts_dir)
|
||||||
self.repo: Optional[Repo] = None
|
self.repo: Optional[Repo] = None
|
||||||
|
|
||||||
def init(self) -> bool:
|
def init(self) -> bool:
|
||||||
|
"""Initialize a git repository in the prompts directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if repository was created, False if it already exists.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.prompts_dir.exists():
|
if not self.prompts_dir.exists():
|
||||||
self.prompts_dir.mkdir(parents=True, exist_ok=True)
|
self.prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -28,6 +42,7 @@ class GitManager:
|
|||||||
raise GitError(f"Failed to initialize git repository: {e}")
|
raise GitError(f"Failed to initialize git repository: {e}")
|
||||||
|
|
||||||
def _is_git_repo(self) -> bool:
|
def _is_git_repo(self) -> bool:
|
||||||
|
"""Check if prompts_dir is a git repository."""
|
||||||
try:
|
try:
|
||||||
self.repo = Repo(str(self.prompts_dir))
|
self.repo = Repo(str(self.prompts_dir))
|
||||||
return not self.repo.bare
|
return not self.repo.bare
|
||||||
@@ -35,53 +50,140 @@ class GitManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _configure_gitignore(self) -> None:
|
def _configure_gitignore(self) -> None:
|
||||||
|
"""Create .gitignore for prompts directory."""
|
||||||
gitignore_path = self.prompts_dir / ".gitignore"
|
gitignore_path = self.prompts_dir / ".gitignore"
|
||||||
if not gitignore_path.exists():
|
if not gitignore_path.exists():
|
||||||
gitignore_path.write_text("*.lock\n.temp*\n")
|
gitignore_path.write_text("*.lock\\.temp*\\n")
|
||||||
|
|
||||||
def commit(self, message: str, author: Optional[str] = None) -> Commit:
|
def commit(self, message: str, author: Optional[str] = None) -> Commit:
|
||||||
|
"""Commit all changes to prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Commit message.
|
||||||
|
author: Author string (e.g., "Name <email>").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created commit object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GitError: If commit fails.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
try:
|
try:
|
||||||
self.repo.index.add(["*"])
|
self.repo.index.add(["*"])
|
||||||
return self.repo.index.commit(message, author=author)
|
author_arg: Any = author # type: ignore[assignment]
|
||||||
|
return self.repo.index.commit(message, author=author_arg)
|
||||||
except GitCommandError as e:
|
except GitCommandError as e:
|
||||||
raise GitError(f"Failed to commit changes: {e}")
|
raise GitError(f"Failed to commit changes: {e}")
|
||||||
|
|
||||||
def log(self, max_count: int = 20) -> List[Commit]:
|
def log(self, max_count: int = 20) -> List[Commit]:
|
||||||
|
"""Get commit history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_count: Maximum number of commits to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of commit objects.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
try:
|
try:
|
||||||
return list(self.repo.iter_commits(max_count=max_count))
|
return list(self.repo.iter_commits(max_count=max_count))
|
||||||
except GitCommandError as e:
|
except GitCommandError as e:
|
||||||
raise GitError(f"Failed to get commit log: {e}")
|
raise GitError(f"Failed to get commit log: {e}")
|
||||||
|
|
||||||
def create_branch(self, branch_name: str) -> None:
|
def show_commit(self, commit_sha: str) -> str:
|
||||||
|
"""Show content of a specific commit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commit_sha: SHA of the commit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Commit diff as string.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
|
try:
|
||||||
|
commit = self.repo.commit(commit_sha)
|
||||||
|
diff_result = commit.diff("HEAD~1" if commit_sha == "HEAD" else f"{commit_sha}^")
|
||||||
|
return str(diff_result) if diff_result else ""
|
||||||
|
except Exception as e:
|
||||||
|
raise GitError(f"Failed to show commit: {e}")
|
||||||
|
|
||||||
|
def create_branch(self, branch_name: str) -> None:
|
||||||
|
"""Create a new branch for prompt variations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_name: Name of the new branch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GitError: If branch creation fails.
|
||||||
|
"""
|
||||||
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
try:
|
try:
|
||||||
self.repo.create_head(branch_name)
|
self.repo.create_head(branch_name)
|
||||||
except GitCommandError as e:
|
except GitCommandError as e:
|
||||||
raise GitError(f"Failed to create branch: {e}")
|
raise GitError(f"Failed to create branch: {e}")
|
||||||
|
|
||||||
def switch_branch(self, branch_name: str) -> None:
|
def switch_branch(self, branch_name: str) -> None:
|
||||||
|
"""Switch to a different branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_name: Name of the branch to switch to.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GitError: If branch switch fails.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
try:
|
try:
|
||||||
self.repo.heads[branch_name].checkout()
|
self.repo.heads[branch_name].checkout()
|
||||||
except GitCommandError as e:
|
except GitCommandError as e:
|
||||||
raise GitError(f"Failed to switch branch: {e}")
|
raise GitError(f"Failed to switch branch: {e}")
|
||||||
|
|
||||||
def list_branches(self) -> List[str]:
|
def list_branches(self) -> List[str]:
|
||||||
|
"""List all branches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of branch names.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
return [head.name for head in self.repo.heads]
|
return [head.name for head in self.repo.heads]
|
||||||
|
|
||||||
def status(self) -> str:
|
def status(self) -> str:
|
||||||
|
"""Get git status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Status string.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
return self.repo.git.status()
|
return self.repo.git.status()
|
||||||
|
|
||||||
def diff(self) -> str:
|
def diff(self) -> str:
|
||||||
|
"""Show uncommitted changes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Diff as string.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
return self.repo.git.diff()
|
return self.repo.git.diff()
|
||||||
|
|
||||||
def get_file_history(self, filename: str) -> List[dict]:
|
def get_file_history(self, filename: str) -> List[dict]:
|
||||||
|
"""Get commit history for a specific file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of commit info dictionaries.
|
||||||
|
"""
|
||||||
self._ensure_repo()
|
self._ensure_repo()
|
||||||
|
assert self.repo is not None
|
||||||
commits = []
|
commits = []
|
||||||
try:
|
try:
|
||||||
for commit in self.repo.iter_commits("--all", filename):
|
for commit in self.repo.iter_commits("--all", filename):
|
||||||
@@ -96,6 +198,7 @@ class GitManager:
|
|||||||
return commits
|
return commits
|
||||||
|
|
||||||
def _ensure_repo(self) -> None:
|
def _ensure_repo(self) -> None:
|
||||||
|
"""Ensure repository is initialized."""
|
||||||
if self.repo is None:
|
if self.repo is None:
|
||||||
if not self._is_git_repo():
|
if not self._is_git_repo():
|
||||||
raise GitError("Git repository not initialized. Run 'pf init' first.")
|
raise GitError("Git repository not initialized. Run 'pf init' first.")
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Prompt model and management."""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -10,6 +12,8 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
|
|
||||||
class VariableType(str, Enum):
|
class VariableType(str, Enum):
|
||||||
|
"""Supported variable types."""
|
||||||
|
|
||||||
STRING = "string"
|
STRING = "string"
|
||||||
INTEGER = "integer"
|
INTEGER = "integer"
|
||||||
FLOAT = "float"
|
FLOAT = "float"
|
||||||
@@ -18,6 +22,8 @@ class VariableType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class PromptVariable(BaseModel):
|
class PromptVariable(BaseModel):
|
||||||
|
"""Definition of a template variable."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
type: VariableType = VariableType.STRING
|
type: VariableType = VariableType.STRING
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@@ -25,8 +31,17 @@ class PromptVariable(BaseModel):
|
|||||||
default: Optional[Any] = None
|
default: Optional[Any] = None
|
||||||
choices: Optional[List[str]] = None
|
choices: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@field_validator('choices')
|
||||||
|
@classmethod
|
||||||
|
def validate_choices(cls, v, info):
|
||||||
|
if v is not None and info.data.get('type') != VariableType.CHOICE:
|
||||||
|
raise ValueError("choices only valid for CHOICE type")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class ValidationRule(BaseModel):
|
class ValidationRule(BaseModel):
|
||||||
|
"""Validation rule for prompt output."""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
pattern: Optional[str] = None
|
pattern: Optional[str] = None
|
||||||
json_schema: Optional[Dict[str, Any]] = None
|
json_schema: Optional[Dict[str, Any]] = None
|
||||||
@@ -34,6 +49,8 @@ class ValidationRule(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Prompt(BaseModel):
|
class Prompt(BaseModel):
|
||||||
|
"""Prompt model with metadata and template."""
|
||||||
|
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@@ -55,12 +72,20 @@ class Prompt(BaseModel):
|
|||||||
return hashlib.md5(content.encode()).hexdigest()
|
return hashlib.md5(content.encode()).hexdigest()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def to_dict(self, exclude_none: bool = False) -> Dict[str, Any]:
|
||||||
|
"""Export prompt to dictionary."""
|
||||||
|
data = super().model_dump(exclude_none=exclude_none)
|
||||||
|
data['created_at'] = self.created_at.isoformat()
|
||||||
|
data['updated_at'] = self.updated_at.isoformat()
|
||||||
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml(cls, yaml_content: str) -> "Prompt":
|
def from_yaml(cls, yaml_content: str) -> "Prompt":
|
||||||
|
"""Parse prompt from YAML with front matter."""
|
||||||
content = yaml_content.strip()
|
content = yaml_content.strip()
|
||||||
|
|
||||||
if not content.startswith('---'):
|
if not content.startswith('---'):
|
||||||
metadata = {}
|
metadata: Dict[str, Any] = {}
|
||||||
prompt_content = content
|
prompt_content = content
|
||||||
else:
|
else:
|
||||||
parts = content[4:].split('\n---', 1)
|
parts = content[4:].split('\n---', 1)
|
||||||
@@ -80,6 +105,7 @@ class Prompt(BaseModel):
|
|||||||
return cls(**data)
|
return cls(**data)
|
||||||
|
|
||||||
def to_yaml(self) -> str:
|
def to_yaml(self) -> str:
|
||||||
|
"""Export prompt to YAML front matter format."""
|
||||||
def var_to_dict(v):
|
def var_to_dict(v):
|
||||||
d = v.model_dump()
|
d = v.model_dump()
|
||||||
d['type'] = v.type.value
|
d['type'] = v.type.value
|
||||||
@@ -101,6 +127,14 @@ class Prompt(BaseModel):
|
|||||||
return f"---\n{yaml_str}---\n{self.content}"
|
return f"---\n{yaml_str}---\n{self.content}"
|
||||||
|
|
||||||
def save(self, prompts_dir: Path) -> Path:
|
def save(self, prompts_dir: Path) -> Path:
|
||||||
|
"""Save prompt to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts_dir: Directory to save prompt in.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved file.
|
||||||
|
"""
|
||||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = self.name.lower().replace(' ', '_').replace('/', '_') + '.yaml'
|
filename = self.name.lower().replace(' ', '_').replace('/', '_') + '.yaml'
|
||||||
filepath = prompts_dir / filename
|
filepath = prompts_dir / filename
|
||||||
@@ -110,13 +144,15 @@ class Prompt(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, filepath: Path) -> "Prompt":
|
def load(cls, filepath: Path) -> "Prompt":
|
||||||
|
"""Load prompt from file."""
|
||||||
with open(filepath, 'r') as f:
|
with open(filepath, 'r') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
return cls.from_yaml(content)
|
return cls.from_yaml(content)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(cls, prompts_dir: Path) -> List["Prompt"]:
|
def list(cls, prompts_dir: Path) -> List["Prompt"]:
|
||||||
prompts = []
|
"""List all prompts in directory."""
|
||||||
|
prompts: List["Prompt"] = []
|
||||||
if not prompts_dir.exists():
|
if not prompts_dir.exists():
|
||||||
return prompts
|
return prompts
|
||||||
for filepath in prompts_dir.glob('*.yaml'):
|
for filepath in prompts_dir.glob('*.yaml'):
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Jinja2-based template engine for prompt rendering."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from jinja2 import Environment, BaseLoader, TemplateSyntaxError, StrictUndefined, UndefinedError
|
from jinja2 import Environment, BaseLoader, TemplateSyntaxError, StrictUndefined, UndefinedError
|
||||||
from jinja2.exceptions import TemplateError
|
from jinja2.exceptions import TemplateError
|
||||||
@@ -7,6 +9,8 @@ from .exceptions import MissingVariableError, InvalidPromptError
|
|||||||
|
|
||||||
|
|
||||||
class TemplateEngine:
|
class TemplateEngine:
|
||||||
|
"""Jinja2 template engine for prompt rendering with variable substitution."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.env = Environment(
|
self.env = Environment(
|
||||||
loader=BaseLoader(),
|
loader=BaseLoader(),
|
||||||
@@ -17,6 +21,14 @@ class TemplateEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_variables(self, content: str) -> List[str]:
|
def get_variables(self, content: str) -> List[str]:
|
||||||
|
"""Extract variable names from template content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Template content with {{variable}} syntax.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of variable names found.
|
||||||
|
"""
|
||||||
from jinja2 import meta
|
from jinja2 import meta
|
||||||
ast = self.env.parse(content)
|
ast = self.env.parse(content)
|
||||||
return sorted(meta.find_undeclared_variables(ast))
|
return sorted(meta.find_undeclared_variables(ast))
|
||||||
@@ -27,6 +39,20 @@ class TemplateEngine:
|
|||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
required_variables: Optional[List[PromptVariable]] = None,
|
required_variables: Optional[List[PromptVariable]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Render template with variable substitution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Template content.
|
||||||
|
variables: Dictionary of variable values.
|
||||||
|
required_variables: List of variable definitions for validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MissingVariableError: If a required variable is missing.
|
||||||
|
InvalidPromptError: If template has syntax errors.
|
||||||
|
"""
|
||||||
variables = variables.copy() if variables else {}
|
variables = variables.copy() if variables else {}
|
||||||
required_variables = required_variables or []
|
required_variables = required_variables or []
|
||||||
|
|
||||||
@@ -57,3 +83,75 @@ class TemplateEngine:
|
|||||||
raise InvalidPromptError(f"Template syntax error: {e.message}")
|
raise InvalidPromptError(f"Template syntax error: {e.message}")
|
||||||
except (UndefinedError, TemplateError) as e:
|
except (UndefinedError, TemplateError) as e:
|
||||||
raise InvalidPromptError(f"Template rendering error: {e}")
|
raise InvalidPromptError(f"Template rendering error: {e}")
|
||||||
|
|
||||||
|
def validate_variables(
|
||||||
|
self,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
required_variables: List[PromptVariable],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Validate variable values against definitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variables: Provided variable values.
|
||||||
|
required_variables: Variable definitions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of validation error messages (empty if valid).
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for var in required_variables:
|
||||||
|
if var.name not in variables:
|
||||||
|
if var.required and var.default is None:
|
||||||
|
errors.append(f"Missing required variable: {var.name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = variables[var.name]
|
||||||
|
var_type = var.type.value
|
||||||
|
|
||||||
|
if var_type == "string":
|
||||||
|
if not isinstance(value, str):
|
||||||
|
errors.append(f"{var.name}: must be a string")
|
||||||
|
elif var_type == "integer":
|
||||||
|
try:
|
||||||
|
int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
errors.append(f"{var.name}: must be an integer")
|
||||||
|
elif var_type == "float":
|
||||||
|
try:
|
||||||
|
float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
errors.append(f"{var.name}: must be a number")
|
||||||
|
elif var_type == "boolean":
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value.lower() not in ("true", "false", "0", "1"):
|
||||||
|
errors.append(f"{var.name}: must be a boolean")
|
||||||
|
elif var_type == "choice":
|
||||||
|
if var.choices and value not in var.choices:
|
||||||
|
errors.append(
|
||||||
|
f"{var.name}: must be one of {', '.join(var.choices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
def render_prompt(
|
||||||
|
self,
|
||||||
|
prompt_content: str,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
variable_definitions: List[PromptVariable],
|
||||||
|
) -> str:
|
||||||
|
"""Render a complete prompt with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_content: Raw prompt template content.
|
||||||
|
variables: Variable values.
|
||||||
|
variable_definitions: Variable definitions from prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered prompt string.
|
||||||
|
"""
|
||||||
|
errors = self.validate_variables(variables, variable_definitions)
|
||||||
|
if errors:
|
||||||
|
raise MissingVariableError(f"Variable validation failed: {', '.join(errors)}")
|
||||||
|
|
||||||
|
return self.render(prompt_content, variables, variable_definitions)
|
||||||
|
|||||||
@@ -1,6 +1,16 @@
|
|||||||
|
"""LLM Provider abstraction layer."""
|
||||||
|
|
||||||
|
from .base import ProviderBase, ProviderResponse
|
||||||
from .factory import ProviderFactory
|
from .factory import ProviderFactory
|
||||||
from .openai import OpenAIProvider
|
from .openai import OpenAIProvider
|
||||||
from .anthropic import AnthropicProvider
|
from .anthropic import AnthropicProvider
|
||||||
from .ollama import OllamaProvider
|
from .ollama import OllamaProvider
|
||||||
|
|
||||||
__all__ = ["ProviderFactory", "OpenAIProvider", "AnthropicProvider", "OllamaProvider"]
|
__all__ = [
|
||||||
|
"ProviderBase",
|
||||||
|
"ProviderResponse",
|
||||||
|
"ProviderFactory",
|
||||||
|
"OpenAIProvider",
|
||||||
|
"AnthropicProvider",
|
||||||
|
"OllamaProvider",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
"""Anthropic provider implementation."""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
from anthropic import Anthropic, APIError, RateLimitError
|
from anthropic import Anthropic, APIError, RateLimitError
|
||||||
|
|
||||||
@@ -9,6 +10,8 @@ from ..core.exceptions import ProviderError
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicProvider(ProviderBase):
|
class AnthropicProvider(ProviderBase):
|
||||||
|
"""Anthropic Claude models provider."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
@@ -16,6 +19,7 @@ class AnthropicProvider(ProviderBase):
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initialize Anthropic provider."""
|
||||||
super().__init__(api_key, model, temperature, **kwargs)
|
super().__init__(api_key, model, temperature, **kwargs)
|
||||||
self._client: Optional[Anthropic] = None
|
self._client: Optional[Anthropic] = None
|
||||||
|
|
||||||
@@ -24,6 +28,7 @@ class AnthropicProvider(ProviderBase):
|
|||||||
return "anthropic"
|
return "anthropic"
|
||||||
|
|
||||||
def _get_client(self) -> Anthropic:
|
def _get_client(self) -> Anthropic:
|
||||||
|
"""Get or create Anthropic client."""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
api_key = self.api_key or self._get_api_key_from_env()
|
api_key = self.api_key or self._get_api_key_from_env()
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@@ -45,25 +50,37 @@ class AnthropicProvider(ProviderBase):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ProviderResponse:
|
) -> ProviderResponse:
|
||||||
|
"""Send completion request to Anthropic."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
if system_prompt:
|
||||||
|
system_message = system_prompt
|
||||||
|
user_message = prompt
|
||||||
|
else:
|
||||||
|
system_message = None
|
||||||
|
user_message = prompt
|
||||||
|
|
||||||
response = client.messages.create(
|
response = client.messages.create( # type: ignore[arg-type]
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=max_tokens or 4096,
|
max_tokens=max_tokens or 4096,
|
||||||
|
temperature=self.temperature,
|
||||||
|
system=system_message, # type: ignore[arg-type]
|
||||||
|
messages=[{"role": "user", "content": user_message}],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for block in response.content:
|
||||||
|
if block.type == "text":
|
||||||
|
content += block.text
|
||||||
|
|
||||||
return ProviderResponse(
|
return ProviderResponse(
|
||||||
content=response.content[0].text,
|
content=content,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
provider=self.name,
|
provider=self.name,
|
||||||
usage={
|
usage={
|
||||||
@@ -71,29 +88,39 @@ class AnthropicProvider(ProviderBase):
|
|||||||
"output_tokens": response.usage.output_tokens,
|
"output_tokens": response.usage.output_tokens,
|
||||||
},
|
},
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
metadata={
|
||||||
|
"stop_reason": response.stop_reason,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
raise ProviderError(f"Anthropic API error: {e}")
|
raise ProviderError(f"Anthropic API error: {e}")
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
raise ProviderError(f"Anthropic rate limit exceeded: {e}")
|
raise ProviderError(f"Anthropic rate limit exceeded: {e}")
|
||||||
|
|
||||||
async def stream_complete(
|
async def stream_complete( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completion from Anthropic."""
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
if system_prompt:
|
||||||
|
system_message = system_prompt
|
||||||
|
user_message = prompt
|
||||||
|
else:
|
||||||
|
system_message = None
|
||||||
|
user_message = prompt
|
||||||
|
|
||||||
with client.messages.stream(
|
with client.messages.stream( # type: ignore[arg-type]
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=max_tokens or 4096,
|
max_tokens=max_tokens or 4096,
|
||||||
|
temperature=self.temperature,
|
||||||
|
system=system_message, # type: ignore[arg-type]
|
||||||
|
messages=[{"role": "user", "content": user_message}],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
@@ -102,12 +129,24 @@ class AnthropicProvider(ProviderBase):
|
|||||||
raise ProviderError(f"Anthropic API error: {e}")
|
raise ProviderError(f"Anthropic API error: {e}")
|
||||||
|
|
||||||
def validate_api_key(self) -> bool:
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Validate Anthropic API key."""
|
||||||
try:
|
try:
|
||||||
import os
|
import os
|
||||||
api_key = self.api_key or os.environ.get("ANTHROPIC_API_KEY")
|
api_key = self.api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return False
|
return False
|
||||||
client = Anthropic(api_key=api_key)
|
_ = Anthropic(api_key=api_key)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available Anthropic models."""
|
||||||
|
return [
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
"claude-3-sonnet-20240229",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
"claude-2.1",
|
||||||
|
"claude-2.0",
|
||||||
|
"claude-instant-1.2",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,26 +1,25 @@
|
|||||||
|
"""Base provider interface."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, AsyncIterator, Dict, Optional
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ProviderResponse:
|
class ProviderResponse:
|
||||||
def __init__(
|
"""Response from an LLM provider."""
|
||||||
self,
|
|
||||||
content: str,
|
content: str
|
||||||
model: str,
|
model: str
|
||||||
provider: str,
|
provider: str
|
||||||
usage: Optional[Dict[str, Any]] = None,
|
usage: Dict[str, int] = field(default_factory=dict)
|
||||||
latency_ms: float = 0.0,
|
latency_ms: float = 0.0
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
):
|
|
||||||
self.content = content
|
|
||||||
self.model = model
|
|
||||||
self.provider = provider
|
|
||||||
self.usage = usage or {}
|
|
||||||
self.latency_ms = latency_ms
|
|
||||||
self.metadata = metadata or {}
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderBase(ABC):
|
class ProviderBase(ABC):
|
||||||
|
"""Abstract base class for LLM providers."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
@@ -28,14 +27,23 @@ class ProviderBase(ABC):
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initialize provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication.
|
||||||
|
model: Model identifier to use.
|
||||||
|
temperature: Sampling temperature (0.0-1.0).
|
||||||
|
**kwargs: Additional provider-specific options.
|
||||||
|
"""
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.extra_kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
"""Provider name identifier."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -46,6 +54,17 @@ class ProviderBase(ABC):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ProviderResponse:
|
) -> ProviderResponse:
|
||||||
|
"""Send a completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderResponse with the generated content.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -56,7 +75,32 @@ class ProviderBase(ABC):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completions incrementally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Chunks of generated content.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def validate_api_key(self) -> bool:
|
def validate_api_key(self) -> bool:
|
||||||
return True
|
"""Validate that the API key is configured correctly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> List[str]:
|
||||||
|
"""List available models for this provider."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_system_prompt(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract system prompt from prompt if using special syntax."""
|
||||||
|
if "---" in prompt:
|
||||||
|
parts = prompt.split("---", 1)
|
||||||
|
return parts[0].strip()
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
from typing import Optional
|
"""Provider factory for instantiating LLM providers."""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .base import ProviderBase
|
||||||
from .openai import OpenAIProvider
|
from .openai import OpenAIProvider
|
||||||
from .anthropic import AnthropicProvider
|
from .anthropic import AnthropicProvider
|
||||||
from .ollama import OllamaProvider
|
from .ollama import OllamaProvider
|
||||||
@@ -6,21 +10,67 @@ from ..core.exceptions import ProviderError
|
|||||||
|
|
||||||
|
|
||||||
class ProviderFactory:
|
class ProviderFactory:
|
||||||
@staticmethod
|
"""Factory for creating LLM provider instances."""
|
||||||
|
|
||||||
|
_providers: Dict[str, type] = {
|
||||||
|
"openai": OpenAIProvider,
|
||||||
|
"anthropic": AnthropicProvider,
|
||||||
|
"ollama": OllamaProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str, provider_class: type) -> None:
|
||||||
|
"""Register a new provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider identifier.
|
||||||
|
provider_class: Provider class to register.
|
||||||
|
"""
|
||||||
|
if not issubclass(provider_class, ProviderBase):
|
||||||
|
raise TypeError("Provider must be a subclass of ProviderBase")
|
||||||
|
cls._providers[name.lower()] = provider_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
cls,
|
||||||
provider_name: str,
|
provider_name: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
api_key: Optional[str] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> ProviderBase:
|
||||||
provider_name = provider_name.lower()
|
"""Create a provider instance.
|
||||||
|
|
||||||
if provider_name in ("openai", "gpt-4", "gpt-3.5"):
|
Args:
|
||||||
return OpenAIProvider(api_key=api_key, model=model or "gpt-4", temperature=temperature)
|
provider_name: Name of the provider to create.
|
||||||
elif provider_name in ("anthropic", "claude"):
|
api_key: API key for the provider.
|
||||||
return AnthropicProvider(api_key=api_key, model=model or "claude-3-sonnet-20240229", temperature=temperature)
|
model: Model to use (uses default if not specified).
|
||||||
elif provider_name in ("ollama", "local"):
|
temperature: Sampling temperature.
|
||||||
return OllamaProvider(model=model or "llama2", temperature=temperature, **kwargs)
|
**kwargs: Additional provider-specific options.
|
||||||
else:
|
|
||||||
raise ProviderError(f"Unknown provider: {provider_name}")
|
Returns:
|
||||||
|
Provider instance.
|
||||||
|
"""
|
||||||
|
provider_class = cls._providers.get(provider_name.lower())
|
||||||
|
if provider_class is None:
|
||||||
|
available = ", ".join(cls._providers.keys())
|
||||||
|
raise ProviderError(
|
||||||
|
f"Unknown provider: {provider_name}. Available: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model or getattr(provider_class, "_default_model", "gpt-4"),
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_providers(cls) -> list[str]:
|
||||||
|
"""List available provider names."""
|
||||||
|
return list(cls._providers.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_class(cls, name: str) -> Optional[type]:
|
||||||
|
"""Get provider class by name."""
|
||||||
|
return cls._providers.get(name.lower())
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import asyncio
|
"""Ollama provider implementation for local models."""
|
||||||
|
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Optional
|
from typing import Any, AsyncIterator, Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .base import ProviderBase, ProviderResponse
|
from .base import ProviderBase, ProviderResponse
|
||||||
@@ -8,20 +11,35 @@ from ..core.exceptions import ProviderError
|
|||||||
|
|
||||||
|
|
||||||
class OllamaProvider(ProviderBase):
|
class OllamaProvider(ProviderBase):
|
||||||
|
"""Ollama local model provider."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
model: str = "llama2",
|
model: str = "llama2",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
base_url: str = "http://localhost:11434",
|
base_url: str = "http://localhost:11434",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(None, model, temperature, **kwargs)
|
"""Initialize Ollama provider."""
|
||||||
|
super().__init__(api_key, model, temperature, **kwargs)
|
||||||
self.base_url = base_url.rstrip('/')
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "ollama"
|
return "ollama"
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create HTTP client."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(timeout=120.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _get_api_url(self, endpoint: str) -> str:
|
||||||
|
"""Get full URL for an endpoint."""
|
||||||
|
return f"{self.base_url}/{endpoint.lstrip('/')}"
|
||||||
|
|
||||||
async def complete(
|
async def complete(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -29,78 +47,133 @@ class OllamaProvider(ProviderBase):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ProviderResponse:
|
) -> ProviderResponse:
|
||||||
|
"""Send completion request to Ollama."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
client = self._get_client()
|
||||||
payload = {
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"messages": messages,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"options": {
|
"options": {
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
payload["options"]["num_predict"] = max_tokens
|
payload["options"]["num_predict"] = max_tokens
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/api/generate",
|
self._get_api_url("/api/chat"),
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=120.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for msg in data.get("message", {}).get("content", ""):
|
||||||
|
if isinstance(msg, str):
|
||||||
|
content += msg
|
||||||
|
elif isinstance(msg, dict):
|
||||||
|
content += msg.get("content", "")
|
||||||
|
|
||||||
return ProviderResponse(
|
return ProviderResponse(
|
||||||
content=data.get("response", ""),
|
content=content,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
provider=self.name,
|
provider=self.name,
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": data.get("prompt_eval_count", 0),
|
||||||
|
"completion_tokens": data.get("eval_count", 0),
|
||||||
|
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
|
||||||
|
},
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
metadata={
|
||||||
|
"done": data.get("done", False),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPError as e:
|
||||||
raise ProviderError(f"Ollama HTTP error: {e}")
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise ProviderError(f"Ollama connection error: {e}")
|
raise ProviderError(f"Ollama connection error: {e}")
|
||||||
|
|
||||||
async def stream_complete(
|
async def stream_complete( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completion from Ollama."""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
client = self._get_client()
|
||||||
payload = {
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"messages": messages,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"options": {
|
"options": {
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
payload["options"]["num_predict"] = max_tokens
|
payload["options"]["num_predict"] = max_tokens
|
||||||
|
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/api/generate",
|
self._get_api_url("/api/chat"),
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=120.0
|
|
||||||
) as response:
|
) as response:
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
import json
|
if line:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
if "response" in data:
|
if "message" in data:
|
||||||
yield data["response"]
|
content = data["message"].get("content", "")
|
||||||
except httpx.HTTPStatusError as e:
|
if content:
|
||||||
raise ProviderError(f"Ollama HTTP error: {e}")
|
yield content
|
||||||
except httpx.RequestError as e:
|
except httpx.HTTPError as e:
|
||||||
raise ProviderError(f"Ollama connection error: {e}")
|
raise ProviderError(f"Ollama connection error: {e}")
|
||||||
|
|
||||||
def validate_api_key(self) -> bool:
|
async def pull_model(self, model: Optional[str] = None) -> bool:
|
||||||
|
"""Pull a model from Ollama registry."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_model = model or self.model
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
self._get_api_url("/api/pull"),
|
||||||
|
json={"name": target_model, "stream": False},
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
return True
|
return True
|
||||||
|
except httpx.HTTPError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Ollama doesn't use API keys, always returns True."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available Ollama models."""
|
||||||
|
return [
|
||||||
|
"llama2",
|
||||||
|
"llama2-uncensored",
|
||||||
|
"mistral",
|
||||||
|
"mixtral",
|
||||||
|
"codellama",
|
||||||
|
"deepseek-coder",
|
||||||
|
"neural-chat",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
"""OpenAI provider implementation."""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError
|
from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError
|
||||||
|
|
||||||
@@ -9,6 +10,8 @@ from ..core.exceptions import ProviderError
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(ProviderBase):
|
class OpenAIProvider(ProviderBase):
|
||||||
|
"""OpenAI GPT models provider."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
@@ -17,6 +20,7 @@ class OpenAIProvider(ProviderBase):
|
|||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initialize OpenAI provider."""
|
||||||
super().__init__(api_key, model, temperature, **kwargs)
|
super().__init__(api_key, model, temperature, **kwargs)
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self._client: Optional[AsyncOpenAI] = None
|
self._client: Optional[AsyncOpenAI] = None
|
||||||
@@ -26,6 +30,7 @@ class OpenAIProvider(ProviderBase):
|
|||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
def _get_client(self) -> AsyncOpenAI:
|
def _get_client(self) -> AsyncOpenAI:
|
||||||
|
"""Get or create OpenAI client."""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
api_key = self.api_key or self._get_api_key_from_env()
|
api_key = self.api_key or self._get_api_key_from_env()
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@@ -33,7 +38,10 @@ class OpenAIProvider(ProviderBase):
|
|||||||
"OpenAI API key not configured. "
|
"OpenAI API key not configured. "
|
||||||
"Set OPENAI_API_KEY env var or pass api_key parameter."
|
"Set OPENAI_API_KEY env var or pass api_key parameter."
|
||||||
)
|
)
|
||||||
self._client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
|
self._client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def _get_api_key_from_env(self) -> Optional[str]:
|
def _get_api_key_from_env(self) -> Optional[str]:
|
||||||
@@ -47,6 +55,7 @@ class OpenAIProvider(ProviderBase):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ProviderResponse:
|
) -> ProviderResponse:
|
||||||
|
"""Send completion request to OpenAI."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -57,9 +66,9 @@ class OpenAIProvider(ProviderBase):
|
|||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create( # type: ignore[arg-type]
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages, # type: ignore[arg-type]
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -72,12 +81,14 @@ class OpenAIProvider(ProviderBase):
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
provider=self.name,
|
provider=self.name,
|
||||||
usage={
|
usage={
|
||||||
"prompt_tokens": response.usage.prompt_tokens,
|
"prompt_tokens": response.usage.prompt_tokens, # type: ignore[union-attr]
|
||||||
"completion_tokens": response.usage.completion_tokens,
|
"completion_tokens": response.usage.completion_tokens, # type: ignore[union-attr]
|
||||||
"total_tokens": response.usage.total_tokens,
|
"total_tokens": response.usage.total_tokens, # type: ignore[union-attr]
|
||||||
},
|
},
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
metadata={"finish_reason": response.choices[0].finish_reason},
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
raise ProviderError(f"OpenAI API error: {e}")
|
raise ProviderError(f"OpenAI API error: {e}")
|
||||||
@@ -86,13 +97,14 @@ class OpenAIProvider(ProviderBase):
|
|||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ProviderError(f"OpenAI connection error: {e}")
|
raise ProviderError(f"OpenAI connection error: {e}")
|
||||||
|
|
||||||
async def stream_complete(
|
async def stream_complete( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completion from OpenAI."""
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
||||||
@@ -101,28 +113,39 @@ class OpenAIProvider(ProviderBase):
|
|||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create( # type: ignore[arg-type]
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages, # type: ignore[arg-type]
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=True,
|
stream=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream: # type: ignore[union-attr]
|
||||||
if chunk.choices[0].delta.content:
|
if chunk.choices[0].delta.content:
|
||||||
yield chunk.choices[0].delta.content
|
yield chunk.choices[0].delta.content
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
raise ProviderError(f"OpenAI API error: {e}")
|
raise ProviderError(f"OpenAI API error: {e}")
|
||||||
|
|
||||||
def validate_api_key(self) -> bool:
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Validate OpenAI API key."""
|
||||||
try:
|
try:
|
||||||
import os
|
import os
|
||||||
api_key = self.api_key or os.environ.get("OPENAI_API_KEY")
|
api_key = self.api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return False
|
return False
|
||||||
client = AsyncOpenAI(api_key=api_key)
|
_ = AsyncOpenAI(api_key=api_key)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available OpenAI models."""
|
||||||
|
return [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
|
"""Prompt registry modules."""
|
||||||
|
|
||||||
from .local import LocalRegistry
|
from .local import LocalRegistry
|
||||||
from .remote import RemoteRegistry
|
from .remote import RemoteRegistry
|
||||||
from .models import RegistryEntry, SearchResult
|
from .models import RegistryEntry, RegistrySearchResult
|
||||||
|
|
||||||
__all__ = ["LocalRegistry", "RemoteRegistry", "RegistryEntry", "SearchResult"]
|
__all__ = [
|
||||||
|
"LocalRegistry",
|
||||||
|
"RemoteRegistry",
|
||||||
|
"RegistryEntry",
|
||||||
|
"RegistrySearchResult",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,67 +1,155 @@
|
|||||||
import os
|
"""Local prompt registry."""
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from .models import RegistryEntry, SearchResult
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .models import RegistryEntry, RegistrySearchResult
|
||||||
from ..core.exceptions import RegistryError
|
from ..core.exceptions import RegistryError
|
||||||
|
|
||||||
|
|
||||||
class LocalRegistry:
|
class LocalRegistry:
|
||||||
|
"""Local prompt registry stored as JSON files."""
|
||||||
|
|
||||||
def __init__(self, registry_path: Optional[str] = None):
|
def __init__(self, registry_path: Optional[str] = None):
|
||||||
self.registry_path = Path(registry_path or os.path.expanduser("~/.promptforge/registry"))
|
"""Initialize local registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry_path: Path to registry directory. Defaults to ~/.promptforge/registry
|
||||||
|
"""
|
||||||
|
self.registry_path = Path(registry_path or self._default_path())
|
||||||
self.registry_path.mkdir(parents=True, exist_ok=True)
|
self.registry_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._index_file = self.registry_path / "index.json"
|
||||||
|
|
||||||
|
def _default_path(self) -> str:
|
||||||
|
import os
|
||||||
|
return os.path.expanduser("~/.promptforge/registry")
|
||||||
|
|
||||||
|
def _load_index(self) -> Dict[str, RegistryEntry]:
|
||||||
|
"""Load registry index."""
|
||||||
|
if not self._index_file.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with open(self._index_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return {
|
||||||
|
k: RegistryEntry(**v) for k, v in data.items()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise RegistryError(f"Failed to load registry index: {e}")
|
||||||
|
|
||||||
|
def _save_index(self, index: Dict[str, RegistryEntry]) -> None:
|
||||||
|
"""Save registry index."""
|
||||||
|
try:
|
||||||
|
data = {k: v.model_dump() for k, v in index.items()}
|
||||||
|
with open(self._index_file, 'w') as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
raise RegistryError(f"Failed to save registry index: {e}")
|
||||||
|
|
||||||
def add(self, entry: RegistryEntry) -> None:
|
def add(self, entry: RegistryEntry) -> None:
|
||||||
try:
|
"""Add an entry to the registry.
|
||||||
entry.to_file(self.registry_path)
|
|
||||||
except Exception as e:
|
|
||||||
raise RegistryError(f"Failed to add entry to registry: {e}")
|
|
||||||
|
|
||||||
def list(self, tag: Optional[str] = None, limit: int = 20) -> List[RegistryEntry]:
|
Args:
|
||||||
entries = []
|
entry: Registry entry to add.
|
||||||
for filepath in self.registry_path.glob("*.yaml"):
|
"""
|
||||||
try:
|
index = self._load_index()
|
||||||
entry = RegistryEntry.from_file(filepath)
|
entry.id = str(entry.id or uuid.uuid4())
|
||||||
if tag is None or tag in entry.tags:
|
entry.added_at = entry.added_at or datetime.utcnow()
|
||||||
entries.append(entry)
|
entry.updated_at = datetime.utcnow()
|
||||||
except Exception:
|
index[entry.id] = entry
|
||||||
continue
|
self._save_index(index)
|
||||||
return entries[:limit]
|
|
||||||
|
def remove(self, entry_id: str) -> bool:
|
||||||
|
"""Remove an entry from the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: ID of entry to remove.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was removed, False if not found.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
if entry_id not in index:
|
||||||
|
return False
|
||||||
|
del index[entry_id]
|
||||||
|
self._save_index(index)
|
||||||
|
return True
|
||||||
|
|
||||||
def get(self, entry_id: str) -> Optional[RegistryEntry]:
|
def get(self, entry_id: str) -> Optional[RegistryEntry]:
|
||||||
filepath = self.registry_path / f"{entry_id}.yaml"
|
"""Get an entry by ID.
|
||||||
if filepath.exists():
|
|
||||||
return RegistryEntry.from_file(filepath)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def search(self, query: str) -> List[SearchResult]:
|
Args:
|
||||||
|
entry_id: Entry ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Registry entry or None if not found.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
return index.get(entry_id)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self,
|
||||||
|
tag: Optional[str] = None,
|
||||||
|
author: Optional[str] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> List[RegistryEntry]:
|
||||||
|
"""List entries in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: Filter by tag.
|
||||||
|
author: Filter by author.
|
||||||
|
limit: Maximum results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
|
index = self._load_index()
|
||||||
|
results = list(index.values())
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
results = [e for e in results if tag in e.tags]
|
||||||
|
|
||||||
|
if author:
|
||||||
|
results = [e for e in results if e.author == author]
|
||||||
|
|
||||||
|
results.sort(key=lambda e: e.added_at or datetime.min, reverse=True)
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
def search(self, query: str) -> List[RegistrySearchResult]:
|
||||||
|
"""Search registry entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries with relevance scores.
|
||||||
|
"""
|
||||||
|
entries = self.list(limit=100)
|
||||||
results = []
|
results = []
|
||||||
query_lower = query.lower()
|
|
||||||
|
|
||||||
for entry in self.list():
|
query_lower = query.lower()
|
||||||
score = 0.0
|
for entry in entries:
|
||||||
|
score = 0
|
||||||
|
|
||||||
if query_lower in entry.name.lower():
|
if query_lower in entry.name.lower():
|
||||||
score += 2.0
|
score += 10
|
||||||
if entry.description and query_lower in entry.description.lower():
|
if entry.description and query_lower in entry.description.lower():
|
||||||
score += 1.0
|
score += 5
|
||||||
if any(query_lower in tag.lower() for tag in entry.tags):
|
if any(query_lower in tag for tag in entry.tags):
|
||||||
score += 0.5
|
score += 3
|
||||||
|
|
||||||
if score > 0:
|
if score > 0:
|
||||||
results.append(SearchResult(entry, score))
|
results.append(RegistrySearchResult(
|
||||||
|
entry=entry,
|
||||||
|
relevance_score=score,
|
||||||
|
))
|
||||||
|
|
||||||
return sorted(results, key=lambda r: r.relevance_score, reverse=True)
|
return sorted(results, key=lambda r: r.relevance_score, reverse=True)
|
||||||
|
|
||||||
def delete(self, entry_id: str) -> bool:
|
def count(self) -> int:
|
||||||
filepath = self.registry_path / f"{entry_id}.yaml"
|
"""Get total number of entries."""
|
||||||
if filepath.exists():
|
index = self._load_index()
|
||||||
filepath.unlink()
|
return len(index)
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def import_prompt(self, filepath: Path) -> RegistryEntry:
|
|
||||||
from ..core.prompt import Prompt
|
|
||||||
prompt = Prompt.load(filepath)
|
|
||||||
entry = RegistryEntry.from_prompt(prompt)
|
|
||||||
self.add(entry)
|
|
||||||
return entry
|
|
||||||
|
|||||||
@@ -1,55 +1,88 @@
|
|||||||
import uuid
|
"""Registry data models."""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
from typing import Any, Dict, List, Optional
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from ..core.prompt import Prompt
|
from ..core.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
class RegistryEntry(BaseModel):
|
class RegistryEntry(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
|
"""Entry in the prompt registry."""
|
||||||
|
|
||||||
|
id: Optional[str] = Field(default_factory=lambda: str(uuid4()))
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
content: str
|
content: str
|
||||||
version: str = "1.0.0"
|
|
||||||
author: Optional[str] = None
|
author: Optional[str] = None
|
||||||
provider: Optional[str] = None
|
version: str = "1.0.0"
|
||||||
tags: List[str] = Field(default_factory=list)
|
tags: List[str] = Field(default_factory=list)
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
provider: Optional[str] = None
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
variables: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
validation_rules: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
downloads: int = 0
|
downloads: int = 0
|
||||||
|
likes: int = 0
|
||||||
rating: float = 0.0
|
rating: float = 0.0
|
||||||
|
is_local: bool = True
|
||||||
|
is_published: bool = False
|
||||||
|
added_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
def to_prompt_content(self) -> str:
|
||||||
|
"""Convert entry to prompt YAML content."""
|
||||||
|
from ..core.prompt import Prompt, PromptVariable, ValidationRule
|
||||||
|
|
||||||
|
variables = [PromptVariable(**v) for v in self.variables]
|
||||||
|
validation_rules = [ValidationRule(**r) for r in self.validation_rules]
|
||||||
|
|
||||||
|
prompt = Prompt(
|
||||||
|
id=str(self.id) if self.id else "",
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
content=self.content,
|
||||||
|
variables=variables,
|
||||||
|
validation_rules=validation_rules,
|
||||||
|
provider=self.provider,
|
||||||
|
tags=self.tags,
|
||||||
|
version=self.version,
|
||||||
|
)
|
||||||
|
return prompt.to_yaml()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_prompt(cls, prompt: Prompt, author: Optional[str] = None) -> "RegistryEntry":
|
def from_prompt(cls, prompt: "Prompt", author: Optional[str] = None) -> "RegistryEntry":
|
||||||
|
"""Create registry entry from a Prompt."""
|
||||||
return cls(
|
return cls(
|
||||||
|
id=str(prompt.id),
|
||||||
name=prompt.name,
|
name=prompt.name,
|
||||||
description=prompt.description,
|
description=prompt.description,
|
||||||
content=prompt.content,
|
content=prompt.content,
|
||||||
version=prompt.version,
|
|
||||||
author=author,
|
author=author,
|
||||||
provider=prompt.provider,
|
version=prompt.version,
|
||||||
tags=prompt.tags,
|
tags=prompt.tags,
|
||||||
|
provider=prompt.provider,
|
||||||
|
variables=[v.model_dump() for v in prompt.variables],
|
||||||
|
validation_rules=[r.model_dump() for r in prompt.validation_rules],
|
||||||
|
is_local=True,
|
||||||
|
added_at=datetime.utcnow(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_file(self, registry_dir: Path) -> Path:
|
|
||||||
registry_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
filepath = registry_dir / f"{self.id}.yaml"
|
|
||||||
with open(filepath, 'w') as f:
|
|
||||||
f.write(self.model_dump_json(indent=2))
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
@classmethod
|
class RegistrySearchResult(BaseModel):
|
||||||
def from_file(cls, filepath: Path) -> "RegistryEntry":
|
"""Search result from registry."""
|
||||||
import json
|
|
||||||
with open(filepath, 'r') as f:
|
entry: RegistryEntry
|
||||||
data = json.load(f)
|
relevance_score: float = 0.0
|
||||||
return cls(**data)
|
highlights: Dict[str, List[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class SearchResult:
|
class RegistryStats(BaseModel):
|
||||||
def __init__(self, entry: RegistryEntry, relevance_score: float = 1.0):
|
"""Registry statistics."""
|
||||||
self.entry = entry
|
|
||||||
self.relevance_score = relevance_score
|
total_entries: int = 0
|
||||||
|
local_entries: int = 0
|
||||||
|
published_entries: int = 0
|
||||||
|
popular_tags: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
top_authors: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -1,53 +1,149 @@
|
|||||||
from typing import List, Optional
|
"""Remote registry via HTTP."""
|
||||||
|
|
||||||
from .local import LocalRegistry
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
from .models import RegistryEntry, SearchResult
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .models import RegistryEntry, RegistrySearchResult
|
||||||
from ..core.exceptions import RegistryError
|
from ..core.exceptions import RegistryError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .local import LocalRegistry
|
||||||
|
|
||||||
|
|
||||||
class RemoteRegistry:
|
class RemoteRegistry:
|
||||||
def __init__(self, base_url: str = "https://registry.promptforge.io"):
|
"""Remote prompt registry accessed via HTTP API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "https://registry.promptforge.io",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize remote registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Base URL of the registry API.
|
||||||
|
api_key: API key for authentication.
|
||||||
|
"""
|
||||||
self.base_url = base_url.rstrip('/')
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.api_key = api_key
|
||||||
|
self._session = requests.Session()
|
||||||
|
if api_key:
|
||||||
|
self._session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||||
|
|
||||||
|
def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
data: Optional[Dict] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Make HTTP request to registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method.
|
||||||
|
endpoint: API endpoint.
|
||||||
|
data: Request data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response JSON.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RegistryError: If request fails.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/v1{endpoint}"
|
||||||
|
try:
|
||||||
|
response = self._session.request(method, url, json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
raise RegistryError(f"Registry API error: {e.response.text}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RegistryError(f"Registry connection error: {e}")
|
||||||
|
|
||||||
|
def search(self, query: str, limit: int = 20) -> List[RegistrySearchResult]:
|
||||||
|
"""Search remote registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query.
|
||||||
|
limit: Maximum results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
|
data = self._request("GET", f"/search?q={query}&limit={limit}")
|
||||||
|
results = []
|
||||||
|
for item in data.get("results", []):
|
||||||
|
entry = RegistryEntry(**item)
|
||||||
|
results.append(RegistrySearchResult(
|
||||||
|
entry=entry,
|
||||||
|
relevance_score=item.get("score", 0),
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get(self, entry_id: str) -> Optional[RegistryEntry]:
|
||||||
|
"""Get entry by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: Entry ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Registry entry or None.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = self._request("GET", f"/entries/{entry_id}")
|
||||||
|
return RegistryEntry(**data)
|
||||||
|
except RegistryError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def pull(self, entry_id: str, local_registry: "LocalRegistry") -> bool:
|
||||||
|
"""Pull entry from remote to local registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_id: Entry ID to pull.
|
||||||
|
local_registry: Local registry to save to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful.
|
||||||
|
"""
|
||||||
|
entry = self.get(entry_id)
|
||||||
|
if entry is None:
|
||||||
|
return False
|
||||||
|
entry.is_local = True
|
||||||
|
local_registry.add(entry)
|
||||||
|
return True
|
||||||
|
|
||||||
def publish(self, entry: RegistryEntry) -> RegistryEntry:
|
def publish(self, entry: RegistryEntry) -> RegistryEntry:
|
||||||
try:
|
"""Publish entry to remote registry.
|
||||||
import requests
|
|
||||||
response = requests.post(
|
|
||||||
f"{self.base_url}/api/entries",
|
|
||||||
json=entry.model_dump(),
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return RegistryEntry(**response.json())
|
|
||||||
except Exception as e:
|
|
||||||
raise RegistryError(f"Failed to publish to remote registry: {e}")
|
|
||||||
|
|
||||||
def pull(self, entry_id: str, local: LocalRegistry) -> bool:
|
Args:
|
||||||
|
entry: Entry to publish.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Published entry with server-assigned ID.
|
||||||
|
"""
|
||||||
|
data = self._request("POST", "/entries", entry.model_dump())
|
||||||
|
return RegistryEntry(**data)
|
||||||
|
|
||||||
|
def list_popular(self, limit: int = 10) -> List[RegistryEntry]:
|
||||||
|
"""List popular entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of popular entries.
|
||||||
|
"""
|
||||||
|
data = self._request("GET", f"/popular?limit={limit}")
|
||||||
|
return [RegistryEntry(**item) for item in data.get("entries", [])]
|
||||||
|
|
||||||
|
def validate_connection(self) -> bool:
|
||||||
|
"""Validate connection to remote registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import requests
|
self._request("GET", "/health")
|
||||||
response = requests.get(
|
|
||||||
f"{self.base_url}/api/entries/{entry_id}",
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
if response.status_code == 404:
|
|
||||||
return False
|
|
||||||
response.raise_for_status()
|
|
||||||
entry = RegistryEntry(**response.json())
|
|
||||||
local.add(entry)
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except RegistryError:
|
||||||
raise RegistryError(f"Failed to pull from remote registry: {e}")
|
return False
|
||||||
|
|
||||||
def search(self, query: str) -> List[SearchResult]:
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
response = requests.get(
|
|
||||||
f"{self.base_url}/api/search",
|
|
||||||
params={"q": query},
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return [SearchResult(RegistryEntry(**e), 1.0) for e in data.get("results", [])]
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|||||||
@@ -1,5 +1,19 @@
|
|||||||
from .ab_test import ABTest, ABTestConfig
|
"""A/B testing and validation modules."""
|
||||||
from .metrics import TestMetrics, MetricsCollector
|
|
||||||
from .results import TestResult, ComparisonResult
|
|
||||||
|
|
||||||
__all__ = ["ABTest", "ABTestConfig", "TestMetrics", "MetricsCollector", "TestResult", "ComparisonResult"]
|
from .ab_test import ABTest, ABTestConfig, ABTestResult
|
||||||
|
from .validator import Validator, JSONSchemaValidator, RegexValidator, CompositeValidator
|
||||||
|
from .metrics import MetricsCollector
|
||||||
|
from .results import TestSessionResults, ResultFormatter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ABTest",
|
||||||
|
"ABTestConfig",
|
||||||
|
"ABTestResult",
|
||||||
|
"Validator",
|
||||||
|
"JSONSchemaValidator",
|
||||||
|
"RegexValidator",
|
||||||
|
"CompositeValidator",
|
||||||
|
"MetricsCollector",
|
||||||
|
"TestSessionResults",
|
||||||
|
"ResultFormatter",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,79 +1,200 @@
|
|||||||
import asyncio
|
"""A/B testing framework for comparing prompt variations."""
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
import time
|
||||||
from typing import AsyncIterator, Dict, List, Optional
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .metrics import TestMetrics, MetricsCollector
|
|
||||||
from .results import TestResult, ComparisonResult
|
|
||||||
from ..core.prompt import Prompt
|
from ..core.prompt import Prompt
|
||||||
from ..providers.base import ProviderBase, ProviderResponse
|
from ..providers import ProviderBase, ProviderResponse
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ABTestConfig:
|
class ABTestConfig:
|
||||||
|
"""Configuration for A/B test."""
|
||||||
|
|
||||||
iterations: int = 3
|
iterations: int = 3
|
||||||
|
provider: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: float = 0.7
|
||||||
parallel: bool = False
|
parallel: bool = False
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestResult:
|
||||||
|
"""Result of a single test run."""
|
||||||
|
|
||||||
|
prompt: Prompt
|
||||||
|
response: ProviderResponse
|
||||||
|
variables: Dict[str, Any]
|
||||||
|
iteration: int
|
||||||
|
passed_validation: bool = False
|
||||||
|
validation_errors: List[str] = field(default_factory=list)
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestSummary:
|
||||||
|
"""Summary of A/B test results."""
|
||||||
|
|
||||||
|
prompt_name: str
|
||||||
|
config: ABTestConfig
|
||||||
|
total_runs: int
|
||||||
|
successful_runs: int
|
||||||
|
failed_runs: int
|
||||||
|
avg_latency_ms: float
|
||||||
|
avg_tokens: float
|
||||||
|
avg_cost: float
|
||||||
|
results: List[ABTestResult]
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
class ABTest:
|
class ABTest:
|
||||||
def __init__(self, provider: ProviderBase, config: ABTestConfig):
|
"""A/B test runner for comparing prompt variations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: ProviderBase,
|
||||||
|
config: Optional[ABTestConfig] = None,
|
||||||
|
):
|
||||||
|
"""Initialize A/B test runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: LLM provider to use.
|
||||||
|
config: Test configuration.
|
||||||
|
"""
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.config = config
|
self.config = config or ABTestConfig()
|
||||||
self.metrics_collector = MetricsCollector()
|
|
||||||
|
|
||||||
async def run_single(self, prompt: Prompt, variables: Dict[str, str]) -> TestResult:
|
async def run(
|
||||||
test_id = str(uuid.uuid4())[:8]
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
) -> ABTestSummary:
|
||||||
|
"""Run A/B test on a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt to test.
|
||||||
|
variables: Variables to substitute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ABTestSummary with all test results.
|
||||||
|
"""
|
||||||
|
results: List[ABTestResult] = []
|
||||||
|
latencies = []
|
||||||
|
total_tokens = []
|
||||||
|
|
||||||
|
for i in range(self.config.iterations):
|
||||||
try:
|
try:
|
||||||
response = await self.provider.complete(
|
result = await self._run_single(prompt, variables, i + 1)
|
||||||
prompt.content.format(**variables) if variables else prompt.content
|
results.append(result)
|
||||||
)
|
latencies.append(result.latency_ms)
|
||||||
|
total_tokens.append(result.response.usage.get("total_tokens", 0))
|
||||||
metrics = TestMetrics(
|
except Exception:
|
||||||
test_id=test_id,
|
results.append(ABTestResult(
|
||||||
prompt_name=prompt.name,
|
prompt=prompt,
|
||||||
|
response=ProviderResponse(
|
||||||
|
content="",
|
||||||
|
model=prompt.provider or self.provider.name,
|
||||||
provider=self.provider.name,
|
provider=self.provider.name,
|
||||||
model=self.provider.model,
|
),
|
||||||
latency_ms=response.latency_ms,
|
variables=variables,
|
||||||
success=True,
|
iteration=i + 1,
|
||||||
tokens_used=response.usage.get("total_tokens", 0) if response.usage else 0,
|
passed_validation=False,
|
||||||
)
|
validation_errors=["Test execution failed"],
|
||||||
|
))
|
||||||
|
|
||||||
return TestResult(success=True, response=response.content, metrics=metrics)
|
successful = sum(1 for r in results if r.passed_validation or r.response.content)
|
||||||
|
|
||||||
except Exception as e:
|
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
||||||
metrics = TestMetrics(
|
avg_tokens = sum(total_tokens) / len(total_tokens) if total_tokens else 0
|
||||||
test_id=test_id,
|
|
||||||
|
return ABTestSummary(
|
||||||
prompt_name=prompt.name,
|
prompt_name=prompt.name,
|
||||||
provider=self.provider.name,
|
config=self.config,
|
||||||
model=self.provider.model,
|
total_runs=self.config.iterations,
|
||||||
latency_ms=0,
|
successful_runs=successful,
|
||||||
success=False,
|
failed_runs=self.config.iterations - successful,
|
||||||
error_message=str(e),
|
avg_latency_ms=avg_latency,
|
||||||
|
avg_tokens=avg_tokens,
|
||||||
|
avg_cost=self._estimate_cost(avg_tokens),
|
||||||
|
results=results,
|
||||||
)
|
)
|
||||||
return TestResult(success=False, response="", metrics=metrics, error=str(e))
|
|
||||||
|
|
||||||
async def run_comparison(self, prompts: List[Prompt]) -> Dict[str, ComparisonResult]:
|
async def run_comparison(
|
||||||
results = {}
|
self,
|
||||||
|
prompts: List[Prompt],
|
||||||
|
shared_variables: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, ABTestSummary]:
|
||||||
|
"""Run tests on multiple prompts for comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompts to compare.
|
||||||
|
shared_variables: Variables shared across all prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping prompt names to their summaries.
|
||||||
|
"""
|
||||||
|
shared_variables = shared_variables or {}
|
||||||
|
summaries = {}
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
all_metrics: List[TestMetrics] = []
|
variables = self._merge_variables(prompt, shared_variables)
|
||||||
|
summary = await self.run(prompt, variables)
|
||||||
|
summaries[prompt.name] = summary
|
||||||
|
|
||||||
for _ in range(self.config.iterations):
|
return summaries
|
||||||
result = await self.run_single(prompt, {})
|
|
||||||
all_metrics.append(result.metrics)
|
|
||||||
|
|
||||||
comparison = self.metrics_collector.compare(prompt.name, all_metrics)
|
async def _run_single(
|
||||||
results[prompt.name] = comparison
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
iteration: int,
|
||||||
|
) -> ABTestResult:
|
||||||
|
"""Run a single test iteration."""
|
||||||
|
from ..core.template import TemplateEngine
|
||||||
|
template_engine = TemplateEngine()
|
||||||
|
|
||||||
return results
|
rendered = template_engine.render(prompt.content, variables, prompt.variables)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
async def run_tests(self, prompt: Prompt, iterations: Optional[int] = None) -> ComparisonResult:
|
response = await self.provider.complete(
|
||||||
iterations = iterations or self.config.iterations
|
prompt=rendered,
|
||||||
all_metrics: List[TestMetrics] = []
|
max_tokens=self.config.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(iterations):
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
result = await self.run_single(prompt, {})
|
|
||||||
all_metrics.append(result.metrics)
|
|
||||||
|
|
||||||
return self.metrics_collector.compare(prompt.name, all_metrics)
|
return ABTestResult(
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
variables=variables,
|
||||||
|
iteration=iteration,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_variables(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
shared: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge shared variables with prompt-specific ones."""
|
||||||
|
variables = shared.copy()
|
||||||
|
for var in prompt.variables:
|
||||||
|
if var.name not in variables and var.default is not None:
|
||||||
|
variables[var.name] = var.default
|
||||||
|
return variables
|
||||||
|
|
||||||
|
def _estimate_cost(self, tokens: float) -> float:
|
||||||
|
"""Estimate cost based on token usage."""
|
||||||
|
rates = {
|
||||||
|
"gpt-4": 0.00003,
|
||||||
|
"gpt-4-turbo": 0.00001,
|
||||||
|
"gpt-3.5-turbo": 0.0000005,
|
||||||
|
"claude-3-sonnet-20240229": 0.000003,
|
||||||
|
"claude-3-opus-20240229": 0.000015,
|
||||||
|
}
|
||||||
|
rate = rates.get(self.provider.model, 0.000001)
|
||||||
|
return tokens * rate
|
||||||
|
|||||||
@@ -1,86 +1,141 @@
|
|||||||
|
"""Metrics collection for A/B testing."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestMetrics:
|
class MetricsSample:
|
||||||
test_id: str
|
"""Single metrics sample from a test run."""
|
||||||
prompt_name: str
|
|
||||||
provider: str
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
model: str
|
latency_ms: float = 0.0
|
||||||
latency_ms: float
|
tokens_prompt: int = 0
|
||||||
success: bool
|
tokens_completion: int = 0
|
||||||
tokens_used: int = 0
|
tokens_total: int = 0
|
||||||
cost_estimate: float = 0.0
|
cost: float = 0.0
|
||||||
error_message: Optional[str] = None
|
validation_passed: bool = False
|
||||||
|
validation_errors: List[str] = field(default_factory=list)
|
||||||
|
custom_metrics: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComparisonResult:
|
class MetricsSummary:
|
||||||
prompt_name: str
|
"""Summary statistics for collected metrics."""
|
||||||
total_runs: int
|
|
||||||
successful_runs: int
|
name: str
|
||||||
failed_runs: int
|
count: int = 0
|
||||||
avg_latency_ms: float
|
latency: Dict[str, float] = field(default_factory=dict)
|
||||||
min_latency_ms: float
|
tokens: Dict[str, float] = field(default_factory=dict)
|
||||||
max_latency_ms: float
|
cost: Dict[str, float] = field(default_factory=dict)
|
||||||
avg_tokens: float
|
validation_pass_rate: float = 0.0
|
||||||
avg_cost: float
|
samples: List[MetricsSample] = field(default_factory=list)
|
||||||
success_rate: float
|
|
||||||
all_metrics: List[TestMetrics] = field(default_factory=list)
|
@classmethod
|
||||||
|
def from_samples(cls, name: str, samples: List[MetricsSample]) -> "MetricsSummary":
|
||||||
|
"""Create summary from list of samples."""
|
||||||
|
if not samples:
|
||||||
|
return cls(name=name)
|
||||||
|
|
||||||
|
latencies = [s.latency_ms for s in samples]
|
||||||
|
tokens = [s.tokens_total for s in samples]
|
||||||
|
costs = [s.cost for s in samples]
|
||||||
|
valid_count = sum(1 for s in samples if s.validation_passed)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
count=len(samples),
|
||||||
|
latency={
|
||||||
|
"min": min(latencies),
|
||||||
|
"max": max(latencies),
|
||||||
|
"avg": sum(latencies) / len(latencies),
|
||||||
|
},
|
||||||
|
tokens={
|
||||||
|
"min": min(tokens),
|
||||||
|
"max": max(tokens),
|
||||||
|
"avg": sum(tokens) / len(tokens),
|
||||||
|
},
|
||||||
|
cost={
|
||||||
|
"min": min(costs),
|
||||||
|
"max": max(costs),
|
||||||
|
"avg": sum(costs) / len(costs),
|
||||||
|
},
|
||||||
|
validation_pass_rate=valid_count / len(samples),
|
||||||
|
samples=samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetricsCollector:
|
class MetricsCollector:
|
||||||
|
"""Collect and aggregate metrics from test runs."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.metrics: List[TestMetrics] = []
|
"""Initialize metrics collector."""
|
||||||
|
self._samples: List[MetricsSample] = []
|
||||||
|
|
||||||
def add(self, metrics: TestMetrics) -> None:
|
def record(self, sample: MetricsSample) -> None:
|
||||||
self.metrics.append(metrics)
|
"""Record a metrics sample."""
|
||||||
|
self._samples.append(sample)
|
||||||
|
|
||||||
def compare(self, prompt_name: str, metrics_list: List[TestMetrics]) -> ComparisonResult:
|
def record_from_response(
|
||||||
if not metrics_list:
|
self,
|
||||||
return ComparisonResult(
|
latency_ms: float,
|
||||||
prompt_name=prompt_name,
|
usage: Dict[str, int],
|
||||||
total_runs=0,
|
validation_passed: bool = False,
|
||||||
successful_runs=0,
|
validation_errors: Optional[List[str]] = None,
|
||||||
failed_runs=0,
|
cost: float = 0.0,
|
||||||
avg_latency_ms=0,
|
custom_metrics: Optional[Dict[str, Any]] = None,
|
||||||
min_latency_ms=0,
|
) -> MetricsSample:
|
||||||
max_latency_ms=0,
|
"""Record metrics from a provider response."""
|
||||||
avg_tokens=0,
|
sample = MetricsSample(
|
||||||
avg_cost=0,
|
latency_ms=latency_ms,
|
||||||
success_rate=0,
|
tokens_prompt=usage.get("prompt_tokens", 0),
|
||||||
|
tokens_completion=usage.get("completion_tokens", 0),
|
||||||
|
tokens_total=usage.get("total_tokens", usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)),
|
||||||
|
cost=cost,
|
||||||
|
validation_passed=validation_passed,
|
||||||
|
validation_errors=validation_errors or [],
|
||||||
|
custom_metrics=custom_metrics or {},
|
||||||
)
|
)
|
||||||
|
self.record(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
successful = [m for m in metrics_list if m.success]
|
def get_summary(self, name: str = "test") -> MetricsSummary:
|
||||||
failed = [m for m in metrics_list if not m.success]
|
"""Get summary of all collected metrics."""
|
||||||
|
return MetricsSummary.from_samples(name, self._samples)
|
||||||
|
|
||||||
latencies = [m.latency_ms for m in successful]
|
def clear(self) -> None:
|
||||||
tokens = [m.tokens_used for m in successful]
|
"""Clear all collected samples."""
|
||||||
costs = [m.cost_estimate for m in successful]
|
self._samples.clear()
|
||||||
|
|
||||||
return ComparisonResult(
|
def get_samples(self) -> List[MetricsSample]:
|
||||||
prompt_name=prompt_name,
|
"""Get all collected samples."""
|
||||||
total_runs=len(metrics_list),
|
return list(self._samples)
|
||||||
successful_runs=len(successful),
|
|
||||||
failed_runs=len(failed),
|
|
||||||
avg_latency_ms=sum(latencies) / len(latencies) if latencies else 0,
|
|
||||||
min_latency_ms=min(latencies) if latencies else 0,
|
|
||||||
max_latency_ms=max(latencies) if latencies else 0,
|
|
||||||
avg_tokens=sum(tokens) / len(tokens) if tokens else 0,
|
|
||||||
avg_cost=sum(costs) / len(costs) if costs else 0,
|
|
||||||
success_rate=len(successful) / len(metrics_list) if metrics_list else 0,
|
|
||||||
all_metrics=metrics_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_summary(self) -> Dict[str, ComparisonResult]:
|
def compare(
|
||||||
by_prompt: Dict[str, List[TestMetrics]] = {}
|
self,
|
||||||
for m in self.metrics:
|
other: "MetricsCollector",
|
||||||
if m.prompt_name not in by_prompt:
|
) -> Dict[str, Any]:
|
||||||
by_prompt[m.prompt_name] = []
|
"""Compare metrics between two collectors.
|
||||||
by_prompt[m.prompt_name].append(m)
|
|
||||||
|
Args:
|
||||||
|
other: Another metrics collector to compare against.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with comparison statistics.
|
||||||
|
"""
|
||||||
|
summary1 = self.get_summary("a")
|
||||||
|
summary2 = other.get_summary("b")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
name: self.compare(name, metrics)
|
"latency_delta_ms": summary2.latency.get("avg", 0) - summary1.latency.get("avg", 0),
|
||||||
for name, metrics in by_prompt.items()
|
"tokens_delta": summary2.tokens.get("avg", 0) - summary1.tokens.get("avg", 0),
|
||||||
|
"cost_delta": summary2.cost.get("avg", 0) - summary1.cost.get("avg", 0),
|
||||||
|
"validation_pass_rate_delta": (
|
||||||
|
summary2.validation_pass_rate - summary1.validation_pass_rate
|
||||||
|
),
|
||||||
|
"sample_count": {
|
||||||
|
"a": summary1.count,
|
||||||
|
"b": summary2.count,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
@@ -1,33 +1,139 @@
|
|||||||
from dataclasses import dataclass
|
"""Test results and formatting."""
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .ab_test import ABTestSummary
|
||||||
|
from .metrics import MetricsSummary
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestResult:
|
class TestResult:
|
||||||
|
"""Result of a single test."""
|
||||||
|
|
||||||
|
test_id: str
|
||||||
|
prompt_name: str
|
||||||
|
provider: str
|
||||||
success: bool
|
success: bool
|
||||||
response: str
|
response: str
|
||||||
metrics: "TestMetrics"
|
metrics: Dict[str, Any] = field(default_factory=dict)
|
||||||
error: Optional[str] = None
|
validation_results: Dict[str, bool] = field(default_factory=dict)
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComparisonResult:
|
class TestSessionResults:
|
||||||
prompt_name: str
|
"""Collection of test results."""
|
||||||
total_runs: int
|
|
||||||
successful_runs: int
|
|
||||||
failed_runs: int
|
|
||||||
avg_latency_ms: float
|
|
||||||
min_latency_ms: float
|
|
||||||
max_latency_ms: float
|
|
||||||
avg_tokens: float
|
|
||||||
avg_cost: float
|
|
||||||
success_rate: float
|
|
||||||
all_metrics: List["TestMetrics"]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TestReport:
|
|
||||||
test_id: str
|
test_id: str
|
||||||
timestamp: str
|
name: str
|
||||||
results: Dict[str, ComparisonResult]
|
results: List[TestResult] = field(default_factory=list)
|
||||||
summary: Dict[str, float]
|
metrics: MetricsSummary = field(default_factory=lambda: MetricsSummary(name=""))
|
||||||
|
ab_comparisons: Dict[str, ABTestSummary] = field(default_factory=dict)
|
||||||
|
start_time: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
end_time: Optional[datetime] = None
|
||||||
|
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_seconds(self) -> float:
|
||||||
|
"""Get test duration in seconds."""
|
||||||
|
if self.end_time is None:
|
||||||
|
return 0.0
|
||||||
|
return (self.end_time - self.start_time).total_seconds()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success_count(self) -> int:
|
||||||
|
"""Count of successful tests."""
|
||||||
|
return sum(1 for r in self.results if r.success)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failure_count(self) -> int:
|
||||||
|
"""Count of failed tests."""
|
||||||
|
return sum(1 for r in self.results if not r.success)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pass_rate(self) -> float:
|
||||||
|
"""Calculate pass rate."""
|
||||||
|
if not self.results:
|
||||||
|
return 0.0
|
||||||
|
return self.success_count / len(self.results)
|
||||||
|
|
||||||
|
|
||||||
|
class ResultFormatter:
|
||||||
|
"""Format test results for display."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_text(results: TestSessionResults) -> str:
|
||||||
|
"""Format results as plain text."""
|
||||||
|
lines = [
|
||||||
|
f"Test Results: {results.name}",
|
||||||
|
f"Duration: {results.duration_seconds:.2f}s",
|
||||||
|
f"Passed: {results.success_count}/{len(results.results)} ({results.pass_rate:.1%})",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
for result in results.results:
|
||||||
|
status = "PASS" if result.success else "FAIL"
|
||||||
|
lines.append(f"[{status}] {result.prompt_name}")
|
||||||
|
if result.error_message:
|
||||||
|
lines.append(f" Error: {result.error_message}")
|
||||||
|
if result.metrics:
|
||||||
|
metrics_str = ", ".join(f"{k}: {v}" for k, v in result.metrics.items())
|
||||||
|
lines.append(f" Metrics: {metrics_str}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_json(results: TestSessionResults) -> str:
|
||||||
|
"""Format results as JSON."""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def serialize(obj):
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"test_id": results.test_id,
|
||||||
|
"name": results.name,
|
||||||
|
"duration_seconds": results.duration_seconds,
|
||||||
|
"summary": {
|
||||||
|
"total": len(results.results),
|
||||||
|
"passed": results.success_count,
|
||||||
|
"failed": results.failure_count,
|
||||||
|
"pass_rate": results.pass_rate,
|
||||||
|
},
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"test_id": r.test_id,
|
||||||
|
"prompt_name": r.prompt_name,
|
||||||
|
"provider": r.provider,
|
||||||
|
"success": r.success,
|
||||||
|
"response": r.response[:500] if r.response else "",
|
||||||
|
"metrics": r.metrics,
|
||||||
|
"validation_results": r.validation_results,
|
||||||
|
"error_message": r.error_message,
|
||||||
|
"timestamp": r.timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
for r in results.results
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return json.dumps(data, default=serialize, indent=2)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_ab_comparison(comparisons: Dict[str, ABTestSummary]) -> str:
|
||||||
|
"""Format A/B test comparisons."""
|
||||||
|
lines = ["A/B Test Comparison", "=" * 40]
|
||||||
|
|
||||||
|
for name, summary in comparisons.items():
|
||||||
|
lines.append(f"\nPrompt: {name}")
|
||||||
|
lines.append(f" Runs: {summary.successful_runs}/{summary.total_runs}")
|
||||||
|
lines.append(f" Avg Latency: {summary.avg_latency_ms:.2f}ms")
|
||||||
|
lines.append(f" Avg Tokens: {summary.avg_tokens:.0f}")
|
||||||
|
lines.append(f" Avg Cost: ${summary.avg_cost:.4f}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|||||||
@@ -1,43 +1,249 @@
|
|||||||
import re
|
"""Response validation framework."""
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from ..core.prompt import ValidationRule
|
from abc import ABC, abstractmethod
|
||||||
from ..core.exceptions import ValidationError
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
class Validator:
|
class Validator(ABC):
|
||||||
def __init__(self, rules: Optional[List[ValidationRule]] = None):
|
"""Abstract base class for validators."""
|
||||||
self.rules = rules or []
|
|
||||||
|
|
||||||
def validate(self, response: str) -> List[str]:
|
@abstractmethod
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str:
|
||||||
|
"""Get validator name."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RegexValidator(Validator):
|
||||||
|
"""Validates responses against regex patterns."""
|
||||||
|
|
||||||
|
def __init__(self, pattern: str, flags: int = 0):
|
||||||
|
"""Initialize regex validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Regex pattern to match.
|
||||||
|
flags: Regex flags (e.g., re.IGNORECASE).
|
||||||
|
"""
|
||||||
|
self.pattern = pattern
|
||||||
|
self.flags = flags
|
||||||
|
self._regex = re.compile(pattern, flags)
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate response matches regex pattern."""
|
||||||
|
if not self._regex.search(response):
|
||||||
|
return False, f"Response does not match pattern: {self.pattern}"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return f"regex({self.pattern})"
|
||||||
|
|
||||||
|
|
||||||
|
class JSONSchemaValidator(Validator):
|
||||||
|
"""Validates JSON responses against a schema."""
|
||||||
|
|
||||||
|
def __init__(self, schema: Dict[str, Any]):
|
||||||
|
"""Initialize JSON schema validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON schema to validate against.
|
||||||
|
"""
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate JSON response against schema."""
|
||||||
|
try:
|
||||||
|
data = json.loads(response)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return False, f"Invalid JSON: {e}"
|
||||||
|
|
||||||
|
errors = self._validate_object(data, self.schema, "")
|
||||||
|
if errors:
|
||||||
|
return False, "; ".join(errors)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def _validate_object(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
schema: Dict[str, Any],
|
||||||
|
path: str,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Recursively validate against schema."""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
for rule in self.rules:
|
if "type" in schema:
|
||||||
if rule.type == "regex":
|
expected_type = schema["type"]
|
||||||
if rule.pattern and not re.search(rule.pattern, response):
|
type_checks = {
|
||||||
errors.append(rule.message or f"Response failed regex validation")
|
"array": (list, "array"),
|
||||||
|
"object": (dict, "object"),
|
||||||
|
"string": (str, "string"),
|
||||||
|
"number": ((int, float), "number"),
|
||||||
|
"boolean": (bool, "boolean"),
|
||||||
|
"integer": ((int,), "integer"),
|
||||||
|
}
|
||||||
|
if expected_type in type_checks:
|
||||||
|
expected_class, type_name = type_checks[expected_type]
|
||||||
|
if not isinstance(data, expected_class): # type: ignore[arg-type]
|
||||||
|
actual_type = type(data).__name__
|
||||||
|
errors.append(f"{path}: expected {type_name}, got {actual_type}")
|
||||||
|
return errors
|
||||||
|
|
||||||
elif rule.type == "json":
|
if "properties" in schema and isinstance(data, dict):
|
||||||
try:
|
for prop, prop_schema in schema["properties"].items():
|
||||||
json.loads(response)
|
if prop in data:
|
||||||
except json.JSONDecodeError:
|
errors.extend(
|
||||||
errors.append(rule.message or "Response is not valid JSON")
|
self._validate_object(data[prop], prop_schema, f"{path}.{prop}")
|
||||||
|
)
|
||||||
|
elif prop_schema.get("required", False):
|
||||||
|
errors.append(f"{path}.{prop}: required property missing")
|
||||||
|
|
||||||
elif rule.type == "length":
|
if "enum" in schema and data not in schema["enum"]:
|
||||||
min_len = rule.json_schema.get("minLength", 0) if rule.json_schema else 0
|
errors.append(f"{path}: value must be one of {schema['enum']}")
|
||||||
max_len = rule.json_schema.get("maxLength", float("inf")) if rule.json_schema else float("inf")
|
|
||||||
if len(response) < min_len or len(response) > max_len:
|
if "minLength" in schema and isinstance(data, str):
|
||||||
errors.append(rule.message or f"Response length must be between {min_len} and {max_len}")
|
if len(data) < schema["minLength"]:
|
||||||
|
errors.append(f"{path}: string too short (min {schema['minLength']})")
|
||||||
|
|
||||||
|
if "maxLength" in schema and isinstance(data, str):
|
||||||
|
if len(data) > schema["maxLength"]:
|
||||||
|
errors.append(f"{path}: string too long (max {schema['maxLength']})")
|
||||||
|
|
||||||
|
if "minimum" in schema and isinstance(data, (int, float)):
|
||||||
|
if data < schema["minimum"]:
|
||||||
|
errors.append(f"{path}: value below minimum ({schema['minimum']})")
|
||||||
|
|
||||||
|
if "maximum" in schema and isinstance(data, (int, float)):
|
||||||
|
if data > schema["maximum"]:
|
||||||
|
errors.append(f"{path}: value above maximum ({schema['maximum']})")
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
def is_valid(self, response: str) -> bool:
|
def get_name(self) -> str:
|
||||||
return len(self.validate(response)) == 0
|
return "json-schema"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_prompt_rules(rules: List[Dict[str, Any]]) -> "Validator":
|
class LengthValidator(Validator):
|
||||||
validation_rules = []
|
"""Validates response length constraints."""
|
||||||
for rule_data in rules:
|
|
||||||
validation_rules.append(ValidationRule(**rule_data))
|
def __init__(
|
||||||
return Validator(validation_rules)
|
self,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Initialize length validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_length: Minimum number of characters.
|
||||||
|
max_length: Maximum number of characters.
|
||||||
|
"""
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate response length."""
|
||||||
|
if self.min_length is not None and len(response) < self.min_length:
|
||||||
|
return False, f"Response too short (min {self.min_length} chars)"
|
||||||
|
if self.max_length is not None and len(response) > self.max_length:
|
||||||
|
return False, f"Response too long (max {self.max_length} chars)"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
parts = ["length"]
|
||||||
|
if self.min_length:
|
||||||
|
parts.append(f"min={self.min_length}")
|
||||||
|
if self.max_length:
|
||||||
|
parts.append(f"max={self.max_length}")
|
||||||
|
return "(" + ", ".join(parts) + ")"
|
||||||
|
|
||||||
|
|
||||||
|
class ContainsValidator(Validator):
|
||||||
|
"""Validates response contains expected content."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
required_strings: List[str],
|
||||||
|
all_required: bool = False,
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize contains validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_strings: Strings that must be present.
|
||||||
|
all_required: If True, all strings must be present.
|
||||||
|
case_sensitive: Whether to match case.
|
||||||
|
"""
|
||||||
|
self.required_strings = required_strings
|
||||||
|
self.all_required = all_required
|
||||||
|
self.case_sensitive = case_sensitive
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate response contains required strings."""
|
||||||
|
strings = self.required_strings
|
||||||
|
response_lower = response.lower() if not self.case_sensitive else response
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
for s in strings:
|
||||||
|
check_str = s.lower() if not self.case_sensitive else s
|
||||||
|
if check_str not in response_lower:
|
||||||
|
missing.append(s)
|
||||||
|
|
||||||
|
if self.all_required:
|
||||||
|
if missing:
|
||||||
|
return False, f"Missing required content: {', '.join(missing)}"
|
||||||
|
else:
|
||||||
|
if len(missing) == len(strings):
|
||||||
|
return False, "Response does not contain any expected content"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
mode = "all" if self.all_required else "any"
|
||||||
|
return f"contains({mode}, {self.required_strings})"
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeValidator(Validator):
|
||||||
|
"""Combines multiple validators."""
|
||||||
|
|
||||||
|
def __init__(self, validators: List[Validator], mode: str = "all"):
|
||||||
|
"""Initialize composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
validators: List of validators to combine.
|
||||||
|
mode: "all" (AND) or "any" (OR) behavior.
|
||||||
|
"""
|
||||||
|
self.validators = validators
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Validate using all validators."""
|
||||||
|
results = [v.validate(response) for v in self.validators]
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if self.mode == "all":
|
||||||
|
for valid, error in results:
|
||||||
|
if not valid:
|
||||||
|
errors.append(error)
|
||||||
|
if errors:
|
||||||
|
return False, "; ".join(e for e in errors if e)
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
for valid, _ in results:
|
||||||
|
if valid:
|
||||||
|
return True, None
|
||||||
|
return False, "No validator passed"
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
names = [v.get_name() for v in self.validators]
|
||||||
|
return f"composite({self.mode}, {names})"
|
||||||
|
|||||||
Reference in New Issue
Block a user