Add test suite
This commit is contained in:
57
tests/test_providers.py
Normal file
57
tests/test_providers.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||||
|
|
||||||
|
from promptforge.providers.factory import ProviderFactory
|
||||||
|
from promptforge.providers.openai import OpenAIProvider
|
||||||
|
from promptforge.providers.anthropic import AnthropicProvider
|
||||||
|
from promptforge.providers.ollama import OllamaProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderFactory:
|
||||||
|
def test_create_openai(self):
|
||||||
|
provider = ProviderFactory.create("openai", model="gpt-4")
|
||||||
|
assert isinstance(provider, OpenAIProvider)
|
||||||
|
assert provider.model == "gpt-4"
|
||||||
|
|
||||||
|
def test_create_anthropic(self):
|
||||||
|
provider = ProviderFactory.create("anthropic", model="claude-3")
|
||||||
|
assert isinstance(provider, AnthropicProvider)
|
||||||
|
|
||||||
|
def test_create_ollama(self):
|
||||||
|
provider = ProviderFactory.create("ollama", model="llama2")
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
|
||||||
|
def test_create_unknown(self):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
ProviderFactory.create("unknown")
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIProvider:
|
||||||
|
def test_provider_name(self):
|
||||||
|
provider = OpenAIProvider(api_key="test-key")
|
||||||
|
assert provider.name == "openai"
|
||||||
|
|
||||||
|
def test_provider_defaults(self):
|
||||||
|
provider = OpenAIProvider()
|
||||||
|
assert provider.model == "gpt-4"
|
||||||
|
assert provider.temperature == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicProvider:
|
||||||
|
def test_provider_name(self):
|
||||||
|
provider = AnthropicProvider(api_key="test-key")
|
||||||
|
assert provider.name == "anthropic"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaProvider:
|
||||||
|
def test_provider_name(self):
|
||||||
|
provider = OllamaProvider()
|
||||||
|
assert provider.name == "ollama"
|
||||||
|
|
||||||
|
def test_provider_defaults(self):
|
||||||
|
provider = OllamaProvider(model="llama2")
|
||||||
|
assert provider.model == "llama2"
|
||||||
Reference in New Issue
Block a user