This commit is contained in:
184
local-llm-prompt-manager/tests/test_llm_clients.py
Normal file
184
local-llm-prompt-manager/tests/test_llm_clients.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for LLM clients."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from src.llm.llm_factory import LLMClientFactory
|
||||
from src.llm.lmstudio import LMStudioClient
|
||||
from src.llm.ollama import OllamaClient
|
||||
|
||||
|
||||
class TestOllamaClient:
|
||||
"""Test cases for OllamaClient."""
|
||||
|
||||
def test_client_creation(self):
|
||||
"""Test client creation with default URL."""
|
||||
client = OllamaClient()
|
||||
assert client.url == "http://localhost:11434"
|
||||
|
||||
def test_client_creation_with_custom_url(self):
|
||||
"""Test client creation with custom URL."""
|
||||
client = OllamaClient(url="http://custom:9000")
|
||||
assert client.url == "http://custom:9000"
|
||||
|
||||
@patch('src.llm.ollama.requests.post')
|
||||
def test_generate(self, mock_post):
|
||||
"""Test generating a response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"response": "Hello, World!"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.generate("Hello")
|
||||
|
||||
assert result == "Hello, World!"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch('src.llm.ollama.requests.post')
|
||||
def test_generate_with_model(self, mock_post):
|
||||
"""Test generating with custom model."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"response": "Test"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = OllamaClient()
|
||||
client.generate("Test prompt", model="custom-model")
|
||||
|
||||
call_args = mock_post.call_args
|
||||
assert "custom-model" in str(call_args)
|
||||
|
||||
@patch('src.llm.ollama.requests.get')
|
||||
def test_test_connection_success(self, mock_get):
|
||||
"""Test successful connection test."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
client = OllamaClient()
|
||||
assert client.test_connection() is True
|
||||
|
||||
@patch('src.llm.ollama.requests.get')
|
||||
def test_test_connection_failure(self, mock_get):
|
||||
"""Test failed connection test."""
|
||||
import requests
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError()
|
||||
|
||||
client = OllamaClient()
|
||||
assert client.test_connection() is False
|
||||
|
||||
@patch('src.llm.ollama.requests.get')
|
||||
def test_get_available_models(self, mock_get):
|
||||
"""Test getting available models."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"models": [
|
||||
{"name": "llama3.2"},
|
||||
{"name": "codellama"}
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
client = OllamaClient()
|
||||
models = client.get_available_models()
|
||||
|
||||
assert "llama3.2" in models
|
||||
assert "codellama" in models
|
||||
|
||||
|
||||
class TestLMStudioClient:
|
||||
"""Test cases for LMStudioClient."""
|
||||
|
||||
def test_client_creation(self):
|
||||
"""Test client creation with default URL."""
|
||||
client = LMStudioClient()
|
||||
assert client.url == "http://localhost:1234"
|
||||
|
||||
def test_client_creation_with_custom_url(self):
|
||||
"""Test client creation with custom URL."""
|
||||
client = LMStudioClient(url="http://custom:9000")
|
||||
assert client.url == "http://custom:9000"
|
||||
|
||||
@patch('src.llm.lmstudio.requests.post')
|
||||
def test_generate(self, mock_post):
|
||||
"""Test generating a response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"text": "Generated response"}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LMStudioClient()
|
||||
result = client.generate("Hello")
|
||||
|
||||
assert result == "Generated response"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch('src.llm.lmstudio.requests.get')
|
||||
def test_test_connection_success(self, mock_get):
|
||||
"""Test successful connection test."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
client = LMStudioClient()
|
||||
assert client.test_connection() is True
|
||||
|
||||
@patch('src.llm.lmstudio.requests.get')
|
||||
def test_test_connection_failure(self, mock_get):
|
||||
"""Test failed connection test."""
|
||||
import requests
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError()
|
||||
|
||||
client = LMStudioClient()
|
||||
assert client.test_connection() is False
|
||||
|
||||
|
||||
class TestLLMClientFactory:
|
||||
"""Test cases for LLMClientFactory."""
|
||||
|
||||
@patch('src.config.get_config')
|
||||
def test_create_ollama(self, mock_config):
|
||||
"""Test creating Ollama client."""
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.default_provider = "ollama"
|
||||
mock_cfg.ollama_url = "http://localhost:11434"
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
client = LLMClientFactory.create()
|
||||
assert isinstance(client, OllamaClient)
|
||||
|
||||
@patch('src.config.get_config')
|
||||
def test_create_lmstudio(self, mock_config):
|
||||
"""Test creating LM Studio client."""
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.default_provider = "lmstudio"
|
||||
mock_cfg.lmstudio_url = "http://localhost:1234"
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
client = LLMClientFactory.create()
|
||||
assert isinstance(client, LMStudioClient)
|
||||
|
||||
@patch('src.config.get_config')
|
||||
def test_create_with_provider(self, mock_config):
|
||||
"""Test creating client with explicit provider."""
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.ollama_url = "http://localhost:11434"
|
||||
mock_cfg.lmstudio_url = "http://localhost:1234"
|
||||
mock_cfg.default_provider = "lmstudio"
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
client = LLMClientFactory.create(provider="ollama")
|
||||
assert isinstance(client, OllamaClient)
|
||||
|
||||
def test_list_providers(self):
|
||||
"""Test listing available providers."""
|
||||
providers = LLMClientFactory.list_providers()
|
||||
assert "ollama" in providers
|
||||
assert "lmstudio" in providers
|
||||
|
||||
def test_create_unknown_provider(self):
|
||||
"""Test creating client with unknown provider."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
LLMClientFactory.create(provider="unknown")
|
||||
assert "Unknown provider" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user