"""Tests for the Ollama client module.""" from unittest.mock import MagicMock, patch import pytest from git_commit_ai.core.ollama_client import OllamaClient, OllamaError, generate_diff_hash class TestOllamaClientBasic: """Basic Ollama client tests.""" def test_init(self, mock_config): """Test client initialization.""" client = OllamaClient(mock_config) assert client.model == "qwen2.5-coder:3b" assert client.base_url == "http://localhost:11434" def test_model_setter(self, mock_config): """Test model setter.""" client = OllamaClient(mock_config) client.model = "llama3:8b" assert client.model == "llama3:8b" def test_base_url_setter(self, mock_config): """Test base URL setter.""" client = OllamaClient(mock_config) client.base_url = "http://localhost:11435" assert client.base_url == "http://localhost:11435" class TestOllamaClientAvailability: """Tests for Ollama availability checks.""" def test_is_available_true(self, mock_config): """Test is_available returns True when server is up.""" with patch('requests.get') as mock_get: mock_response = MagicMock() mock_response.status_code = 200 mock_get.return_value = mock_response client = OllamaClient(mock_config) assert client.is_available() is True def test_is_available_false(self, mock_config): """Test is_available returns False when server is down.""" with patch('requests.get') as mock_get: mock_get.side_effect = Exception("Connection refused") client = OllamaClient(mock_config) assert client.is_available() is False class TestOllamaClientModels: """Tests for model-related functionality.""" def test_list_models(self, mock_config): """Test listing available models.""" with patch('requests.get') as mock_get: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "models": [ {"name": "qwen2.5-coder:3b", "size": 2000000000}, {"name": "llama3:8b", "size": 4000000000}, ] } mock_get.return_value = mock_response client = OllamaClient(mock_config) models = client.list_models() assert len(models) == 2 assert models[0]["name"] == "qwen2.5-coder:3b" def test_check_model_exists_true(self, mock_config): """Test checking if model exists.""" with patch.object(OllamaClient, 'list_models') as mock_list: mock_list.return_value = [{"name": "qwen2.5-coder:3b", "size": 2000000000}] client = OllamaClient(mock_config) assert client.check_model_exists() is True def test_check_model_exists_false(self, mock_config): """Test checking if model doesn't exist.""" with patch.object(OllamaClient, 'list_models') as mock_list: mock_list.return_value = [{"name": "llama3:8b", "size": 4000000000}] client = OllamaClient(mock_config) assert client.check_model_exists() is False class TestOllamaClientGeneration: """Tests for commit message generation.""" def test_parse_commit_message_simple(self, mock_config): """Test parsing a simple commit message.""" client = OllamaClient(mock_config) response = "feat: add new feature" parsed = client._parse_commit_message(response) assert parsed == "feat: add new feature" def test_parse_commit_message_with_quotes(self, mock_config): """Test parsing a quoted commit message.""" client = OllamaClient(mock_config) response = '"feat: add new feature"' parsed = client._parse_commit_message(response) assert parsed == "feat: add new feature" def test_parse_commit_message_truncates_long(self, mock_config): """Test parsing truncates long messages.""" client = OllamaClient(mock_config) long_message = "a" * 100 parsed = client._parse_commit_message(long_message) assert len(parsed) <= 80 class TestGenerateDiffHash: """Tests for generate_diff_hash function.""" def test_generate_diff_hash(self): """Test generating diff hash.""" diff1 = "def hello():\n print('hi')" diff2 = "def hello():\n print('hi')" diff3 = "def goodbye():\n print('bye')" hash1 = generate_diff_hash(diff1) hash2 = generate_diff_hash(diff2) hash3 = generate_diff_hash(diff3) assert hash1 == hash2 assert hash1 != hash3 class TestOllamaError: """Tests for OllamaError exception.""" def test_ollama_error(self): """Test OllamaError is raised correctly.""" with pytest.raises(OllamaError): raise OllamaError("Test error")