This commit is contained in:
74
app/src/promptforge/providers/factory.py
Normal file
74
app/src/promptforge/providers/factory.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .base import ProviderBase
|
||||||
|
from .openai import OpenAIProvider
|
||||||
|
from .anthropic import AnthropicProvider
|
||||||
|
from .ollama import OllamaProvider
|
||||||
|
from ..core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderFactory:
|
||||||
|
"""Factory for creating LLM provider instances."""
|
||||||
|
|
||||||
|
_providers: Dict[str, type] = {
|
||||||
|
"openai": OpenAIProvider,
|
||||||
|
"anthropic": AnthropicProvider,
|
||||||
|
"ollama": OllamaProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str, provider_class: type) -> None:
|
||||||
|
"""Register a new provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider identifier.
|
||||||
|
provider_class: Provider class to register.
|
||||||
|
"""
|
||||||
|
if not issubclass(provider_class, ProviderBase):
|
||||||
|
raise TypeError("Provider must be a subclass of ProviderBase")
|
||||||
|
cls._providers[name.lower()] = provider_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
provider_name: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
) -> ProviderBase:
|
||||||
|
"""Create a provider instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider to create.
|
||||||
|
api_key: API key for the provider.
|
||||||
|
model: Model to use (uses default if not specified).
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
**kwargs: Additional provider-specific options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider instance.
|
||||||
|
"""
|
||||||
|
provider_class = cls._providers.get(provider_name.lower())
|
||||||
|
if provider_class is None:
|
||||||
|
available = ", ".join(cls._providers.keys())
|
||||||
|
raise ProviderError(
|
||||||
|
f"Unknown provider: {provider_name}. Available: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model or getattr(provider_class, "_default_model", "gpt-4"),
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_providers(cls) -> list[str]:
|
||||||
|
"""List available provider names."""
|
||||||
|
return list(cls._providers.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_class(cls, name: str) -> Optional[type]:
|
||||||
|
"""Get provider class by name."""
|
||||||
|
return cls._providers.get(name.lower())
|
||||||
Reference in New Issue
Block a user