This commit is contained in:
239
app/tests/test_testing.py
Normal file
239
app/tests/test_testing.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from promptforge.testing.validator import (
|
||||
RegexValidator,
|
||||
JSONSchemaValidator,
|
||||
LengthValidator,
|
||||
ContainsValidator,
|
||||
CompositeValidator,
|
||||
)
|
||||
from promptforge.testing.metrics import MetricsCollector, MetricsSample
|
||||
from promptforge.testing.results import TestSessionResults, ResultFormatter
|
||||
from promptforge.testing.ab_test import ABTestConfig
|
||||
|
||||
|
||||
class TestRegexValidator:
|
||||
"""Tests for RegexValidator."""
|
||||
|
||||
def test_valid_pattern(self):
|
||||
"""Test matching pattern."""
|
||||
validator = RegexValidator(r"^Hello.*")
|
||||
is_valid, error = validator.validate("Hello, World!")
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
def test_invalid_pattern(self):
|
||||
"""Test non-matching pattern."""
|
||||
validator = RegexValidator(r"^Hello.*")
|
||||
is_valid, error = validator.validate("Goodbye")
|
||||
assert is_valid is False
|
||||
assert error is not None
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Test case insensitive matching."""
|
||||
validator = RegexValidator(r"hello", flags=2) # re.IGNORECASE
|
||||
is_valid, _ = validator.validate("HELLO")
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
class TestJSONSchemaValidator:
|
||||
"""Tests for JSONSchemaValidator."""
|
||||
|
||||
def test_valid_json(self):
|
||||
"""Test valid JSON against schema."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"}
|
||||
}
|
||||
}
|
||||
validator = JSONSchemaValidator(schema)
|
||||
is_valid, error = validator.validate('{"name": "Alice", "age": 30}')
|
||||
assert is_valid is True
|
||||
|
||||
def test_invalid_json(self):
|
||||
"""Test invalid JSON."""
|
||||
validator = JSONSchemaValidator({})
|
||||
is_valid, error = validator.validate("not json")
|
||||
assert is_valid is False
|
||||
assert "JSON" in error
|
||||
|
||||
def test_type_mismatch(self):
|
||||
"""Test type mismatch in schema."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {"type": "integer"}
|
||||
}
|
||||
}
|
||||
validator = JSONSchemaValidator(schema)
|
||||
is_valid, error = validator.validate('{"count": "not a number"}')
|
||||
assert is_valid is False
|
||||
|
||||
|
||||
class TestLengthValidator:
|
||||
"""Tests for LengthValidator."""
|
||||
|
||||
def test_within_bounds(self):
|
||||
"""Test valid length."""
|
||||
validator = LengthValidator(min_length=5, max_length=100)
|
||||
is_valid, error = validator.validate("Hello, World!")
|
||||
assert is_valid is True
|
||||
|
||||
def test_too_short(self):
|
||||
"""Test string too short."""
|
||||
validator = LengthValidator(min_length=10)
|
||||
is_valid, error = validator.validate("Hi")
|
||||
assert is_valid is False
|
||||
assert "short" in error
|
||||
|
||||
def test_too_long(self):
|
||||
"""Test string too long."""
|
||||
validator = LengthValidator(max_length=5)
|
||||
is_valid, error = validator.validate("Hello, World!")
|
||||
assert is_valid is False
|
||||
assert "long" in error
|
||||
|
||||
|
||||
class TestContainsValidator:
|
||||
"""Tests for ContainsValidator."""
|
||||
|
||||
def test_contains_string(self):
|
||||
"""Test string contains required content."""
|
||||
validator = ContainsValidator(required_strings=["hello", "world"])
|
||||
is_valid, error = validator.validate("Say hello to the world")
|
||||
assert is_valid is True
|
||||
|
||||
def test_missing_content(self):
|
||||
"""Test missing required content."""
|
||||
validator = ContainsValidator(required_strings=["hello", "world"])
|
||||
is_valid, error = validator.validate("Just some random text")
|
||||
assert is_valid is False
|
||||
|
||||
|
||||
class TestCompositeValidator:
|
||||
"""Tests for CompositeValidator."""
|
||||
|
||||
def test_all_mode(self):
|
||||
"""Test AND mode validation."""
|
||||
validators = [
|
||||
RegexValidator(r"^Hello.*"),
|
||||
LengthValidator(min_length=5),
|
||||
]
|
||||
composite = CompositeValidator(validators, mode="all")
|
||||
is_valid, _ = composite.validate("Hello, World!")
|
||||
assert is_valid is True
|
||||
|
||||
def test_any_mode(self):
|
||||
"""Test OR mode validation."""
|
||||
validators = [
|
||||
RegexValidator(r"^Hello.*"),
|
||||
RegexValidator(r"^Goodbye.*"),
|
||||
]
|
||||
composite = CompositeValidator(validators, mode="any")
|
||||
is_valid, _ = composite.validate("Goodbye, World!")
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
class TestMetricsCollector:
|
||||
"""Tests for MetricsCollector."""
|
||||
|
||||
def test_record_sample(self):
|
||||
"""Test recording metrics sample."""
|
||||
collector = MetricsCollector()
|
||||
sample = MetricsSample(
|
||||
latency_ms=100.0,
|
||||
tokens_total=50,
|
||||
validation_passed=True,
|
||||
)
|
||||
collector.record(sample)
|
||||
|
||||
summary = collector.get_summary()
|
||||
assert summary.count == 1
|
||||
assert summary.latency["avg"] == 100.0
|
||||
|
||||
def test_record_from_response(self):
|
||||
"""Test recording from provider response."""
|
||||
collector = MetricsCollector()
|
||||
sample = collector.record_from_response(
|
||||
latency_ms=50.0,
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
validation_passed=True,
|
||||
)
|
||||
assert sample.tokens_total == 15
|
||||
|
||||
def test_clear_samples(self):
|
||||
"""Test clearing samples."""
|
||||
collector = MetricsCollector()
|
||||
collector.record(MetricsSample())
|
||||
collector.clear()
|
||||
|
||||
summary = collector.get_summary()
|
||||
assert summary.count == 0
|
||||
|
||||
def test_compare_collectors(self):
|
||||
"""Test comparing two collectors."""
|
||||
collector1 = MetricsCollector()
|
||||
collector1.record(MetricsSample(latency_ms=100.0))
|
||||
|
||||
collector2 = MetricsCollector()
|
||||
collector2.record(MetricsSample(latency_ms=200.0))
|
||||
|
||||
comparison = collector1.compare(collector2)
|
||||
assert comparison["latency_delta_ms"] == 100.0
|
||||
|
||||
|
||||
class TestResultFormatter:
|
||||
"""Tests for ResultFormatter."""
|
||||
|
||||
def test_format_text(self):
|
||||
"""Test formatting results as text."""
|
||||
from promptforge.testing.results import TestResult
|
||||
results = TestSessionResults(
|
||||
test_id="test-123",
|
||||
name="My Test",
|
||||
)
|
||||
results.results.append(TestResult(
|
||||
test_id="1",
|
||||
prompt_name="prompt1",
|
||||
provider="openai",
|
||||
success=True,
|
||||
response="Hello",
|
||||
))
|
||||
|
||||
formatted = ResultFormatter.format_text(results)
|
||||
assert "My Test" in formatted
|
||||
assert "PASS" in formatted
|
||||
|
||||
def test_format_json(self):
|
||||
"""Test formatting results as JSON."""
|
||||
results = TestSessionResults(
|
||||
test_id="test-123",
|
||||
name="My Test",
|
||||
)
|
||||
|
||||
formatted = ResultFormatter.format_json(results)
|
||||
import json
|
||||
data = json.loads(formatted)
|
||||
assert data["name"] == "My Test"
|
||||
assert "results" in data
|
||||
|
||||
|
||||
class TestABTestConfig:
|
||||
"""Tests for ABTestConfig."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default A/B test configuration."""
|
||||
config = ABTestConfig()
|
||||
assert config.iterations == 3
|
||||
assert config.parallel is False
|
||||
assert config.temperature == 0.7
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom A/B test configuration."""
|
||||
config = ABTestConfig(
|
||||
iterations=5,
|
||||
parallel=True,
|
||||
temperature=0.5,
|
||||
)
|
||||
assert config.iterations == 5
|
||||
assert config.parallel is True
|
||||
Reference in New Issue
Block a user