180 lines
5.6 KiB
Python
180 lines
5.6 KiB
Python
"""Tests for LLM backends."""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
|
|
def check_llama_cpp_available():
|
|
"""Check if llama_cpp is available."""
|
|
try:
|
|
import llama_cpp # noqa: F401
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
llama_cpp_available = check_llama_cpp_available()
|
|
|
|
|
|
llama_cpp_available = check_llama_cpp_available()
|
|
|
|
|
|
class TestOllamaBackend:
|
|
"""Tests for Ollama backend."""
|
|
|
|
def test_ollama_backend_initialization(self):
|
|
"""Test OllamaBackend initialization."""
|
|
from shellgen.backends.ollama import OllamaBackend
|
|
|
|
backend = OllamaBackend(
|
|
host="localhost:11434",
|
|
model="codellama",
|
|
temperature=0.1,
|
|
max_tokens=500,
|
|
)
|
|
|
|
assert backend.host == "localhost:11434"
|
|
assert backend._model == "codellama"
|
|
assert backend.temperature == 0.1
|
|
assert backend.max_tokens == 500
|
|
|
|
@patch('shellgen.backends.ollama.ollama_client')
|
|
def test_ollama_generate(self, mock_client):
|
|
"""Test Ollama generate method."""
|
|
from shellgen.backends.ollama import OllamaBackend
|
|
|
|
mock_instance = MagicMock()
|
|
mock_instance.generate.return_value = {"response": "ls -la"}
|
|
mock_client.Client.return_value = mock_instance
|
|
|
|
backend = OllamaBackend()
|
|
result = backend.generate("list files")
|
|
|
|
assert result == "ls -la"
|
|
mock_instance.generate.assert_called_once()
|
|
|
|
def test_ollama_is_available_true(self):
|
|
"""Test is_available when Ollama is running."""
|
|
from shellgen.backends.ollama import OllamaBackend
|
|
|
|
with patch('shellgen.backends.ollama.ollama_client') as mock_client:
|
|
mock_client.Client.return_value.ps.return_value = {}
|
|
|
|
backend = OllamaBackend()
|
|
assert backend.is_available() is True
|
|
|
|
def test_ollama_get_model_name(self):
|
|
"""Test get_model_name returns correct model."""
|
|
from shellgen.backends.ollama import OllamaBackend
|
|
|
|
backend = OllamaBackend(model="llama2")
|
|
assert backend.get_model_name() == "llama2"
|
|
|
|
def test_ollama_set_model(self):
|
|
"""Test set_model changes the model."""
|
|
from shellgen.backends.ollama import OllamaBackend
|
|
|
|
backend = OllamaBackend()
|
|
backend.set_model("mistral")
|
|
assert backend._model == "mistral"
|
|
|
|
|
|
class TestLlamaCppBackend:
|
|
"""Tests for Llama.cpp backend."""
|
|
|
|
def test_llama_cpp_backend_initialization(self):
|
|
"""Test LlamaCppBackend initialization."""
|
|
from shellgen.backends.llama_cpp import LlamaCppBackend
|
|
|
|
backend = LlamaCppBackend(
|
|
model_path="~/.cache/llama-cpp/model.gguf",
|
|
n_ctx=2048,
|
|
n_threads=4,
|
|
temperature=0.1,
|
|
max_tokens=500,
|
|
)
|
|
|
|
assert backend.model_path == "~/.cache/llama-cpp/model.gguf"
|
|
assert backend.n_ctx == 2048
|
|
assert backend.n_threads == 4
|
|
|
|
@pytest.mark.skipif(not llama_cpp_available, reason="llama_cpp not installed")
|
|
@patch('llama_cpp.Llama')
|
|
def test_llama_cpp_generate(self, mock_llama):
|
|
"""Test Llama.cpp generate method."""
|
|
from shellgen.backends.llama_cpp import LlamaCppBackend
|
|
|
|
mock_instance = MagicMock()
|
|
mock_instance.return_value = {
|
|
"choices": [{"text": "find . -name '*.py'"}]
|
|
}
|
|
mock_llama.return_value = mock_instance
|
|
|
|
backend = LlamaCppBackend()
|
|
result = backend.generate("find python files")
|
|
|
|
assert "find" in result
|
|
|
|
def test_llama_cpp_get_model_name(self):
|
|
"""Test get_model_name extracts from path."""
|
|
from shellgen.backends.llama_cpp import LlamaCppBackend
|
|
|
|
backend = LlamaCppBackend(
|
|
model_path="~/.cache/llama-cpp/models/codellama.gguf"
|
|
)
|
|
assert "codellama" in backend.get_model_name()
|
|
|
|
def test_llama_cpp_set_model(self):
|
|
"""Test set_model changes the model path."""
|
|
from shellgen.backends.llama_cpp import LlamaCppBackend
|
|
|
|
backend = LlamaCppBackend()
|
|
backend.set_model("new-model.gguf")
|
|
assert backend.model_path == "new-model.gguf"
|
|
|
|
|
|
class TestBackendFactory:
|
|
"""Tests for backend factory."""
|
|
|
|
def test_factory_create_ollama(self):
|
|
"""Test creating Ollama backend via factory."""
|
|
from shellgen.backends.factory import BackendFactory
|
|
|
|
backend = BackendFactory.create_backend(
|
|
backend_type="ollama",
|
|
host="localhost:11434",
|
|
model="codellama",
|
|
)
|
|
|
|
from shellgen.backends.ollama import OllamaBackend
|
|
assert isinstance(backend, OllamaBackend)
|
|
assert backend._model == "codellama"
|
|
|
|
def test_factory_create_llama_cpp(self):
|
|
"""Test creating Llama.cpp backend via factory."""
|
|
from shellgen.backends.factory import BackendFactory
|
|
|
|
backend = BackendFactory.create_backend(
|
|
backend_type="llama_cpp",
|
|
model="model.gguf",
|
|
)
|
|
|
|
from shellgen.backends.llama_cpp import LlamaCppBackend
|
|
assert isinstance(backend, LlamaCppBackend)
|
|
|
|
def test_factory_get_available_backends(self):
|
|
"""Test getting list of available backends."""
|
|
from shellgen.backends.factory import BackendFactory
|
|
|
|
backends = BackendFactory.get_available_backends()
|
|
|
|
assert "ollama" in backends
|
|
assert "llama_cpp" in backends
|
|
|
|
def test_factory_unknown_backend(self):
|
|
"""Test factory raises error for unknown backend."""
|
|
from shellgen.backends.factory import BackendFactory
|
|
|
|
with pytest.raises(ValueError):
|
|
BackendFactory.create_backend(backend_type="unknown")
|