Files
promptforge/app/src/promptforge/providers/factory.py
7000pctAUTO 0993900953
Some checks failed
CI / test (push) Has been cancelled
fix: resolve CI linting and type errors
2026-02-04 12:49:06 +00:00

75 lines
2.3 KiB
Python

from typing import Dict, Optional
from .base import ProviderBase
from .openai import OpenAIProvider
from .anthropic import AnthropicProvider
from .ollama import OllamaProvider
from ..core.exceptions import ProviderError
class ProviderFactory:
"""Factory for creating LLM provider instances."""
_providers: Dict[str, type] = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"ollama": OllamaProvider,
}
@classmethod
def register(cls, name: str, provider_class: type) -> None:
"""Register a new provider.
Args:
name: Provider identifier.
provider_class: Provider class to register.
"""
if not issubclass(provider_class, ProviderBase):
raise TypeError("Provider must be a subclass of ProviderBase")
cls._providers[name.lower()] = provider_class
@classmethod
def create(
cls,
provider_name: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
temperature: float = 0.7,
**kwargs,
) -> ProviderBase:
"""Create a provider instance.
Args:
provider_name: Name of the provider to create.
api_key: API key for the provider.
model: Model to use (uses default if not specified).
temperature: Sampling temperature.
**kwargs: Additional provider-specific options.
Returns:
Provider instance.
"""
provider_class = cls._providers.get(provider_name.lower())
if provider_class is None:
available = ", ".join(cls._providers.keys())
raise ProviderError(
f"Unknown provider: {provider_name}. Available: {available}"
)
return provider_class(
api_key=api_key,
model=model or getattr(provider_class, "_default_model", "gpt-4"),
temperature=temperature,
**kwargs,
)
@classmethod
def list_providers(cls) -> list[str]:
"""List available provider names."""
return list(cls._providers.keys())
@classmethod
def get_provider_class(cls, name: str) -> Optional[type]:
"""Get provider class by name."""
return cls._providers.get(name.lower())