fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
from typing import Optional
|
||||
"""Provider factory for instantiating LLM providers."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base import ProviderBase
|
||||
from .openai import OpenAIProvider
|
||||
from .anthropic import AnthropicProvider
|
||||
from .ollama import OllamaProvider
|
||||
@@ -6,21 +10,67 @@ from ..core.exceptions import ProviderError
|
||||
|
||||
|
||||
class ProviderFactory:
|
||||
@staticmethod
|
||||
"""Factory for creating LLM provider instances."""
|
||||
|
||||
_providers: Dict[str, type] = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"ollama": OllamaProvider,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, provider_class: type) -> None:
|
||||
"""Register a new provider.
|
||||
|
||||
Args:
|
||||
name: Provider identifier.
|
||||
provider_class: Provider class to register.
|
||||
"""
|
||||
if not issubclass(provider_class, ProviderBase):
|
||||
raise TypeError("Provider must be a subclass of ProviderBase")
|
||||
cls._providers[name.lower()] = provider_class
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
provider_name: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
provider_name = provider_name.lower()
|
||||
) -> ProviderBase:
|
||||
"""Create a provider instance.
|
||||
|
||||
if provider_name in ("openai", "gpt-4", "gpt-3.5"):
|
||||
return OpenAIProvider(api_key=api_key, model=model or "gpt-4", temperature=temperature)
|
||||
elif provider_name in ("anthropic", "claude"):
|
||||
return AnthropicProvider(api_key=api_key, model=model or "claude-3-sonnet-20240229", temperature=temperature)
|
||||
elif provider_name in ("ollama", "local"):
|
||||
return OllamaProvider(model=model or "llama2", temperature=temperature, **kwargs)
|
||||
else:
|
||||
raise ProviderError(f"Unknown provider: {provider_name}")
|
||||
Args:
|
||||
provider_name: Name of the provider to create.
|
||||
api_key: API key for the provider.
|
||||
model: Model to use (uses default if not specified).
|
||||
temperature: Sampling temperature.
|
||||
**kwargs: Additional provider-specific options.
|
||||
|
||||
Returns:
|
||||
Provider instance.
|
||||
"""
|
||||
provider_class = cls._providers.get(provider_name.lower())
|
||||
if provider_class is None:
|
||||
available = ", ".join(cls._providers.keys())
|
||||
raise ProviderError(
|
||||
f"Unknown provider: {provider_name}. Available: {available}"
|
||||
)
|
||||
|
||||
return provider_class(
|
||||
api_key=api_key,
|
||||
model=model or getattr(provider_class, "_default_model", "gpt-4"),
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_providers(cls) -> list[str]:
|
||||
"""List available provider names."""
|
||||
return list(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_class(cls, name: str) -> Optional[type]:
|
||||
"""Get provider class by name."""
|
||||
return cls._providers.get(name.lower())
|
||||
|
||||
Reference in New Issue
Block a user