fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:27 +00:00
parent 8090d3eeba
commit e86adcfbfc

View File

@@ -1,4 +1,8 @@
from typing import Optional
"""Provider factory for instantiating LLM providers."""
from typing import Dict, Optional
from .base import ProviderBase
from .openai import OpenAIProvider
from .anthropic import AnthropicProvider
from .ollama import OllamaProvider
@@ -6,21 +10,67 @@ from ..core.exceptions import ProviderError
class ProviderFactory:
@staticmethod
"""Factory for creating LLM provider instances."""
_providers: Dict[str, type] = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"ollama": OllamaProvider,
}
@classmethod
def register(cls, name: str, provider_class: type) -> None:
"""Register a new provider.
Args:
name: Provider identifier.
provider_class: Provider class to register.
"""
if not issubclass(provider_class, ProviderBase):
raise TypeError("Provider must be a subclass of ProviderBase")
cls._providers[name.lower()] = provider_class
@classmethod
def create(
cls,
provider_name: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
temperature: float = 0.7,
api_key: Optional[str] = None,
**kwargs,
):
provider_name = provider_name.lower()
) -> ProviderBase:
"""Create a provider instance.
if provider_name in ("openai", "gpt-4", "gpt-3.5"):
return OpenAIProvider(api_key=api_key, model=model or "gpt-4", temperature=temperature)
elif provider_name in ("anthropic", "claude"):
return AnthropicProvider(api_key=api_key, model=model or "claude-3-sonnet-20240229", temperature=temperature)
elif provider_name in ("ollama", "local"):
return OllamaProvider(model=model or "llama2", temperature=temperature, **kwargs)
else:
raise ProviderError(f"Unknown provider: {provider_name}")
Args:
provider_name: Name of the provider to create.
api_key: API key for the provider.
model: Model to use (uses default if not specified).
temperature: Sampling temperature.
**kwargs: Additional provider-specific options.
Returns:
Provider instance.
"""
provider_class = cls._providers.get(provider_name.lower())
if provider_class is None:
available = ", ".join(cls._providers.keys())
raise ProviderError(
f"Unknown provider: {provider_name}. Available: {available}"
)
return provider_class(
api_key=api_key,
model=model or getattr(provider_class, "_default_model", "gpt-4"),
temperature=temperature,
**kwargs,
)
@classmethod
def list_providers(cls) -> list[str]:
"""List available provider names."""
return list(cls._providers.keys())
@classmethod
def get_provider_class(cls, name: str) -> Optional[type]:
"""Get provider class by name."""
return cls._providers.get(name.lower())