From 578edafab3002b6da6109a5b23471fe033660abf Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:49:15 +0000 Subject: [PATCH] fix: resolve CI linting and type errors --- app/tests/test_providers.py | 137 ++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 app/tests/test_providers.py diff --git a/app/tests/test_providers.py b/app/tests/test_providers.py new file mode 100644 index 0000000..c8b303c --- /dev/null +++ b/app/tests/test_providers.py @@ -0,0 +1,137 @@ +import pytest +from unittest.mock import patch + +from promptforge.providers.base import ProviderBase, ProviderResponse +from promptforge.providers.factory import ProviderFactory +from promptforge.providers.openai import OpenAIProvider +from promptforge.providers.anthropic import AnthropicProvider +from promptforge.providers.ollama import OllamaProvider +from promptforge.core.exceptions import ProviderError + + +class TestProviderBase: + """Tests for ProviderBase abstract class.""" + + def test_response_creation(self): + """Test ProviderResponse creation.""" + response = ProviderResponse( + content="Hello", + model="gpt-4", + provider="openai", + usage={"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}, + latency_ms=100.5, + ) + assert response.content == "Hello" + assert response.usage["total_tokens"] == 8 + + def test_provider_requires_implementation(self): + """Test that ProviderBase requires implementation.""" + with pytest.raises(TypeError): + _ = ProviderBase() + + +class TestProviderFactory: + """Tests for ProviderFactory.""" + + def test_list_providers(self): + """Test listing available providers.""" + providers = ProviderFactory.list_providers() + assert "openai" in providers + assert "anthropic" in providers + assert "ollama" in providers + + def test_create_openai(self): + """Test creating OpenAI provider.""" + provider = ProviderFactory.create("openai", model="gpt-4") + assert isinstance(provider, OpenAIProvider) + assert provider.model == "gpt-4" + + def test_create_anthropic(self): + """Test creating Anthropic provider.""" + provider = ProviderFactory.create("anthropic", model="claude-3") + assert isinstance(provider, AnthropicProvider) + assert provider.model == "claude-3" + + def test_create_ollama(self): + """Test creating Ollama provider.""" + provider = ProviderFactory.create("ollama", model="llama2") + assert isinstance(provider, OllamaProvider) + assert provider.model == "llama2" + + def test_create_unknown_provider(self): + """Test creating unknown provider raises error.""" + with pytest.raises(ProviderError): + ProviderFactory.create("unknown") + + def test_provider_temperature(self): + """Test provider temperature setting.""" + provider = ProviderFactory.create("openai", temperature=0.5) + assert provider.temperature == 0.5 + + +class TestOpenAIProvider: + """Tests for OpenAIProvider.""" + + def test_provider_name(self): + """Test provider name.""" + provider = OpenAIProvider() + assert provider.name == "openai" + + def test_list_models(self): + """Test listing available models.""" + provider = OpenAIProvider() + models = provider.list_models() + assert "gpt-4" in models + assert "gpt-3.5-turbo" in models + + def test_validate_api_key_missing(self): + """Test API key validation when missing.""" + provider = OpenAIProvider(api_key=None) + with patch.dict('os.environ', {}, clear=True): + assert provider.validate_api_key() is False + + +class TestAnthropicProvider: + """Tests for AnthropicProvider.""" + + def test_provider_name(self): + """Test provider name.""" + provider = AnthropicProvider() + assert provider.name == "anthropic" + + def test_list_models(self): + """Test listing available models.""" + provider = AnthropicProvider() + models = provider.list_models() + assert "claude-3-sonnet-20240229" in models + + def test_validate_api_key_missing(self): + """Test API key validation when missing.""" + provider = AnthropicProvider(api_key=None) + with patch.dict('os.environ', {}, clear=True): + assert provider.validate_api_key() is False + + +class TestOllamaProvider: + """Tests for OllamaProvider.""" + + def test_provider_name(self): + """Test provider name.""" + provider = OllamaProvider() + assert provider.name == "ollama" + + def test_list_models(self): + """Test listing available models.""" + provider = OllamaProvider() + models = provider.list_models() + assert "llama2" in models + + def test_validate_api_key_not_needed(self): + """Test Ollama doesn't require API key.""" + provider = OllamaProvider() + assert provider.validate_api_key() is True + + def test_default_base_url(self): + """Test default Ollama base URL.""" + provider = OllamaProvider() + assert provider.base_url == "http://localhost:11434"