diff --git a/app/shellgen/backends/factory.py b/app/shellgen/backends/factory.py new file mode 100644 index 0000000..896b047 --- /dev/null +++ b/app/shellgen/backends/factory.py @@ -0,0 +1,59 @@ +"""Backend factory for easy switching between LLM backends.""" + +from typing import Optional +from .base import LLMBackend +from .ollama import OllamaBackend +from .llama_cpp import LlamaCppBackend + + +class BackendFactory: + """Factory for creating LLM backend instances.""" + + @staticmethod + def create_backend( + backend_type: str = "ollama", + host: Optional[str] = None, + model: Optional[str] = None, + **kwargs, + ) -> LLMBackend: + """Create an LLM backend instance. + + Args: + backend_type: Type of backend ("ollama" or "llama_cpp"). + host: Backend-specific host/connection info. + model: Model name/path. + **kwargs: Additional backend-specific arguments. + + Returns: + LLMBackend instance. + """ + backend_type = backend_type.lower() + + if backend_type == "ollama": + return OllamaBackend( + host=host or "localhost:11434", + model=model or "codellama", + temperature=kwargs.get("temperature", 0.1), + max_tokens=kwargs.get("max_tokens", 500), + ) + + elif backend_type == "llama_cpp": + return LlamaCppBackend( + model_path=model or "~/.cache/llama-cpp/models/", + n_ctx=kwargs.get("n_ctx", 2048), + n_threads=kwargs.get("n_threads", 4), + temperature=kwargs.get("temperature", 0.1), + max_tokens=kwargs.get("max_tokens", 500), + ) + + else: + raise ValueError(f"Unknown backend type: {backend_type}") + + @staticmethod + def get_available_backends() -> list[str]: + """Get list of available backend types. + + Returns: + List of backend type strings. + """ + return ["ollama", "llama_cpp"]