240 lines
7.3 KiB
Python
240 lines
7.3 KiB
Python
from promptforge.testing.validator import (
|
|
RegexValidator,
|
|
JSONSchemaValidator,
|
|
LengthValidator,
|
|
ContainsValidator,
|
|
CompositeValidator,
|
|
)
|
|
from promptforge.testing.metrics import MetricsCollector, MetricsSample
|
|
from promptforge.testing.results import TestSessionResults, ResultFormatter
|
|
from promptforge.testing.ab_test import ABTestConfig
|
|
|
|
|
|
class TestRegexValidator:
|
|
"""Tests for RegexValidator."""
|
|
|
|
def test_valid_pattern(self):
|
|
"""Test matching pattern."""
|
|
validator = RegexValidator(r"^Hello.*")
|
|
is_valid, error = validator.validate("Hello, World!")
|
|
assert is_valid is True
|
|
assert error is None
|
|
|
|
def test_invalid_pattern(self):
|
|
"""Test non-matching pattern."""
|
|
validator = RegexValidator(r"^Hello.*")
|
|
is_valid, error = validator.validate("Goodbye")
|
|
assert is_valid is False
|
|
assert error is not None
|
|
|
|
def test_case_insensitive(self):
|
|
"""Test case insensitive matching."""
|
|
validator = RegexValidator(r"hello", flags=2) # re.IGNORECASE
|
|
is_valid, _ = validator.validate("HELLO")
|
|
assert is_valid is True
|
|
|
|
|
|
class TestJSONSchemaValidator:
|
|
"""Tests for JSONSchemaValidator."""
|
|
|
|
def test_valid_json(self):
|
|
"""Test valid JSON against schema."""
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"age": {"type": "number"}
|
|
}
|
|
}
|
|
validator = JSONSchemaValidator(schema)
|
|
is_valid, error = validator.validate('{"name": "Alice", "age": 30}')
|
|
assert is_valid is True
|
|
|
|
def test_invalid_json(self):
|
|
"""Test invalid JSON."""
|
|
validator = JSONSchemaValidator({})
|
|
is_valid, error = validator.validate("not json")
|
|
assert is_valid is False
|
|
assert "JSON" in error
|
|
|
|
def test_type_mismatch(self):
|
|
"""Test type mismatch in schema."""
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"count": {"type": "integer"}
|
|
}
|
|
}
|
|
validator = JSONSchemaValidator(schema)
|
|
is_valid, error = validator.validate('{"count": "not a number"}')
|
|
assert is_valid is False
|
|
|
|
|
|
class TestLengthValidator:
|
|
"""Tests for LengthValidator."""
|
|
|
|
def test_within_bounds(self):
|
|
"""Test valid length."""
|
|
validator = LengthValidator(min_length=5, max_length=100)
|
|
is_valid, error = validator.validate("Hello, World!")
|
|
assert is_valid is True
|
|
|
|
def test_too_short(self):
|
|
"""Test string too short."""
|
|
validator = LengthValidator(min_length=10)
|
|
is_valid, error = validator.validate("Hi")
|
|
assert is_valid is False
|
|
assert "short" in error
|
|
|
|
def test_too_long(self):
|
|
"""Test string too long."""
|
|
validator = LengthValidator(max_length=5)
|
|
is_valid, error = validator.validate("Hello, World!")
|
|
assert is_valid is False
|
|
assert "long" in error
|
|
|
|
|
|
class TestContainsValidator:
|
|
"""Tests for ContainsValidator."""
|
|
|
|
def test_contains_string(self):
|
|
"""Test string contains required content."""
|
|
validator = ContainsValidator(required_strings=["hello", "world"])
|
|
is_valid, error = validator.validate("Say hello to the world")
|
|
assert is_valid is True
|
|
|
|
def test_missing_content(self):
|
|
"""Test missing required content."""
|
|
validator = ContainsValidator(required_strings=["hello", "world"])
|
|
is_valid, error = validator.validate("Just some random text")
|
|
assert is_valid is False
|
|
|
|
|
|
class TestCompositeValidator:
|
|
"""Tests for CompositeValidator."""
|
|
|
|
def test_all_mode(self):
|
|
"""Test AND mode validation."""
|
|
validators = [
|
|
RegexValidator(r"^Hello.*"),
|
|
LengthValidator(min_length=5),
|
|
]
|
|
composite = CompositeValidator(validators, mode="all")
|
|
is_valid, _ = composite.validate("Hello, World!")
|
|
assert is_valid is True
|
|
|
|
def test_any_mode(self):
|
|
"""Test OR mode validation."""
|
|
validators = [
|
|
RegexValidator(r"^Hello.*"),
|
|
RegexValidator(r"^Goodbye.*"),
|
|
]
|
|
composite = CompositeValidator(validators, mode="any")
|
|
is_valid, _ = composite.validate("Goodbye, World!")
|
|
assert is_valid is True
|
|
|
|
|
|
class TestMetricsCollector:
|
|
"""Tests for MetricsCollector."""
|
|
|
|
def test_record_sample(self):
|
|
"""Test recording metrics sample."""
|
|
collector = MetricsCollector()
|
|
sample = MetricsSample(
|
|
latency_ms=100.0,
|
|
tokens_total=50,
|
|
validation_passed=True,
|
|
)
|
|
collector.record(sample)
|
|
|
|
summary = collector.get_summary()
|
|
assert summary.count == 1
|
|
assert summary.latency["avg"] == 100.0
|
|
|
|
def test_record_from_response(self):
|
|
"""Test recording from provider response."""
|
|
collector = MetricsCollector()
|
|
sample = collector.record_from_response(
|
|
latency_ms=50.0,
|
|
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
|
validation_passed=True,
|
|
)
|
|
assert sample.tokens_total == 15
|
|
|
|
def test_clear_samples(self):
|
|
"""Test clearing samples."""
|
|
collector = MetricsCollector()
|
|
collector.record(MetricsSample())
|
|
collector.clear()
|
|
|
|
summary = collector.get_summary()
|
|
assert summary.count == 0
|
|
|
|
def test_compare_collectors(self):
|
|
"""Test comparing two collectors."""
|
|
collector1 = MetricsCollector()
|
|
collector1.record(MetricsSample(latency_ms=100.0))
|
|
|
|
collector2 = MetricsCollector()
|
|
collector2.record(MetricsSample(latency_ms=200.0))
|
|
|
|
comparison = collector1.compare(collector2)
|
|
assert comparison["latency_delta_ms"] == 100.0
|
|
|
|
|
|
class TestResultFormatter:
|
|
"""Tests for ResultFormatter."""
|
|
|
|
def test_format_text(self):
|
|
"""Test formatting results as text."""
|
|
from promptforge.testing.results import TestResult
|
|
results = TestSessionResults(
|
|
test_id="test-123",
|
|
name="My Test",
|
|
)
|
|
results.results.append(TestResult(
|
|
test_id="1",
|
|
prompt_name="prompt1",
|
|
provider="openai",
|
|
success=True,
|
|
response="Hello",
|
|
))
|
|
|
|
formatted = ResultFormatter.format_text(results)
|
|
assert "My Test" in formatted
|
|
assert "PASS" in formatted
|
|
|
|
def test_format_json(self):
|
|
"""Test formatting results as JSON."""
|
|
results = TestSessionResults(
|
|
test_id="test-123",
|
|
name="My Test",
|
|
)
|
|
|
|
formatted = ResultFormatter.format_json(results)
|
|
import json
|
|
data = json.loads(formatted)
|
|
assert data["name"] == "My Test"
|
|
assert "results" in data
|
|
|
|
|
|
class TestABTestConfig:
|
|
"""Tests for ABTestConfig."""
|
|
|
|
def test_default_config(self):
|
|
"""Test default A/B test configuration."""
|
|
config = ABTestConfig()
|
|
assert config.iterations == 3
|
|
assert config.parallel is False
|
|
assert config.temperature == 0.7
|
|
|
|
def test_custom_config(self):
|
|
"""Test custom A/B test configuration."""
|
|
config = ABTestConfig(
|
|
iterations=5,
|
|
parallel=True,
|
|
temperature=0.5,
|
|
)
|
|
assert config.iterations == 5
|
|
assert config.parallel is True
|