fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,43 +1,249 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""Response validation framework."""
|
||||
|
||||
import json
|
||||
|
||||
from ..core.prompt import ValidationRule
|
||||
from ..core.exceptions import ValidationError
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class Validator:
|
||||
def __init__(self, rules: Optional[List[ValidationRule]] = None):
|
||||
self.rules = rules or []
|
||||
class Validator(ABC):
|
||||
"""Abstract base class for validators."""
|
||||
|
||||
def validate(self, response: str) -> List[str]:
|
||||
@abstractmethod
|
||||
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate a response.
|
||||
|
||||
Args:
|
||||
response: The response to validate.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Get validator name."""
|
||||
pass
|
||||
|
||||
|
||||
class RegexValidator(Validator):
|
||||
"""Validates responses against regex patterns."""
|
||||
|
||||
def __init__(self, pattern: str, flags: int = 0):
|
||||
"""Initialize regex validator.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to match.
|
||||
flags: Regex flags (e.g., re.IGNORECASE).
|
||||
"""
|
||||
self.pattern = pattern
|
||||
self.flags = flags
|
||||
self._regex = re.compile(pattern, flags)
|
||||
|
||||
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate response matches regex pattern."""
|
||||
if not self._regex.search(response):
|
||||
return False, f"Response does not match pattern: {self.pattern}"
|
||||
return True, None
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"regex({self.pattern})"
|
||||
|
||||
|
||||
class JSONSchemaValidator(Validator):
|
||||
"""Validates JSON responses against a schema."""
|
||||
|
||||
def __init__(self, schema: Dict[str, Any]):
|
||||
"""Initialize JSON schema validator.
|
||||
|
||||
Args:
|
||||
schema: JSON schema to validate against.
|
||||
"""
|
||||
self.schema = schema
|
||||
|
||||
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate JSON response against schema."""
|
||||
try:
|
||||
data = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
return False, f"Invalid JSON: {e}"
|
||||
|
||||
errors = self._validate_object(data, self.schema, "")
|
||||
if errors:
|
||||
return False, "; ".join(errors)
|
||||
return True, None
|
||||
|
||||
def _validate_object(
|
||||
self,
|
||||
data: Any,
|
||||
schema: Dict[str, Any],
|
||||
path: str,
|
||||
) -> List[str]:
|
||||
"""Recursively validate against schema."""
|
||||
errors = []
|
||||
|
||||
for rule in self.rules:
|
||||
if rule.type == "regex":
|
||||
if rule.pattern and not re.search(rule.pattern, response):
|
||||
errors.append(rule.message or f"Response failed regex validation")
|
||||
if "type" in schema:
|
||||
expected_type = schema["type"]
|
||||
type_checks = {
|
||||
"array": (list, "array"),
|
||||
"object": (dict, "object"),
|
||||
"string": (str, "string"),
|
||||
"number": ((int, float), "number"),
|
||||
"boolean": (bool, "boolean"),
|
||||
"integer": ((int,), "integer"),
|
||||
}
|
||||
if expected_type in type_checks:
|
||||
expected_class, type_name = type_checks[expected_type]
|
||||
if not isinstance(data, expected_class): # type: ignore[arg-type]
|
||||
actual_type = type(data).__name__
|
||||
errors.append(f"{path}: expected {type_name}, got {actual_type}")
|
||||
return errors
|
||||
|
||||
elif rule.type == "json":
|
||||
try:
|
||||
json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(rule.message or "Response is not valid JSON")
|
||||
if "properties" in schema and isinstance(data, dict):
|
||||
for prop, prop_schema in schema["properties"].items():
|
||||
if prop in data:
|
||||
errors.extend(
|
||||
self._validate_object(data[prop], prop_schema, f"{path}.{prop}")
|
||||
)
|
||||
elif prop_schema.get("required", False):
|
||||
errors.append(f"{path}.{prop}: required property missing")
|
||||
|
||||
elif rule.type == "length":
|
||||
min_len = rule.json_schema.get("minLength", 0) if rule.json_schema else 0
|
||||
max_len = rule.json_schema.get("maxLength", float("inf")) if rule.json_schema else float("inf")
|
||||
if len(response) < min_len or len(response) > max_len:
|
||||
errors.append(rule.message or f"Response length must be between {min_len} and {max_len}")
|
||||
if "enum" in schema and data not in schema["enum"]:
|
||||
errors.append(f"{path}: value must be one of {schema['enum']}")
|
||||
|
||||
if "minLength" in schema and isinstance(data, str):
|
||||
if len(data) < schema["minLength"]:
|
||||
errors.append(f"{path}: string too short (min {schema['minLength']})")
|
||||
|
||||
if "maxLength" in schema and isinstance(data, str):
|
||||
if len(data) > schema["maxLength"]:
|
||||
errors.append(f"{path}: string too long (max {schema['maxLength']})")
|
||||
|
||||
if "minimum" in schema and isinstance(data, (int, float)):
|
||||
if data < schema["minimum"]:
|
||||
errors.append(f"{path}: value below minimum ({schema['minimum']})")
|
||||
|
||||
if "maximum" in schema and isinstance(data, (int, float)):
|
||||
if data > schema["maximum"]:
|
||||
errors.append(f"{path}: value above maximum ({schema['maximum']})")
|
||||
|
||||
return errors
|
||||
|
||||
def is_valid(self, response: str) -> bool:
|
||||
return len(self.validate(response)) == 0
|
||||
def get_name(self) -> str:
|
||||
return "json-schema"
|
||||
|
||||
@staticmethod
|
||||
def from_prompt_rules(rules: List[Dict[str, Any]]) -> "Validator":
|
||||
validation_rules = []
|
||||
for rule_data in rules:
|
||||
validation_rules.append(ValidationRule(**rule_data))
|
||||
return Validator(validation_rules)
|
||||
|
||||
class LengthValidator(Validator):
|
||||
"""Validates response length constraints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
"""Initialize length validator.
|
||||
|
||||
Args:
|
||||
min_length: Minimum number of characters.
|
||||
max_length: Maximum number of characters.
|
||||
"""
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
|
||||
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate response length."""
|
||||
if self.min_length is not None and len(response) < self.min_length:
|
||||
return False, f"Response too short (min {self.min_length} chars)"
|
||||
if self.max_length is not None and len(response) > self.max_length:
|
||||
return False, f"Response too long (max {self.max_length} chars)"
|
||||
return True, None
|
||||
|
||||
def get_name(self) -> str:
|
||||
parts = ["length"]
|
||||
if self.min_length:
|
||||
parts.append(f"min={self.min_length}")
|
||||
if self.max_length:
|
||||
parts.append(f"max={self.max_length}")
|
||||
return "(" + ", ".join(parts) + ")"
|
||||
|
||||
|
||||
class ContainsValidator(Validator):
|
||||
"""Validates response contains expected content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
required_strings: List[str],
|
||||
all_required: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
):
|
||||
"""Initialize contains validator.
|
||||
|
||||
Args:
|
||||
required_strings: Strings that must be present.
|
||||
all_required: If True, all strings must be present.
|
||||
case_sensitive: Whether to match case.
|
||||
"""
|
||||
self.required_strings = required_strings
|
||||
self.all_required = all_required
|
||||
self.case_sensitive = case_sensitive
|
||||
|
||||
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate response contains required strings."""
|
||||
strings = self.required_strings
|
||||
response_lower = response.lower() if not self.case_sensitive else response
|
||||
|
||||
missing = []
|
||||
for s in strings:
|
||||
check_str = s.lower() if not self.case_sensitive else s
|
||||
if check_str not in response_lower:
|
||||
missing.append(s)
|
||||
|
||||
if self.all_required:
|
||||
if missing:
|
||||
return False, f"Missing required content: {', '.join(missing)}"
|
||||
else:
|
||||
if len(missing) == len(strings):
|
||||
return False, "Response does not contain any expected content"
|
||||
|
||||
return True, None
|
||||
|
||||
def get_name(self) -> str:
|
||||
mode = "all" if self.all_required else "any"
|
||||
return f"contains({mode}, {self.required_strings})"
|
||||
|
||||
|
||||
class CompositeValidator(Validator):
|
||||
"""Combines multiple validators."""
|
||||
|
||||
def __init__(self, validators: List[Validator], mode: str = "all"):
|
||||
"""Initialize composite validator.
|
||||
|
||||
Args:
|
||||
validators: List of validators to combine.
|
||||
mode: "all" (AND) or "any" (OR) behavior.
|
||||
"""
|
||||
self.validators = validators
|
||||
self.mode = mode
|
||||
|
||||
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate using all validators."""
|
||||
results = [v.validate(response) for v in self.validators]
|
||||
errors = []
|
||||
|
||||
if self.mode == "all":
|
||||
for valid, error in results:
|
||||
if not valid:
|
||||
errors.append(error)
|
||||
if errors:
|
||||
return False, "; ".join(e for e in errors if e)
|
||||
return True, None
|
||||
else:
|
||||
for valid, _ in results:
|
||||
if valid:
|
||||
return True, None
|
||||
return False, "No validator passed"
|
||||
|
||||
def get_name(self) -> str:
|
||||
names = [v.get_name() for v in self.validators]
|
||||
return f"composite({self.mode}, {names})"
|
||||
|
||||
Reference in New Issue
Block a user