diff --git a/app/shellgen/backends/llama_cpp.py b/app/shellgen/backends/llama_cpp.py new file mode 100644 index 0000000..ebe2811 --- /dev/null +++ b/app/shellgen/backends/llama_cpp.py @@ -0,0 +1,111 @@ +"""Llama.cpp backend implementation.""" + +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from llama_cpp import Llama + +from .base import LLMBackend + + +class LlamaCppBackend(LLMBackend): + """Llama.cpp Python bindings backend.""" + + def __init__( + self, + model_path: str = "~/.cache/llama-cpp/models/", + n_ctx: int = 2048, + n_threads: int = 4, + temperature: float = 0.1, + max_tokens: int = 500, + ): + """Initialize the Llama.cpp backend. + + Args: + model_path: Path to the model file. + n_ctx: Context window size. + n_threads: Number of threads to use. + temperature: Generation temperature. + max_tokens: Maximum tokens to generate. + """ + self.model_path = model_path + self.n_ctx = n_ctx + self.n_threads = n_threads + self.temperature = temperature + self.max_tokens = max_tokens + self._llm: Optional["Llama"] = None + + def _load_model(self) -> "Llama": + """Load the model if not already loaded. + + Returns: + Loaded Llama instance. + """ + from llama_cpp import Llama as LlamaClass + if self._llm is None: + expanded_path = self.model_path.replace("~", "$HOME") + self._llm = LlamaClass( + model_path=expanded_path, + n_ctx=self.n_ctx, + n_threads=self.n_threads, + temperature=self.temperature, + ) + return self._llm + + def generate(self, prompt: str) -> str: + """Generate response using llama-cpp-python. + + Args: + prompt: The prompt to send. + + Returns: + Generated response text. + """ + try: + llm = self._load_model() + response = llm( + prompt, + max_tokens=self.max_tokens, + temperature=self.temperature, + stop=["", "###"], + ) + + return response["choices"][0]["text"] + + except Exception as e: + raise ConnectionError(f"Llama.cpp error: {e}") + + def is_available(self) -> bool: + """Check if model can be loaded. + + Returns: + True if backend is available. + """ + try: + self._load_model() + return True + except Exception: + return False + + def get_model_name(self) -> str: + """Get the model name from path. + + Returns: + Model name string. + """ + return self.model_path.split("/")[-1] + + def set_model(self, model: str) -> None: + """Set the model path. + + Args: + model: Path to the model. + """ + self.model_path = model + self._llm = None + + def close(self) -> None: + """Clean up model resources.""" + if self._llm is not None: + del self._llm + self._llm = None