diff --git a/tests/test_ollama_client.py b/tests/test_ollama_client.py new file mode 100644 index 0000000..4bbc1a0 --- /dev/null +++ b/tests/test_ollama_client.py @@ -0,0 +1,75 @@ +"""Tests for Ollama client module.""" + +import pytest +from unittest.mock import Mock, patch + +from shellgenius.ollama_client import OllamaClient, get_ollama_client + + +class TestOllamaClient: + def test_init(self): + """Test client initialization.""" + with patch('shellgenius.ollama_client.get_config') as mock_config: + mock_config.return_value.ollama_host = "localhost:11434" + mock_config.return_value.ollama_model = "codellama" + + client = OllamaClient() + + assert client.host == "localhost:11434" + assert client.model == "codellama" + + def test_is_available(self): + """Test availability check.""" + with patch('shellgenius.ollama_client.get_config') as mock_config: + mock_config.return_value.ollama_host = "localhost:11434" + mock_config.return_value.ollama_model = "codellama" + + client = OllamaClient() + + with patch.object(client, 'list_models', return_value=["codellama"]): + assert client.is_available() == True + + def test_list_models(self): + """Test listing models.""" + with patch('shellgenius.ollama_client.get_config') as mock_config: + mock_config.return_value.ollama_host = "localhost:11434" + mock_config.return_value.ollama_model = "codellama" + + client = OllamaClient() + + mock_response = {"models": [{"name": "codellama"}, {"name": "llama2"}]} + + with patch.object(client.client, 'list', return_value=mock_response): + models = client.list_models() + + assert len(models) == 2 + assert "codellama" in models + + def test_generate(self): + """Test text generation.""" + with patch('shellgenius.ollama_client.get_config') as mock_config: + mock_config.return_value.ollama_host = "localhost:11434" + mock_config.return_value.ollama_model = "codellama" + + client = OllamaClient() + + mock_response = {"response": "Generated text"} + + with patch.object(client.client, 'generate', return_value=mock_response): + result = client.generate("test prompt") + + assert result["success"] == True + assert "Generated text" in str(result["response"]) + + +class TestGetOllamaClient: + def test_convenience_function(self): + """Test the convenience function for getting client.""" + with patch('shellgenius.ollama_client.get_config') as mock_config: + mock_config.return_value.ollama_host = "localhost:11434" + mock_config.return_value.ollama_model = "custom-model" + + client = get_ollama_client(host="custom:9999", model="custom-model") + + assert client.host == "custom:9999" + assert client.model == "custom-model"