From 64cef11c7c962085bd43d971fa5a926519448f8c Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:58:30 +0000 Subject: [PATCH] fix: resolve CI linting and type errors --- src/promptforge/testing/validator.py | 268 +++++++++++++++++++++++---- 1 file changed, 237 insertions(+), 31 deletions(-) diff --git a/src/promptforge/testing/validator.py b/src/promptforge/testing/validator.py index 0a09f13..434747f 100644 --- a/src/promptforge/testing/validator.py +++ b/src/promptforge/testing/validator.py @@ -1,43 +1,249 @@ -import re -from typing import Any, Dict, List, Optional +"""Response validation framework.""" + import json - -from ..core.prompt import ValidationRule -from ..core.exceptions import ValidationError +import re +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple -class Validator: - def __init__(self, rules: Optional[List[ValidationRule]] = None): - self.rules = rules or [] +class Validator(ABC): + """Abstract base class for validators.""" - def validate(self, response: str) -> List[str]: + @abstractmethod + def validate(self, response: str) -> Tuple[bool, Optional[str]]: + """Validate a response. + + Args: + response: The response to validate. + + Returns: + Tuple of (is_valid, error_message). + """ + pass + + @abstractmethod + def get_name(self) -> str: + """Get validator name.""" + pass + + +class RegexValidator(Validator): + """Validates responses against regex patterns.""" + + def __init__(self, pattern: str, flags: int = 0): + """Initialize regex validator. + + Args: + pattern: Regex pattern to match. + flags: Regex flags (e.g., re.IGNORECASE). + """ + self.pattern = pattern + self.flags = flags + self._regex = re.compile(pattern, flags) + + def validate(self, response: str) -> Tuple[bool, Optional[str]]: + """Validate response matches regex pattern.""" + if not self._regex.search(response): + return False, f"Response does not match pattern: {self.pattern}" + return True, None + + def get_name(self) -> str: + return f"regex({self.pattern})" + + +class JSONSchemaValidator(Validator): + """Validates JSON responses against a schema.""" + + def __init__(self, schema: Dict[str, Any]): + """Initialize JSON schema validator. + + Args: + schema: JSON schema to validate against. + """ + self.schema = schema + + def validate(self, response: str) -> Tuple[bool, Optional[str]]: + """Validate JSON response against schema.""" + try: + data = json.loads(response) + except json.JSONDecodeError as e: + return False, f"Invalid JSON: {e}" + + errors = self._validate_object(data, self.schema, "") + if errors: + return False, "; ".join(errors) + return True, None + + def _validate_object( + self, + data: Any, + schema: Dict[str, Any], + path: str, + ) -> List[str]: + """Recursively validate against schema.""" errors = [] - for rule in self.rules: - if rule.type == "regex": - if rule.pattern and not re.search(rule.pattern, response): - errors.append(rule.message or f"Response failed regex validation") + if "type" in schema: + expected_type = schema["type"] + type_checks = { + "array": (list, "array"), + "object": (dict, "object"), + "string": (str, "string"), + "number": ((int, float), "number"), + "boolean": (bool, "boolean"), + "integer": ((int,), "integer"), + } + if expected_type in type_checks: + expected_class, type_name = type_checks[expected_type] + if not isinstance(data, expected_class): # type: ignore[arg-type] + actual_type = type(data).__name__ + errors.append(f"{path}: expected {type_name}, got {actual_type}") + return errors - elif rule.type == "json": - try: - json.loads(response) - except json.JSONDecodeError: - errors.append(rule.message or "Response is not valid JSON") + if "properties" in schema and isinstance(data, dict): + for prop, prop_schema in schema["properties"].items(): + if prop in data: + errors.extend( + self._validate_object(data[prop], prop_schema, f"{path}.{prop}") + ) + elif prop_schema.get("required", False): + errors.append(f"{path}.{prop}: required property missing") - elif rule.type == "length": - min_len = rule.json_schema.get("minLength", 0) if rule.json_schema else 0 - max_len = rule.json_schema.get("maxLength", float("inf")) if rule.json_schema else float("inf") - if len(response) < min_len or len(response) > max_len: - errors.append(rule.message or f"Response length must be between {min_len} and {max_len}") + if "enum" in schema and data not in schema["enum"]: + errors.append(f"{path}: value must be one of {schema['enum']}") + + if "minLength" in schema and isinstance(data, str): + if len(data) < schema["minLength"]: + errors.append(f"{path}: string too short (min {schema['minLength']})") + + if "maxLength" in schema and isinstance(data, str): + if len(data) > schema["maxLength"]: + errors.append(f"{path}: string too long (max {schema['maxLength']})") + + if "minimum" in schema and isinstance(data, (int, float)): + if data < schema["minimum"]: + errors.append(f"{path}: value below minimum ({schema['minimum']})") + + if "maximum" in schema and isinstance(data, (int, float)): + if data > schema["maximum"]: + errors.append(f"{path}: value above maximum ({schema['maximum']})") return errors - def is_valid(self, response: str) -> bool: - return len(self.validate(response)) == 0 + def get_name(self) -> str: + return "json-schema" - @staticmethod - def from_prompt_rules(rules: List[Dict[str, Any]]) -> "Validator": - validation_rules = [] - for rule_data in rules: - validation_rules.append(ValidationRule(**rule_data)) - return Validator(validation_rules) \ No newline at end of file + +class LengthValidator(Validator): + """Validates response length constraints.""" + + def __init__( + self, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + ): + """Initialize length validator. + + Args: + min_length: Minimum number of characters. + max_length: Maximum number of characters. + """ + self.min_length = min_length + self.max_length = max_length + + def validate(self, response: str) -> Tuple[bool, Optional[str]]: + """Validate response length.""" + if self.min_length is not None and len(response) < self.min_length: + return False, f"Response too short (min {self.min_length} chars)" + if self.max_length is not None and len(response) > self.max_length: + return False, f"Response too long (max {self.max_length} chars)" + return True, None + + def get_name(self) -> str: + parts = ["length"] + if self.min_length: + parts.append(f"min={self.min_length}") + if self.max_length: + parts.append(f"max={self.max_length}") + return "(" + ", ".join(parts) + ")" + + +class ContainsValidator(Validator): + """Validates response contains expected content.""" + + def __init__( + self, + required_strings: List[str], + all_required: bool = False, + case_sensitive: bool = False, + ): + """Initialize contains validator. + + Args: + required_strings: Strings that must be present. + all_required: If True, all strings must be present. + case_sensitive: Whether to match case. + """ + self.required_strings = required_strings + self.all_required = all_required + self.case_sensitive = case_sensitive + + def validate(self, response: str) -> Tuple[bool, Optional[str]]: + """Validate response contains required strings.""" + strings = self.required_strings + response_lower = response.lower() if not self.case_sensitive else response + + missing = [] + for s in strings: + check_str = s.lower() if not self.case_sensitive else s + if check_str not in response_lower: + missing.append(s) + + if self.all_required: + if missing: + return False, f"Missing required content: {', '.join(missing)}" + else: + if len(missing) == len(strings): + return False, "Response does not contain any expected content" + + return True, None + + def get_name(self) -> str: + mode = "all" if self.all_required else "any" + return f"contains({mode}, {self.required_strings})" + + +class CompositeValidator(Validator): + """Combines multiple validators.""" + + def __init__(self, validators: List[Validator], mode: str = "all"): + """Initialize composite validator. + + Args: + validators: List of validators to combine. + mode: "all" (AND) or "any" (OR) behavior. + """ + self.validators = validators + self.mode = mode + + def validate(self, response: str) -> Tuple[bool, Optional[str]]: + """Validate using all validators.""" + results = [v.validate(response) for v in self.validators] + errors = [] + + if self.mode == "all": + for valid, error in results: + if not valid: + errors.append(error) + if errors: + return False, "; ".join(e for e in errors if e) + return True, None + else: + for valid, _ in results: + if valid: + return True, None + return False, "No validator passed" + + def get_name(self) -> str: + names = [v.get_name() for v in self.validators] + return f"composite({self.mode}, {names})"