This commit is contained in:
104
app/src/promptforge/providers/base.py
Normal file
104
app/src/promptforge/providers/base.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderResponse:
|
||||||
|
"""Response from an LLM provider."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
usage: Dict[str, int] = field(default_factory=dict)
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderBase(ABC):
|
||||||
|
"""Abstract base class for LLM providers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "gpt-4",
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication.
|
||||||
|
model: Model identifier to use.
|
||||||
|
temperature: Sampling temperature (0.0-1.0).
|
||||||
|
**kwargs: Additional provider-specific options.
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Provider name identifier."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderResponse:
|
||||||
|
"""Send a completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderResponse with the generated content.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stream_complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completions incrementally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Chunks of generated content.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""Validate that the API key is configured correctly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> List[str]:
|
||||||
|
"""List available models for this provider."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_system_prompt(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract system prompt from prompt if using special syntax."""
|
||||||
|
if "---" in prompt:
|
||||||
|
parts = prompt.split("---", 1)
|
||||||
|
return parts[0].strip()
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user