diff --git a/app/src/promptforge/providers/factory.py b/app/src/promptforge/providers/factory.py new file mode 100644 index 0000000..a9f3004 --- /dev/null +++ b/app/src/promptforge/providers/factory.py @@ -0,0 +1,74 @@ +from typing import Dict, Optional + +from .base import ProviderBase +from .openai import OpenAIProvider +from .anthropic import AnthropicProvider +from .ollama import OllamaProvider +from ..core.exceptions import ProviderError + + +class ProviderFactory: + """Factory for creating LLM provider instances.""" + + _providers: Dict[str, type] = { + "openai": OpenAIProvider, + "anthropic": AnthropicProvider, + "ollama": OllamaProvider, + } + + @classmethod + def register(cls, name: str, provider_class: type) -> None: + """Register a new provider. + + Args: + name: Provider identifier. + provider_class: Provider class to register. + """ + if not issubclass(provider_class, ProviderBase): + raise TypeError("Provider must be a subclass of ProviderBase") + cls._providers[name.lower()] = provider_class + + @classmethod + def create( + cls, + provider_name: str, + api_key: Optional[str] = None, + model: Optional[str] = None, + temperature: float = 0.7, + **kwargs, + ) -> ProviderBase: + """Create a provider instance. + + Args: + provider_name: Name of the provider to create. + api_key: API key for the provider. + model: Model to use (uses default if not specified). + temperature: Sampling temperature. + **kwargs: Additional provider-specific options. + + Returns: + Provider instance. + """ + provider_class = cls._providers.get(provider_name.lower()) + if provider_class is None: + available = ", ".join(cls._providers.keys()) + raise ProviderError( + f"Unknown provider: {provider_name}. Available: {available}" + ) + + return provider_class( + api_key=api_key, + model=model or getattr(provider_class, "_default_model", "gpt-4"), + temperature=temperature, + **kwargs, + ) + + @classmethod + def list_providers(cls) -> list[str]: + """List available provider names.""" + return list(cls._providers.keys()) + + @classmethod + def get_provider_class(cls, name: str) -> Optional[type]: + """Get provider class by name.""" + return cls._providers.get(name.lower())