This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
"""Tests for Ollama client module."""
|
"""Tests for Ollama client module."""
|
||||||
|
|
||||||
import pytest
|
from unittest.mock import patch, MagicMock
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
from shellgenius.ollama_client import OllamaClient, get_ollama_client
|
from shellgenius.ollama_client import OllamaClient, get_ollama_client
|
||||||
|
|
||||||
@@ -9,10 +8,13 @@ from shellgenius.ollama_client import OllamaClient, get_ollama_client
|
|||||||
class TestOllamaClient:
|
class TestOllamaClient:
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
"""Test client initialization."""
|
"""Test client initialization."""
|
||||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
mock_config = MagicMock()
|
||||||
mock_config.return_value.ollama_host = "localhost:11434"
|
mock_config.get.side_effect = lambda key, default=None: {
|
||||||
mock_config.return_value.ollama_model = "codellama"
|
"ollama_host": "localhost:11434",
|
||||||
|
"ollama_model": "codellama",
|
||||||
|
}.get(key, default)
|
||||||
|
|
||||||
|
with patch('shellgenius.ollama_client.get_config', return_value=mock_config):
|
||||||
client = OllamaClient()
|
client = OllamaClient()
|
||||||
|
|
||||||
assert client.host == "localhost:11434"
|
assert client.host == "localhost:11434"
|
||||||
@@ -20,21 +22,27 @@ class TestOllamaClient:
|
|||||||
|
|
||||||
def test_is_available(self):
|
def test_is_available(self):
|
||||||
"""Test availability check."""
|
"""Test availability check."""
|
||||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
mock_config = MagicMock()
|
||||||
mock_config.return_value.ollama_host = "localhost:11434"
|
mock_config.get.side_effect = lambda key, default=None: {
|
||||||
mock_config.return_value.ollama_model = "codellama"
|
"ollama_host": "localhost:11434",
|
||||||
|
"ollama_model": "codellama",
|
||||||
|
}.get(key, default)
|
||||||
|
|
||||||
|
with patch('shellgenius.ollama_client.get_config', return_value=mock_config):
|
||||||
client = OllamaClient()
|
client = OllamaClient()
|
||||||
|
|
||||||
with patch.object(client, 'list_models', return_value=["codellama"]):
|
with patch.object(client, 'list_models', return_value=["codellama"]):
|
||||||
assert client.is_available() == True
|
assert client.is_available() is True
|
||||||
|
|
||||||
def test_list_models(self):
|
def test_list_models(self):
|
||||||
"""Test listing models."""
|
"""Test listing models."""
|
||||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
mock_config = MagicMock()
|
||||||
mock_config.return_value.ollama_host = "localhost:11434"
|
mock_config.get.side_effect = lambda key, default=None: {
|
||||||
mock_config.return_value.ollama_model = "codellama"
|
"ollama_host": "localhost:11434",
|
||||||
|
"ollama_model": "codellama",
|
||||||
|
}.get(key, default)
|
||||||
|
|
||||||
|
with patch('shellgenius.ollama_client.get_config', return_value=mock_config):
|
||||||
client = OllamaClient()
|
client = OllamaClient()
|
||||||
|
|
||||||
mock_response = {"models": [{"name": "codellama"}, {"name": "llama2"}]}
|
mock_response = {"models": [{"name": "codellama"}, {"name": "llama2"}]}
|
||||||
@@ -47,10 +55,13 @@ class TestOllamaClient:
|
|||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
"""Test text generation."""
|
"""Test text generation."""
|
||||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
mock_config = MagicMock()
|
||||||
mock_config.return_value.ollama_host = "localhost:11434"
|
mock_config.get.side_effect = lambda key, default=None: {
|
||||||
mock_config.return_value.ollama_model = "codellama"
|
"ollama_host": "localhost:11434",
|
||||||
|
"ollama_model": "codellama",
|
||||||
|
}.get(key, default)
|
||||||
|
|
||||||
|
with patch('shellgenius.ollama_client.get_config', return_value=mock_config):
|
||||||
client = OllamaClient()
|
client = OllamaClient()
|
||||||
|
|
||||||
mock_response = {"response": "Generated text"}
|
mock_response = {"response": "Generated text"}
|
||||||
@@ -58,7 +69,7 @@ class TestOllamaClient:
|
|||||||
with patch.object(client.client, 'generate', return_value=mock_response):
|
with patch.object(client.client, 'generate', return_value=mock_response):
|
||||||
result = client.generate("test prompt")
|
result = client.generate("test prompt")
|
||||||
|
|
||||||
assert result["success"] == True
|
assert result["success"] is True
|
||||||
assert "Generated text" in str(result["response"])
|
assert "Generated text" in str(result["response"])
|
||||||
|
|
||||||
|
|
||||||
@@ -66,8 +77,7 @@ class TestGetOllamaClient:
|
|||||||
def test_convenience_function(self):
|
def test_convenience_function(self):
|
||||||
"""Test the convenience function for getting client."""
|
"""Test the convenience function for getting client."""
|
||||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
||||||
mock_config.return_value.ollama_host = "localhost:11434"
|
mock_config.return_value = {}
|
||||||
mock_config.return_value.ollama_model = "custom-model"
|
|
||||||
|
|
||||||
client = get_ollama_client(host="custom:9999", model="custom-model")
|
client = get_ollama_client(host="custom:9999", model="custom-model")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user