diff --git a/regex_humanizer/test_generator.py b/regex_humanizer/test_generator.py new file mode 100644 index 0000000..714dffa --- /dev/null +++ b/regex_humanizer/test_generator.py @@ -0,0 +1,381 @@ +"""Test case generator for regex patterns.""" + +import random +import string +from typing import Optional, Callable +from .parser import parse_regex, RegexNode, NodeType + + +class TestCaseGenerator: + """Generates matching and non-matching test cases for regex patterns.""" + + def __init__(self, flavor: str = "pcre"): + self.flavor = flavor + + def generate_matching( + self, + pattern: str, + count: int = 5, + max_length: int = 50 + ) -> list[str]: + """Generate strings that match the pattern.""" + try: + ast = parse_regex(pattern, self.flavor) + return self._generate_matching_from_ast(ast, count, max_length) + except Exception: + return self._generate_fallback_matching(pattern, count) + + def _generate_matching_from_ast( + self, + node: RegexNode, + count: int, + max_length: int + ) -> list[str]: + """Generate matching strings from AST.""" + if node.node_type == NodeType.SEQUENCE: + return self._generate_sequence(node.children, count, max_length) + return [pattern_to_string(node, max_length) for _ in range(count)] + + def _generate_sequence( + self, + children: list[RegexNode], + count: int, + max_length: int + ) -> list[str]: + """Generate strings for a sequence of nodes.""" + results = [] + for _ in range(count): + parts = [] + for child in children: + if len("".join(parts)) >= max_length: + break + part = generate_from_node(child, max_length - len("".join(parts))) + if part is None: + part = "" + parts.append(part) + results.append("".join(parts)) + return results + + def _generate_fallback_matching( + self, + pattern: str, + count: int + ) -> list[str]: + """Fallback matching generation using simple heuristics.""" + results = [] + for _ in range(count): + result = "" + in_class = False + class_chars = [] + + for char in pattern: + if char == '\\' and len(pattern) > 1: + next_char = pattern[pattern.index(char) + 1] + if next_char in 'dDsSwWbB': + if next_char == 'd': + result += random.choice(string.digits) + elif next_char == 'D': + result += random.choice(string.ascii_letters) + elif next_char == 'w': + result += random.choice(string.ascii_letters) + elif next_char == 'W': + result += random.choice(' !@#$%^&*()') + elif next_char == 's': + result += " " + elif next_char == 'b': + result += random.choice(string.ascii_letters) + else: + result += next_char + elif char == '.': + result += random.choice(string.ascii_letters) + elif char in '*+?': + continue + elif char == '[': + in_class = True + class_chars = [] + elif char == ']': + in_class = False + if class_chars: + result += random.choice(class_chars) + elif in_class: + if char == '-' and class_chars: + pass + else: + class_chars.append(char) + elif char not in '()|^$\\{}': + result += char + + if not result: + result = "test" + results.append(result[:20]) + + return results[:count] + + def generate_non_matching( + self, + pattern: str, + count: int = 5, + max_length: int = 50 + ) -> list[str]: + """Generate strings that do NOT match the pattern.""" + try: + ast = parse_regex(pattern, self.flavor) + return self._generate_non_matching_from_ast(ast, count, max_length) + except Exception: + return self._generate_fallback_non_matching(pattern, count) + + def _generate_non_matching_from_ast( + self, + node: RegexNode, + count: int, + max_length: int + ) -> list[str]: + """Generate non-matching strings from AST.""" + results = set() + + if node.node_type == NodeType.ANCHOR_START: + return [s + "prefix" for s in results] or ["prefix_test"] + + if node.node_type == NodeType.ANCHOR_END: + return ["suffix" + s for s in results] or ["test_suffix"] + + if node.node_type == NodeType.START_OF_STRING: + return ["prefix" + s for s in results] or ["prefix_test"] + + if node.node_type == NodeType.END_OF_STRING: + return [s + "suffix" for s in results] or ["test_suffix"] + + base_matching = self._generate_matching_from_ast(node, 10, max_length) + + for matching in base_matching: + if len(results) >= count: + break + + if len(matching) > 0: + pos = random.randint(0, len(matching) - 1) + original = matching[pos] + replacement = get_replacement_char(original) + if replacement != original: + non_match = matching[:pos] + replacement + matching[pos + 1:] + if not matches_pattern(pattern, non_match, self.flavor): + results.add(non_match) + + if len(results) < count and matching: + pos = random.randint(0, len(matching)) + char_to_add = get_opposite_char_class(matching[pos - 1] if pos > 0 else 'a') + non_match = matching[:pos] + char_to_add + matching[pos:] + if not matches_pattern(pattern, non_match, self.flavor): + results.add(non_match) + + if len(results) < count: + for _ in range(count - len(results)): + base = self._generate_fallback_non_matching(pattern, 1)[0] if self._generate_fallback_non_matching(pattern, 1) else "does_not_match_123" + results.add(base + str(random.randint(100, 999))) + + return list(results)[:count] + + def _generate_fallback_non_matching( + self, + pattern: str, + count: int + ) -> list[str]: + """Fallback non-matching generation.""" + results = ["does_not_match", "completely_different", "!@#$%^&*()", "", "xyz123"] + + if pattern.startswith('^'): + results.append("prefix_" + results[0]) + + if pattern.endswith('$'): + results.append(results[0] + "_suffix") + + if '\\d' in pattern or '[0-9]' in pattern: + results.append("abc_def") + + if '\\w' in pattern: + results.append("!@#$%^&*") + + if '\\s' in pattern: + results.append("nospacehere") + + dot_count = pattern.count('.') + if dot_count > 0: + results.append("x" * (dot_count + 1)) + + import re + try: + compiled = re.compile(pattern) + filtered_results = [] + for r in results: + if compiled.search(r) is None: + filtered_results.append(r) + if filtered_results: + return filtered_results[:count] + except re.error: + pass + + return results[:count] + + +def generate_from_node(node: RegexNode, max_length: int) -> Optional[str]: + """Generate a string from a single node.""" + if node.node_type == NodeType.LITERAL: + return node.value[:max_length] if node.value else None + + if node.node_type == NodeType.ESCAPED_CHAR: + return node.value if node.value else None + + if node.node_type == NodeType.DOT: + return random.choice(string.ascii_letters) + + if node.node_type in (NodeType.POSITIVE_SET, NodeType.NEGATIVE_SET): + if node.node_type == NodeType.NEGATIVE_SET: + all_chars = [] + for start, end in node.ranges: + all_chars.extend([chr(i) for i in range(ord(start), ord(end) + 1)]) + all_chars.extend(node.characters) + available = [c for c in string.ascii_letters if c not in all_chars] + if available: + return random.choice(available) + return "!" + if node.ranges: + start, end = node.ranges[0] + return chr(random.randint(ord(start), ord(end))) + if node.characters: + return random.choice(node.characters) + return "a" + + if node.node_type in (NodeType.DIGIT, NodeType.NON_DIGIT): + return random.choice(string.digits) + + if node.node_type in (NodeType.WORD_CHAR, NodeType.NON_WORD_CHAR): + return random.choice(string.ascii_letters) + + if node.node_type in (NodeType.WHITESPACE, NodeType.NON_WHITESPACE): + return " " + + if node.node_type == NodeType.QUANTIFIER: + if node.children: + child_str = generate_from_node(node.children[0], max_length) + if child_str is None: + child_str = "x" + + min_count = node.min_count if node.min_count else 0 + max_count = min(node.max_count, 3) if node.max_count and node.max_count != float('inf') else 3 + max_count = max(min_count, max_count) + + if min_count == 0 and max_count == 0: + repeat = 0 + elif min_count == 0: + repeat = random.randint(1, max_count) + else: + repeat = random.randint(min_count, max_count) + + return (child_str * repeat)[:max_length] + return None + + if node.node_type == NodeType.CAPTURING_GROUP: + if node.children: + return generate_from_node(node.children[0], max_length) + return None + + if node.node_type == NodeType.NON_CAPTURING_GROUP: + if node.children: + return generate_from_node(node.children[0], max_length) + return None + + if node.node_type == NodeType.NAMED_GROUP: + if node.children: + return generate_from_node(node.children[0], max_length) + return None + + if node.node_type in (NodeType.LOOKAHEAD, NodeType.NEGATIVE_LOOKAHEAD): + return "" + + if node.node_type in (NodeType.LOOKBEHIND, NodeType.NEGATIVE_LOOKBEHIND): + return "" + + if node.node_type == NodeType.SEQUENCE: + result = "" + for child in node.children: + if len(result) >= max_length: + break + part = generate_from_node(child, max_length - len(result)) + if part: + result += part + return result if result else None + + if node.node_type == NodeType.BRANCH: + if node.children: + choices = [] + for child in node.children: + part = generate_from_node(child, max_length) + if part: + choices.append(part) + if choices: + return random.choice(choices) + return None + + return None + + +def pattern_to_string(node: RegexNode, max_length: int) -> str: + """Convert a node to a representative string.""" + result = generate_from_node(node, max_length) + return result if result else "test" + + +def get_replacement_char(original: str) -> str: + """Get a replacement character different from the original.""" + if original.isdigit(): + return random.choice([c for c in string.digits if c != original]) + if original.isalpha(): + return random.choice([c for c in string.ascii_letters if c.lower() != original.lower()]) + if original == ' ': + return random.choice(['\\t', '\\n']) + return 'x' + + +def get_opposite_char_class(char: str) -> str: + """Get a character from a different class.""" + if char.isdigit(): + return random.choice(string.ascii_letters) + if char.isalpha(): + return random.choice(string.digits) + if char == ' ': + return 'x' + return '1' + + +def matches_pattern(pattern: str, text: str, flavor: str) -> bool: + """Check if text matches pattern.""" + import re + try: + flags = 0 + if flavor == "python": + pass + elif flavor == "javascript": + flags = re.MULTILINE + elif flavor == "pcre": + flags = re.MULTILINE + + compiled = re.compile(pattern, flags) + return compiled.search(text) is not None + except re.error: + return False + + +def generate_test_cases( + pattern: str, + flavor: str = "pcre", + matching_count: int = 5, + non_matching_count: int = 5 +) -> dict: + """Generate all test cases for a pattern.""" + generator = TestCaseGenerator(flavor) + + return { + "pattern": pattern, + "flavor": flavor, + "matching": generator.generate_matching(pattern, matching_count), + "non_matching": generator.generate_non_matching(pattern, non_matching_count) + }