fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:27 +00:00
parent 8090d3eeba
commit e86adcfbfc

View File

@@ -1,4 +1,8 @@
from typing import Optional """Provider factory for instantiating LLM providers."""
from typing import Dict, Optional
from .base import ProviderBase
from .openai import OpenAIProvider from .openai import OpenAIProvider
from .anthropic import AnthropicProvider from .anthropic import AnthropicProvider
from .ollama import OllamaProvider from .ollama import OllamaProvider
@@ -6,21 +10,67 @@ from ..core.exceptions import ProviderError
class ProviderFactory: class ProviderFactory:
@staticmethod """Factory for creating LLM provider instances."""
_providers: Dict[str, type] = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"ollama": OllamaProvider,
}
@classmethod
def register(cls, name: str, provider_class: type) -> None:
"""Register a new provider.
Args:
name: Provider identifier.
provider_class: Provider class to register.
"""
if not issubclass(provider_class, ProviderBase):
raise TypeError("Provider must be a subclass of ProviderBase")
cls._providers[name.lower()] = provider_class
@classmethod
def create( def create(
cls,
provider_name: str, provider_name: str,
api_key: Optional[str] = None,
model: Optional[str] = None, model: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.7,
api_key: Optional[str] = None,
**kwargs, **kwargs,
): ) -> ProviderBase:
provider_name = provider_name.lower() """Create a provider instance.
if provider_name in ("openai", "gpt-4", "gpt-3.5"): Args:
return OpenAIProvider(api_key=api_key, model=model or "gpt-4", temperature=temperature) provider_name: Name of the provider to create.
elif provider_name in ("anthropic", "claude"): api_key: API key for the provider.
return AnthropicProvider(api_key=api_key, model=model or "claude-3-sonnet-20240229", temperature=temperature) model: Model to use (uses default if not specified).
elif provider_name in ("ollama", "local"): temperature: Sampling temperature.
return OllamaProvider(model=model or "llama2", temperature=temperature, **kwargs) **kwargs: Additional provider-specific options.
else:
raise ProviderError(f"Unknown provider: {provider_name}") Returns:
Provider instance.
"""
provider_class = cls._providers.get(provider_name.lower())
if provider_class is None:
available = ", ".join(cls._providers.keys())
raise ProviderError(
f"Unknown provider: {provider_name}. Available: {available}"
)
return provider_class(
api_key=api_key,
model=model or getattr(provider_class, "_default_model", "gpt-4"),
temperature=temperature,
**kwargs,
)
@classmethod
def list_providers(cls) -> list[str]:
"""List available provider names."""
return list(cls._providers.keys())
@classmethod
def get_provider_class(cls, name: str) -> Optional[type]:
"""Get provider class by name."""
return cls._providers.get(name.lower())