fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:30 +00:00
parent 944ea90346
commit 64cef11c7c

View File

@@ -1,43 +1,249 @@
import re
from typing import Any, Dict, List, Optional
"""Response validation framework."""
import json
from ..core.prompt import ValidationRule
from ..core.exceptions import ValidationError
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
class Validator:
def __init__(self, rules: Optional[List[ValidationRule]] = None):
self.rules = rules or []
class Validator(ABC):
"""Abstract base class for validators."""
def validate(self, response: str) -> List[str]:
@abstractmethod
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
"""Validate a response.
Args:
response: The response to validate.
Returns:
Tuple of (is_valid, error_message).
"""
pass
@abstractmethod
def get_name(self) -> str:
"""Get validator name."""
pass
class RegexValidator(Validator):
"""Validates responses against regex patterns."""
def __init__(self, pattern: str, flags: int = 0):
"""Initialize regex validator.
Args:
pattern: Regex pattern to match.
flags: Regex flags (e.g., re.IGNORECASE).
"""
self.pattern = pattern
self.flags = flags
self._regex = re.compile(pattern, flags)
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
"""Validate response matches regex pattern."""
if not self._regex.search(response):
return False, f"Response does not match pattern: {self.pattern}"
return True, None
def get_name(self) -> str:
return f"regex({self.pattern})"
class JSONSchemaValidator(Validator):
"""Validates JSON responses against a schema."""
def __init__(self, schema: Dict[str, Any]):
"""Initialize JSON schema validator.
Args:
schema: JSON schema to validate against.
"""
self.schema = schema
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
"""Validate JSON response against schema."""
try:
data = json.loads(response)
except json.JSONDecodeError as e:
return False, f"Invalid JSON: {e}"
errors = self._validate_object(data, self.schema, "")
if errors:
return False, "; ".join(errors)
return True, None
def _validate_object(
self,
data: Any,
schema: Dict[str, Any],
path: str,
) -> List[str]:
"""Recursively validate against schema."""
errors = []
for rule in self.rules:
if rule.type == "regex":
if rule.pattern and not re.search(rule.pattern, response):
errors.append(rule.message or f"Response failed regex validation")
if "type" in schema:
expected_type = schema["type"]
type_checks = {
"array": (list, "array"),
"object": (dict, "object"),
"string": (str, "string"),
"number": ((int, float), "number"),
"boolean": (bool, "boolean"),
"integer": ((int,), "integer"),
}
if expected_type in type_checks:
expected_class, type_name = type_checks[expected_type]
if not isinstance(data, expected_class): # type: ignore[arg-type]
actual_type = type(data).__name__
errors.append(f"{path}: expected {type_name}, got {actual_type}")
return errors
elif rule.type == "json":
try:
json.loads(response)
except json.JSONDecodeError:
errors.append(rule.message or "Response is not valid JSON")
if "properties" in schema and isinstance(data, dict):
for prop, prop_schema in schema["properties"].items():
if prop in data:
errors.extend(
self._validate_object(data[prop], prop_schema, f"{path}.{prop}")
)
elif prop_schema.get("required", False):
errors.append(f"{path}.{prop}: required property missing")
elif rule.type == "length":
min_len = rule.json_schema.get("minLength", 0) if rule.json_schema else 0
max_len = rule.json_schema.get("maxLength", float("inf")) if rule.json_schema else float("inf")
if len(response) < min_len or len(response) > max_len:
errors.append(rule.message or f"Response length must be between {min_len} and {max_len}")
if "enum" in schema and data not in schema["enum"]:
errors.append(f"{path}: value must be one of {schema['enum']}")
if "minLength" in schema and isinstance(data, str):
if len(data) < schema["minLength"]:
errors.append(f"{path}: string too short (min {schema['minLength']})")
if "maxLength" in schema and isinstance(data, str):
if len(data) > schema["maxLength"]:
errors.append(f"{path}: string too long (max {schema['maxLength']})")
if "minimum" in schema and isinstance(data, (int, float)):
if data < schema["minimum"]:
errors.append(f"{path}: value below minimum ({schema['minimum']})")
if "maximum" in schema and isinstance(data, (int, float)):
if data > schema["maximum"]:
errors.append(f"{path}: value above maximum ({schema['maximum']})")
return errors
def is_valid(self, response: str) -> bool:
return len(self.validate(response)) == 0
def get_name(self) -> str:
return "json-schema"
@staticmethod
def from_prompt_rules(rules: List[Dict[str, Any]]) -> "Validator":
validation_rules = []
for rule_data in rules:
validation_rules.append(ValidationRule(**rule_data))
return Validator(validation_rules)
class LengthValidator(Validator):
"""Validates response length constraints."""
def __init__(
self,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
):
"""Initialize length validator.
Args:
min_length: Minimum number of characters.
max_length: Maximum number of characters.
"""
self.min_length = min_length
self.max_length = max_length
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
"""Validate response length."""
if self.min_length is not None and len(response) < self.min_length:
return False, f"Response too short (min {self.min_length} chars)"
if self.max_length is not None and len(response) > self.max_length:
return False, f"Response too long (max {self.max_length} chars)"
return True, None
def get_name(self) -> str:
parts = ["length"]
if self.min_length:
parts.append(f"min={self.min_length}")
if self.max_length:
parts.append(f"max={self.max_length}")
return "(" + ", ".join(parts) + ")"
class ContainsValidator(Validator):
"""Validates response contains expected content."""
def __init__(
self,
required_strings: List[str],
all_required: bool = False,
case_sensitive: bool = False,
):
"""Initialize contains validator.
Args:
required_strings: Strings that must be present.
all_required: If True, all strings must be present.
case_sensitive: Whether to match case.
"""
self.required_strings = required_strings
self.all_required = all_required
self.case_sensitive = case_sensitive
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
"""Validate response contains required strings."""
strings = self.required_strings
response_lower = response.lower() if not self.case_sensitive else response
missing = []
for s in strings:
check_str = s.lower() if not self.case_sensitive else s
if check_str not in response_lower:
missing.append(s)
if self.all_required:
if missing:
return False, f"Missing required content: {', '.join(missing)}"
else:
if len(missing) == len(strings):
return False, "Response does not contain any expected content"
return True, None
def get_name(self) -> str:
mode = "all" if self.all_required else "any"
return f"contains({mode}, {self.required_strings})"
class CompositeValidator(Validator):
"""Combines multiple validators."""
def __init__(self, validators: List[Validator], mode: str = "all"):
"""Initialize composite validator.
Args:
validators: List of validators to combine.
mode: "all" (AND) or "any" (OR) behavior.
"""
self.validators = validators
self.mode = mode
def validate(self, response: str) -> Tuple[bool, Optional[str]]:
"""Validate using all validators."""
results = [v.validate(response) for v in self.validators]
errors = []
if self.mode == "all":
for valid, error in results:
if not valid:
errors.append(error)
if errors:
return False, "; ".join(e for e in errors if e)
return True, None
else:
for valid, _ in results:
if valid:
return True, None
return False, "No validator passed"
def get_name(self) -> str:
names = [v.get_name() for v in self.validators]
return f"composite({self.mode}, {names})"