diff --git a/src/llm/llm_factory.py b/src/llm/llm_factory.py new file mode 100644 index 0000000..e748b5e --- /dev/null +++ b/src/llm/llm_factory.py @@ -0,0 +1,39 @@ +"""LLM client factory for creating clients based on provider.""" + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .base import LLMClient + + +class LLMClientFactory: + """Factory for creating LLM clients.""" + + @staticmethod + def create(provider: str = None, url: str = None) -> "LLMClient": + """Create an LLM client for the specified provider.""" + from ..config import get_config + from .lmstudio import LMStudioClient + from .ollama import OllamaClient + + config = get_config() + provider = provider or config.default_provider + + if provider.lower() == "ollama": + return OllamaClient(url) + elif provider.lower() == "lmstudio": + return LMStudioClient(url) + else: + raise ValueError(f"Unknown provider: {provider}") + + @staticmethod + def get_default_client() -> "LLMClient": + """Get the default LLM client based on configuration.""" + from ..config import get_config + config = get_config() + return LLMClientFactory.create(config.default_provider) + + @staticmethod + def list_providers() -> list[str]: + """List available providers.""" + return ["ollama", "lmstudio"]