60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
"""Backend factory for easy switching between LLM backends."""
|
|
|
|
from typing import Optional
|
|
from .base import LLMBackend
|
|
from .ollama import OllamaBackend
|
|
from .llama_cpp import LlamaCppBackend
|
|
|
|
|
|
class BackendFactory:
|
|
"""Factory for creating LLM backend instances."""
|
|
|
|
@staticmethod
|
|
def create_backend(
|
|
backend_type: str = "ollama",
|
|
host: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
**kwargs,
|
|
) -> LLMBackend:
|
|
"""Create an LLM backend instance.
|
|
|
|
Args:
|
|
backend_type: Type of backend ("ollama" or "llama_cpp").
|
|
host: Backend-specific host/connection info.
|
|
model: Model name/path.
|
|
**kwargs: Additional backend-specific arguments.
|
|
|
|
Returns:
|
|
LLMBackend instance.
|
|
"""
|
|
backend_type = backend_type.lower()
|
|
|
|
if backend_type == "ollama":
|
|
return OllamaBackend(
|
|
host=host or "localhost:11434",
|
|
model=model or "codellama",
|
|
temperature=kwargs.get("temperature", 0.1),
|
|
max_tokens=kwargs.get("max_tokens", 500),
|
|
)
|
|
|
|
elif backend_type == "llama_cpp":
|
|
return LlamaCppBackend(
|
|
model_path=model or "~/.cache/llama-cpp/models/",
|
|
n_ctx=kwargs.get("n_ctx", 2048),
|
|
n_threads=kwargs.get("n_threads", 4),
|
|
temperature=kwargs.get("temperature", 0.1),
|
|
max_tokens=kwargs.get("max_tokens", 500),
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
|
|
|
@staticmethod
|
|
def get_available_backends() -> list[str]:
|
|
"""Get list of available backend types.
|
|
|
|
Returns:
|
|
List of backend type strings.
|
|
"""
|
|
return ["ollama", "llama_cpp"]
|