142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
"""Tests for ollama_client.py."""
|
|
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from src.ollama_client import (
|
|
OllamaClient,
|
|
OllamaError,
|
|
OllamaConnectionError,
|
|
OllamaModelNotFoundError,
|
|
OllamaModel,
|
|
)
|
|
from src.config import Config
|
|
|
|
|
|
class TestOllamaClient:
|
|
"""Tests for OllamaClient class."""
|
|
|
|
@pytest.fixture
|
|
def mock_client(self):
|
|
"""Create a mock Ollama client."""
|
|
with patch("src.ollama_client.ollama.Client") as mock:
|
|
yield mock
|
|
|
|
@pytest.fixture
|
|
def config(self):
|
|
"""Create a test configuration."""
|
|
return Config(
|
|
ollama_host="http://localhost:11434",
|
|
model="llama3.2",
|
|
)
|
|
|
|
@pytest.fixture
|
|
def ollama_client(self, config, mock_client):
|
|
"""Create an OllamaClient with mocked dependencies."""
|
|
client = OllamaClient(config)
|
|
return client
|
|
|
|
def test_client_initialization(self, config):
|
|
"""Test client initialization."""
|
|
client = OllamaClient(config)
|
|
assert client.config.ollama_host == "http://localhost:11434"
|
|
assert client.config.model == "llama3.2"
|
|
|
|
def test_connect_with_retry_success(self, ollama_client, mock_client):
|
|
"""Test successful connection with retry."""
|
|
mock_response = Mock()
|
|
mock_response.models = []
|
|
mock_client.return_value.ps.return_value = mock_response
|
|
|
|
result = ollama_client.connect_with_retry(max_retries=1)
|
|
assert result is True
|
|
|
|
def test_connect_with_retry_failure(self, ollama_client, mock_client):
|
|
"""Test failed connection with retry."""
|
|
import ollama
|
|
mock_client.return_value.ps.side_effect = ollama.RequestError("Connection refused")
|
|
|
|
with pytest.raises(OllamaConnectionError):
|
|
ollama_client.connect_with_retry(max_retries=1)
|
|
|
|
def test_list_models(self, ollama_client, mock_client):
|
|
"""Test listing available models."""
|
|
mock_model = Mock()
|
|
mock_model.name = "llama3.2"
|
|
mock_model.size = 4000000000
|
|
mock_model.digest = "abc123"
|
|
mock_model.modified_at = "2024-01-15T10:00:00"
|
|
|
|
mock_response = Mock()
|
|
mock_response.models = [mock_model]
|
|
mock_client.return_value.ps.return_value = mock_response
|
|
|
|
models = ollama_client.list_models()
|
|
assert len(models) == 1
|
|
assert models[0].name == "llama3.2"
|
|
|
|
def test_generate_commit_message(self, ollama_client, mock_client):
|
|
"""Test generating a commit message."""
|
|
mock_response = Mock()
|
|
mock_response.message.content = "feat(api): add new endpoint"
|
|
mock_client.return_value.chat.return_value = mock_response
|
|
|
|
result = ollama_client.generate_commit_message("some diff content")
|
|
assert result == "feat(api): add new endpoint"
|
|
|
|
def test_generate_commit_message_with_language(self, ollama_client, mock_client):
|
|
"""Test generating a commit message with language context."""
|
|
mock automated testing
|
|
mock_response = Mock()
|
|
mock_response.message.content = "feat(python): add helper function"
|
|
mock_client.return_value.chat.return_value = mock_response
|
|
|
|
result = ollama_client.generate_commit_message("diff content", language="python")
|
|
assert "python" in str(mock_client.return_value.chat.call_args)
|
|
|
|
def test_generate_changelog(self, ollama_client, mock_client):
|
|
"""Test generating a changelog."""
|
|
mock_response = Mock()
|
|
mock_response.message.content = "# Changelog\n\n## Features\n- New feature added"
|
|
mock_client.return_value.chat.return_value = mock_response
|
|
|
|
commits = [
|
|
{"sha": "abc123", "message": "feat: new feature", "author": "Test User"}
|
|
]
|
|
result = ollama_client.generate_changelog(commits)
|
|
assert "Changelog" in result
|
|
|
|
def test_generate_api_docs(self, ollama_client, mock_client):
|
|
"""Test generating API documentation."""
|
|
mock_response = Mock()
|
|
mock_response.message.content = "# API Documentation\n\n## GET /users\nReturns list of users"
|
|
mock_client.return_value.chat.return_value = mock_response
|
|
|
|
code_changes = {"src/api.py": "some diff content"}
|
|
result = ollama_client.generate_api_docs(code_changes)
|
|
assert "API Documentation" in result
|
|
|
|
def test_model_not_found_error(self, ollama_client, mock_client):
|
|
"""Test handling model not found error."""
|
|
import ollama
|
|
mock_client.return_value.chat.side_effect = ollama.ResponseError("model not found")
|
|
|
|
with pytest.raises(OllamaModelNotFoundError):
|
|
ollama_client.generate_commit_message("diff content")
|
|
|
|
|
|
class TestOllamaModel:
|
|
"""Tests for OllamaModel dataclass."""
|
|
|
|
def test_ollama_model_creation(self):
|
|
"""Test creating an OllamaModel object."""
|
|
model = OllamaModel(
|
|
name="llama3.2",
|
|
size=4000000000,
|
|
digest="abc123",
|
|
modified_at="2024-01-15T10:00:00",
|
|
)
|
|
assert model.name == "llama3.2"
|
|
assert model.size == 4000000000
|