This commit is contained in:
104
app/src/promptforge/providers/base.py
Normal file
104
app/src/promptforge/providers/base.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderResponse:
|
||||
"""Response from an LLM provider."""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
provider: str
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
latency_ms: float = 0.0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ProviderBase(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "gpt-4",
|
||||
temperature: float = 0.7,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication.
|
||||
model: Model identifier to use.
|
||||
temperature: Sampling temperature (0.0-1.0).
|
||||
**kwargs: Additional provider-specific options.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name identifier."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ProviderResponse:
|
||||
"""Send a completion request.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send.
|
||||
system_prompt: Optional system instructions.
|
||||
max_tokens: Maximum tokens in response.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
ProviderResponse with the generated content.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream completions incrementally.
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send.
|
||||
system_prompt: Optional system instructions.
|
||||
max_tokens: Maximum tokens in response.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
Chunks of generated content.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_api_key(self) -> bool:
|
||||
"""Validate that the API key is configured correctly."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self) -> List[str]:
|
||||
"""List available models for this provider."""
|
||||
pass
|
||||
|
||||
def _get_system_prompt(self, prompt: str) -> Optional[str]:
|
||||
"""Extract system prompt from prompt if using special syntax."""
|
||||
if "---" in prompt:
|
||||
parts = prompt.split("---", 1)
|
||||
return parts[0].strip()
|
||||
return None
|
||||
Reference in New Issue
Block a user