diff --git a/src/promptforge/providers/factory.py b/src/promptforge/providers/factory.py index 9e04515..c8919de 100644 --- a/src/promptforge/providers/factory.py +++ b/src/promptforge/providers/factory.py @@ -1,4 +1,8 @@ -from typing import Optional +"""Provider factory for instantiating LLM providers.""" + +from typing import Dict, Optional + +from .base import ProviderBase from .openai import OpenAIProvider from .anthropic import AnthropicProvider from .ollama import OllamaProvider @@ -6,21 +10,67 @@ from ..core.exceptions import ProviderError class ProviderFactory: - @staticmethod + """Factory for creating LLM provider instances.""" + + _providers: Dict[str, type] = { + "openai": OpenAIProvider, + "anthropic": AnthropicProvider, + "ollama": OllamaProvider, + } + + @classmethod + def register(cls, name: str, provider_class: type) -> None: + """Register a new provider. + + Args: + name: Provider identifier. + provider_class: Provider class to register. + """ + if not issubclass(provider_class, ProviderBase): + raise TypeError("Provider must be a subclass of ProviderBase") + cls._providers[name.lower()] = provider_class + + @classmethod def create( + cls, provider_name: str, + api_key: Optional[str] = None, model: Optional[str] = None, temperature: float = 0.7, - api_key: Optional[str] = None, **kwargs, - ): - provider_name = provider_name.lower() + ) -> ProviderBase: + """Create a provider instance. - if provider_name in ("openai", "gpt-4", "gpt-3.5"): - return OpenAIProvider(api_key=api_key, model=model or "gpt-4", temperature=temperature) - elif provider_name in ("anthropic", "claude"): - return AnthropicProvider(api_key=api_key, model=model or "claude-3-sonnet-20240229", temperature=temperature) - elif provider_name in ("ollama", "local"): - return OllamaProvider(model=model or "llama2", temperature=temperature, **kwargs) - else: - raise ProviderError(f"Unknown provider: {provider_name}") \ No newline at end of file + Args: + provider_name: Name of the provider to create. + api_key: API key for the provider. + model: Model to use (uses default if not specified). + temperature: Sampling temperature. + **kwargs: Additional provider-specific options. + + Returns: + Provider instance. + """ + provider_class = cls._providers.get(provider_name.lower()) + if provider_class is None: + available = ", ".join(cls._providers.keys()) + raise ProviderError( + f"Unknown provider: {provider_name}. Available: {available}" + ) + + return provider_class( + api_key=api_key, + model=model or getattr(provider_class, "_default_model", "gpt-4"), + temperature=temperature, + **kwargs, + ) + + @classmethod + def list_providers(cls) -> list[str]: + """List available provider names.""" + return list(cls._providers.keys()) + + @classmethod + def get_provider_class(cls, name: str) -> Optional[type]: + """Get provider class by name.""" + return cls._providers.get(name.lower())