This commit is contained in:
137
app/tests/test_providers.py
Normal file
137
app/tests/test_providers.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from promptforge.providers.base import ProviderBase, ProviderResponse
|
||||||
|
from promptforge.providers.factory import ProviderFactory
|
||||||
|
from promptforge.providers.openai import OpenAIProvider
|
||||||
|
from promptforge.providers.anthropic import AnthropicProvider
|
||||||
|
from promptforge.providers.ollama import OllamaProvider
|
||||||
|
from promptforge.core.exceptions import ProviderError
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderBase:
|
||||||
|
"""Tests for ProviderBase abstract class."""
|
||||||
|
|
||||||
|
def test_response_creation(self):
|
||||||
|
"""Test ProviderResponse creation."""
|
||||||
|
response = ProviderResponse(
|
||||||
|
content="Hello",
|
||||||
|
model="gpt-4",
|
||||||
|
provider="openai",
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8},
|
||||||
|
latency_ms=100.5,
|
||||||
|
)
|
||||||
|
assert response.content == "Hello"
|
||||||
|
assert response.usage["total_tokens"] == 8
|
||||||
|
|
||||||
|
def test_provider_requires_implementation(self):
|
||||||
|
"""Test that ProviderBase requires implementation."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
_ = ProviderBase()
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderFactory:
|
||||||
|
"""Tests for ProviderFactory."""
|
||||||
|
|
||||||
|
def test_list_providers(self):
|
||||||
|
"""Test listing available providers."""
|
||||||
|
providers = ProviderFactory.list_providers()
|
||||||
|
assert "openai" in providers
|
||||||
|
assert "anthropic" in providers
|
||||||
|
assert "ollama" in providers
|
||||||
|
|
||||||
|
def test_create_openai(self):
|
||||||
|
"""Test creating OpenAI provider."""
|
||||||
|
provider = ProviderFactory.create("openai", model="gpt-4")
|
||||||
|
assert isinstance(provider, OpenAIProvider)
|
||||||
|
assert provider.model == "gpt-4"
|
||||||
|
|
||||||
|
def test_create_anthropic(self):
|
||||||
|
"""Test creating Anthropic provider."""
|
||||||
|
provider = ProviderFactory.create("anthropic", model="claude-3")
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
assert provider.model == "claude-3"
|
||||||
|
|
||||||
|
def test_create_ollama(self):
|
||||||
|
"""Test creating Ollama provider."""
|
||||||
|
provider = ProviderFactory.create("ollama", model="llama2")
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
assert provider.model == "llama2"
|
||||||
|
|
||||||
|
def test_create_unknown_provider(self):
|
||||||
|
"""Test creating unknown provider raises error."""
|
||||||
|
with pytest.raises(ProviderError):
|
||||||
|
ProviderFactory.create("unknown")
|
||||||
|
|
||||||
|
def test_provider_temperature(self):
|
||||||
|
"""Test provider temperature setting."""
|
||||||
|
provider = ProviderFactory.create("openai", temperature=0.5)
|
||||||
|
assert provider.temperature == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIProvider:
|
||||||
|
"""Tests for OpenAIProvider."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name."""
|
||||||
|
provider = OpenAIProvider()
|
||||||
|
assert provider.name == "openai"
|
||||||
|
|
||||||
|
def test_list_models(self):
|
||||||
|
"""Test listing available models."""
|
||||||
|
provider = OpenAIProvider()
|
||||||
|
models = provider.list_models()
|
||||||
|
assert "gpt-4" in models
|
||||||
|
assert "gpt-3.5-turbo" in models
|
||||||
|
|
||||||
|
def test_validate_api_key_missing(self):
|
||||||
|
"""Test API key validation when missing."""
|
||||||
|
provider = OpenAIProvider(api_key=None)
|
||||||
|
with patch.dict('os.environ', {}, clear=True):
|
||||||
|
assert provider.validate_api_key() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicProvider:
|
||||||
|
"""Tests for AnthropicProvider."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name."""
|
||||||
|
provider = AnthropicProvider()
|
||||||
|
assert provider.name == "anthropic"
|
||||||
|
|
||||||
|
def test_list_models(self):
|
||||||
|
"""Test listing available models."""
|
||||||
|
provider = AnthropicProvider()
|
||||||
|
models = provider.list_models()
|
||||||
|
assert "claude-3-sonnet-20240229" in models
|
||||||
|
|
||||||
|
def test_validate_api_key_missing(self):
|
||||||
|
"""Test API key validation when missing."""
|
||||||
|
provider = AnthropicProvider(api_key=None)
|
||||||
|
with patch.dict('os.environ', {}, clear=True):
|
||||||
|
assert provider.validate_api_key() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaProvider:
|
||||||
|
"""Tests for OllamaProvider."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.name == "ollama"
|
||||||
|
|
||||||
|
def test_list_models(self):
|
||||||
|
"""Test listing available models."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
models = provider.list_models()
|
||||||
|
assert "llama2" in models
|
||||||
|
|
||||||
|
def test_validate_api_key_not_needed(self):
|
||||||
|
"""Test Ollama doesn't require API key."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.validate_api_key() is True
|
||||||
|
|
||||||
|
def test_default_base_url(self):
|
||||||
|
"""Test default Ollama base URL."""
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.base_url == "http://localhost:11434"
|
||||||
Reference in New Issue
Block a user