import random import string from typing import Optional from .parser import parse_regex, RegexNode, NodeType class TestCaseGenerator: """Generates matching and non-matching test cases for regex patterns.""" def __init__(self, flavor: str = "pcre"): self.flavor = flavor def generate_matching( self, pattern: str, count: int = 5, max_length: int = 50 ) -> list[str]: """Generate strings that match the pattern.""" try: ast = parse_regex(pattern, self.flavor) return self._generate_matching_from_ast(ast, count, max_length) except Exception: return self._generate_fallback_matching(pattern, count) def _generate_matching_from_ast( self, node: RegexNode, count: int, max_length: int ) -> list[str]: """Generate matching strings from AST.""" if node.node_type == NodeType.SEQUENCE: return self._generate_sequence(node.children, count, max_length) return [pattern_to_string(node, max_length) for _ in range(count)] def _generate_sequence( self, children: list[RegexNode], count: int, max_length: int ) -> list[str]: """Generate strings for a sequence of nodes.""" results = [] for _ in range(count): parts = [] for child in children: if len("".join(parts)) >= max_length: break part = generate_from_node(child, max_length - len("".join(parts))) if part is None: part = "" parts.append(part) results.append("".join(parts)) return results def _generate_fallback_matching( self, pattern: str, count: int ) -> list[str]: """Fallback matching generation using simple heuristics.""" results = [] for _ in range(count): result = "" in_class = False class_chars = [] for char in pattern: if char == '\\' and len(pattern) > 1: next_char = pattern[pattern.index(char) + 1] if next_char in 'dDsSwWbB': if next_char == 'd': result += random.choice(string.digits) elif next_char == 'D': result += random.choice(string.ascii_letters) elif next_char == 'w': result += random.choice(string.ascii_letters) elif next_char == 'W': result += random.choice(' !@#$%^&*()') elif next_char == 's': result += " " elif next_char == 'b': result += random.choice(string.ascii_letters) else: result += next_char elif char == '.': result += random.choice(string.ascii_letters) elif char in '*+?': continue elif char == '[': in_class = True class_chars = [] elif char == ']': in_class = False if class_chars: result += random.choice(class_chars) elif in_class: if char == '-' and class_chars: pass else: class_chars.append(char) elif char not in '()|^$\\{}': result += char if not result: result = "test" results.append(result[:20]) return results[:count] def generate_non_matching( self, pattern: str, count: int = 5, max_length: int = 50 ) -> list[str]: """Generate strings that do NOT match the pattern.""" try: ast = parse_regex(pattern, self.flavor) return self._generate_non_matching_from_ast(pattern, ast, count, max_length) except Exception: return self._generate_fallback_non_matching(pattern, count) def _generate_non_matching_from_ast( self, pattern: str, node: RegexNode, count: int, max_length: int ) -> list[str]: """Generate non-matching strings from AST.""" results = set() if node.node_type == NodeType.ANCHOR_START: return [s + "prefix" for s in results] or ["prefix_test"] if node.node_type == NodeType.ANCHOR_END: return ["suffix" + s for s in results] or ["test_suffix"] if node.node_type == NodeType.START_OF_STRING: return ["prefix" + s for s in results] or ["prefix_test"] if node.node_type == NodeType.END_OF_STRING: return [s + "suffix" for s in results] or ["test_suffix"] base_matching = self._generate_matching_from_ast(node, 10, max_length) for matching in base_matching: if len(results) >= count: break if len(matching) > 0: pos = random.randint(0, len(matching) - 1) original = matching[pos] replacement = get_replacement_char(original) if replacement != original: non_match = matching[:pos] + replacement + matching[pos + 1:] if not matches_pattern(pattern, non_match, self.flavor): results.add(non_match) if len(results) < count and matching: pos = random.randint(0, len(matching)) char_to_add = get_opposite_char_class(matching[pos - 1] if pos > 0 else 'a') non_match = matching[:pos] + char_to_add + matching[pos:] if not matches_pattern(pattern, non_match, self.flavor): results.add(non_match) if len(results) < count: for _ in range(count - len(results)): base = self._generate_fallback_non_matching(pattern, 1)[0] if self._generate_fallback_non_matching(pattern, 1) else "does_not_match_123" results.add(base + str(random.randint(100, 999))) return list(results)[:count] def _generate_fallback_non_matching( self, pattern: str, count: int ) -> list[str]: """Fallback non-matching generation.""" results = ["does_not_match", "completely_different", "!@#$%^&*()", "", "xyz123"] if pattern.startswith('^'): results.append("prefix_" + results[0]) if pattern.endswith('$'): results.append(results[0] + "_suffix") if '\\d' in pattern or '[0-9]' in pattern: results.append("abc_def") if '\\w' in pattern: results.append("!@#$%^&*") if '\\s' in pattern: results.append("nospacehere") dot_count = pattern.count('.') if dot_count > 0: results.append("x" * (dot_count + 1)) import re try: compiled = re.compile(pattern) filtered_results = [] for r in results: if compiled.search(r) is None: filtered_results.append(r) if filtered_results: return filtered_results[:count] except re.error: pass return results[:count] def generate_from_node(node: RegexNode, max_length: int) -> Optional[str]: """Generate a string from a single node.""" if node.node_type == NodeType.LITERAL: return node.value[:max_length] if node.value else None if node.node_type == NodeType.ESCAPED_CHAR: return node.value if node.value else None if node.node_type == NodeType.DOT: return random.choice(string.ascii_letters) if node.node_type in (NodeType.POSITIVE_SET, NodeType.NEGATIVE_SET): if node.node_type == NodeType.NEGATIVE_SET: all_chars = [] for start, end in node.ranges: all_chars.extend([chr(i) for i in range(ord(start), ord(end) + 1)]) all_chars.extend(node.characters) available = [c for c in string.ascii_letters if c not in all_chars] if available: return random.choice(available) return "!" if node.ranges: start, end = node.ranges[0] return chr(random.randint(ord(start), ord(end))) if node.characters: return random.choice(node.characters) return "a" if node.node_type in (NodeType.DIGIT, NodeType.NON_DIGIT): return random.choice(string.digits) if node.node_type in (NodeType.WORD_CHAR, NodeType.NON_WORD_CHAR): return random.choice(string.ascii_letters) if node.node_type in (NodeType.WHITESPACE, NodeType.NON_WHITESPACE): return " " if node.node_type == NodeType.QUANTIFIER: if node.children: child_str = generate_from_node(node.children[0], max_length) if child_str is None: child_str = "x" min_count = node.min_count if node.min_count else 0 max_count = min(node.max_count, 3) if node.max_count and node.max_count != float('inf') else 3 max_count = max(min_count, max_count) if min_count == 0 and max_count == 0: repeat = 0 elif min_count == 0: repeat = random.randint(1, max_count) else: repeat = random.randint(min_count, max_count) return (child_str * repeat)[:max_length] return None if node.node_type == NodeType.CAPTURING_GROUP: if node.children: return generate_from_node(node.children[0], max_length) return None if node.node_type == NodeType.NON_CAPTURING_GROUP: if node.children: return generate_from_node(node.children[0], max_length) return None if node.node_type == NodeType.NAMED_GROUP: if node.children: return generate_from_node(node.children[0], max_length) return None if node.node_type in (NodeType.LOOKAHEAD, NodeType.NEGATIVE_LOOKAHEAD): return "" if node.node_type in (NodeType.LOOKBEHIND, NodeType.NEGATIVE_LOOKBEHIND): return "" if node.node_type == NodeType.SEQUENCE: result = "" for child in node.children: if len(result) >= max_length: break part = generate_from_node(child, max_length - len(result)) if part: result += part return result if result else None if node.node_type == NodeType.BRANCH: if node.children: choices = [] for child in node.children: part = generate_from_node(child, max_length) if part: choices.append(part) if choices: return random.choice(choices) return None return None def pattern_to_string(node: RegexNode, max_length: int) -> str: """Convert a node to a representative string.""" result = generate_from_node(node, max_length) return result if result else "test" def get_replacement_char(original: str) -> str: """Get a replacement character different from the original.""" if original.isdigit(): return random.choice([c for c in string.digits if c != original]) if original.isalpha(): return random.choice([c for c in string.ascii_letters if c.lower() != original.lower()]) if original == ' ': return random.choice(['\\t', '\\n']) return 'x' def get_opposite_char_class(char: str) -> str: """Get a character from a different class.""" if char.isdigit(): return random.choice(string.ascii_letters) if char.isalpha(): return random.choice(string.digits) if char == ' ': return 'x' return '1' def matches_pattern(pattern: str, text: str, flavor: str) -> bool: """Check if text matches pattern.""" import re try: flags = 0 if flavor == "python": pass elif flavor == "javascript": flags = re.MULTILINE elif flavor == "pcre": flags = re.MULTILINE compiled = re.compile(pattern, flags) return compiled.search(text) is not None except re.error: return False def generate_test_cases( pattern: str, flavor: str = "pcre", matching_count: int = 5, non_matching_count: int = 5 ) -> dict: """Generate all test cases for a pattern.""" generator = TestCaseGenerator(flavor) return { "pattern": pattern, "flavor": flavor, "matching": generator.generate_matching(pattern, matching_count), "non_matching": generator.generate_non_matching(pattern, non_matching_count) }