diff --git a/app/tests/test_testing.py b/app/tests/test_testing.py new file mode 100644 index 0000000..edf29d0 --- /dev/null +++ b/app/tests/test_testing.py @@ -0,0 +1,239 @@ +from promptforge.testing.validator import ( + RegexValidator, + JSONSchemaValidator, + LengthValidator, + ContainsValidator, + CompositeValidator, +) +from promptforge.testing.metrics import MetricsCollector, MetricsSample +from promptforge.testing.results import TestSessionResults, ResultFormatter +from promptforge.testing.ab_test import ABTestConfig + + +class TestRegexValidator: + """Tests for RegexValidator.""" + + def test_valid_pattern(self): + """Test matching pattern.""" + validator = RegexValidator(r"^Hello.*") + is_valid, error = validator.validate("Hello, World!") + assert is_valid is True + assert error is None + + def test_invalid_pattern(self): + """Test non-matching pattern.""" + validator = RegexValidator(r"^Hello.*") + is_valid, error = validator.validate("Goodbye") + assert is_valid is False + assert error is not None + + def test_case_insensitive(self): + """Test case insensitive matching.""" + validator = RegexValidator(r"hello", flags=2) # re.IGNORECASE + is_valid, _ = validator.validate("HELLO") + assert is_valid is True + + +class TestJSONSchemaValidator: + """Tests for JSONSchemaValidator.""" + + def test_valid_json(self): + """Test valid JSON against schema.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"} + } + } + validator = JSONSchemaValidator(schema) + is_valid, error = validator.validate('{"name": "Alice", "age": 30}') + assert is_valid is True + + def test_invalid_json(self): + """Test invalid JSON.""" + validator = JSONSchemaValidator({}) + is_valid, error = validator.validate("not json") + assert is_valid is False + assert "JSON" in error + + def test_type_mismatch(self): + """Test type mismatch in schema.""" + schema = { + "type": "object", + "properties": { + "count": {"type": "integer"} + } + } + validator = JSONSchemaValidator(schema) + is_valid, error = validator.validate('{"count": "not a number"}') + assert is_valid is False + + +class TestLengthValidator: + """Tests for LengthValidator.""" + + def test_within_bounds(self): + """Test valid length.""" + validator = LengthValidator(min_length=5, max_length=100) + is_valid, error = validator.validate("Hello, World!") + assert is_valid is True + + def test_too_short(self): + """Test string too short.""" + validator = LengthValidator(min_length=10) + is_valid, error = validator.validate("Hi") + assert is_valid is False + assert "short" in error + + def test_too_long(self): + """Test string too long.""" + validator = LengthValidator(max_length=5) + is_valid, error = validator.validate("Hello, World!") + assert is_valid is False + assert "long" in error + + +class TestContainsValidator: + """Tests for ContainsValidator.""" + + def test_contains_string(self): + """Test string contains required content.""" + validator = ContainsValidator(required_strings=["hello", "world"]) + is_valid, error = validator.validate("Say hello to the world") + assert is_valid is True + + def test_missing_content(self): + """Test missing required content.""" + validator = ContainsValidator(required_strings=["hello", "world"]) + is_valid, error = validator.validate("Just some random text") + assert is_valid is False + + +class TestCompositeValidator: + """Tests for CompositeValidator.""" + + def test_all_mode(self): + """Test AND mode validation.""" + validators = [ + RegexValidator(r"^Hello.*"), + LengthValidator(min_length=5), + ] + composite = CompositeValidator(validators, mode="all") + is_valid, _ = composite.validate("Hello, World!") + assert is_valid is True + + def test_any_mode(self): + """Test OR mode validation.""" + validators = [ + RegexValidator(r"^Hello.*"), + RegexValidator(r"^Goodbye.*"), + ] + composite = CompositeValidator(validators, mode="any") + is_valid, _ = composite.validate("Goodbye, World!") + assert is_valid is True + + +class TestMetricsCollector: + """Tests for MetricsCollector.""" + + def test_record_sample(self): + """Test recording metrics sample.""" + collector = MetricsCollector() + sample = MetricsSample( + latency_ms=100.0, + tokens_total=50, + validation_passed=True, + ) + collector.record(sample) + + summary = collector.get_summary() + assert summary.count == 1 + assert summary.latency["avg"] == 100.0 + + def test_record_from_response(self): + """Test recording from provider response.""" + collector = MetricsCollector() + sample = collector.record_from_response( + latency_ms=50.0, + usage={"prompt_tokens": 10, "completion_tokens": 5}, + validation_passed=True, + ) + assert sample.tokens_total == 15 + + def test_clear_samples(self): + """Test clearing samples.""" + collector = MetricsCollector() + collector.record(MetricsSample()) + collector.clear() + + summary = collector.get_summary() + assert summary.count == 0 + + def test_compare_collectors(self): + """Test comparing two collectors.""" + collector1 = MetricsCollector() + collector1.record(MetricsSample(latency_ms=100.0)) + + collector2 = MetricsCollector() + collector2.record(MetricsSample(latency_ms=200.0)) + + comparison = collector1.compare(collector2) + assert comparison["latency_delta_ms"] == 100.0 + + +class TestResultFormatter: + """Tests for ResultFormatter.""" + + def test_format_text(self): + """Test formatting results as text.""" + from promptforge.testing.results import TestResult + results = TestSessionResults( + test_id="test-123", + name="My Test", + ) + results.results.append(TestResult( + test_id="1", + prompt_name="prompt1", + provider="openai", + success=True, + response="Hello", + )) + + formatted = ResultFormatter.format_text(results) + assert "My Test" in formatted + assert "PASS" in formatted + + def test_format_json(self): + """Test formatting results as JSON.""" + results = TestSessionResults( + test_id="test-123", + name="My Test", + ) + + formatted = ResultFormatter.format_json(results) + import json + data = json.loads(formatted) + assert data["name"] == "My Test" + assert "results" in data + + +class TestABTestConfig: + """Tests for ABTestConfig.""" + + def test_default_config(self): + """Test default A/B test configuration.""" + config = ABTestConfig() + assert config.iterations == 3 + assert config.parallel is False + assert config.temperature == 0.7 + + def test_custom_config(self): + """Test custom A/B test configuration.""" + config = ABTestConfig( + iterations=5, + parallel=True, + temperature=0.5, + ) + assert config.iterations == 5 + assert config.parallel is True