diff --git a/regex_humanizer/examples/generator.py b/regex_humanizer/examples/generator.py new file mode 100644 index 0000000..3ebbad0 --- /dev/null +++ b/regex_humanizer/examples/generator.py @@ -0,0 +1,241 @@ +import random +import re +import string +from typing import List + +from ..parser import ( + Alternation, + Anchor, + ASTNode, + Backreference, + CharacterClass, + Group, + Literal, + Quantifier, + SpecialSequence, + parse_regex, +) + + +DIGITS = string.digits +UPPERCASE = string.ascii_uppercase +LOWERCASE = string.ascii_lowercase +WHITESPACE = " \t\n\r" +PUNCTUATION = "!@#$%^&*()_+-=[]{}|;:,.<>?" + + +def generate_literal_example(node: Literal) -> str: + return node.value + + +def generate_character_class_example(node: CharacterClass) -> str: + options = [] + + for char in node.characters: + if char in r"-\] ": + options.append("\\" + char) + elif char == "\t": + options.append("\\t") + elif char == "\n": + options.append("\\n") + elif char == "\r": + options.append("\\r") + else: + options.append(char) + + for start, end in node.ranges: + start_ord = ord(start) + end_ord = ord(end) + for i in range(start_ord, min(end_ord + 1, start_ord + 10)): + options.append(chr(i)) + + if not options: + return "" + + return random.choice(options) + + +def generate_special_sequence_example(node: SpecialSequence) -> str: + sequences = { + ".": random.choice(string.ascii_letters + string.digits + "!@#$"), + r"\d": random.choice(DIGITS), + r"\D": random.choice(UPPERCASE + LOWERCASE + PUNCTUATION + WHITESPACE), + r"\w": random.choice(string.ascii_letters + string.digits + "_"), + r"\W": random.choice(PUNCTUATION + WHITESPACE), + r"\s": random.choice(WHITESPACE), + r"\S": random.choice(string.ascii_letters + string.digits + PUNCTUATION), + r"\b": "", + r"\B": "", + r"^": "", + r"$": "", + } + return sequences.get(node.sequence, node.sequence) + + +def generate_anchor_example(node: Anchor) -> str: + return "" + + +def generate_quantifier_example(node: Quantifier) -> str: + if not hasattr(node, 'child') or not node.child: + return "*" + + if isinstance(node.child, SpecialSequence): + if node.child.sequence in (r"\d", r"\D", r"\w", r"\W", r"\s", r"\S"): + word_chars = string.ascii_letters + string.digits + "_" + if node.child.sequence == r"\d": + chars = string.digits + elif node.child.sequence == r"\D": + chars = string.ascii_letters + string.punctuation + string.whitespace + elif node.child.sequence == r"\w": + chars = word_chars + elif node.child.sequence == r"\W": + chars = string.punctuation + string.whitespace + elif node.child.sequence == r"\s": + chars = string.whitespace + else: + chars = string.ascii_letters + string.digits + string.punctuation + + if node.min == 0 and node.max == 1: + return random.choice(["", random.choice(chars)]) + elif node.min == 0 and node.max == Quantifier.MAX_UNBOUNDED: + count = random.randint(0, 4) + elif node.min == 1 and node.max == Quantifier.MAX_UNBOUNDED: + count = random.randint(1, 4) + elif node.min == node.max: + count = node.min + elif node.max == Quantifier.MAX_UNBOUNDED: + count = random.randint(node.min, node.min + 3) + else: + count = random.randint(node.min, node.max) + return "".join(random.choice(chars) for _ in range(count)) + + child_example = generate_node_example(node.child) + + if node.min == 0 and node.max == 1: + return random.choice(["", child_example]) + elif node.min == 0 and node.max == Quantifier.MAX_UNBOUNDED: + count = random.randint(0, 4) + return child_example * count + elif node.min == 1 and node.max == Quantifier.MAX_UNBOUNDED: + count = random.randint(1, 4) + return child_example * count + elif node.min == node.max: + return child_example * node.min + elif node.max == Quantifier.MAX_UNBOUNDED: + count = random.randint(node.min, node.min + 3) + return child_example * count + else: + count = random.randint(node.min, node.max) + return child_example * count + + +def generate_group_example(node: Group) -> str: + return "".join(generate_node_example(child) for child in node.content) + + +def generate_alternation_example(node: Alternation) -> str: + if not node.options: + return "" + + non_empty_options = [opt for opt in node.options if opt] + if not non_empty_options: + return "" + + option = random.choice(non_empty_options) + return "".join(generate_node_example(child) for child in option) + + +def generate_backreference_example(node: Backreference) -> str: + return "[reference]" + + +def generate_node_example(node: ASTNode) -> str: + if isinstance(node, Literal): + return generate_literal_example(node) + elif isinstance(node, CharacterClass): + return generate_character_class_example(node) + elif isinstance(node, SpecialSequence): + return generate_special_sequence_example(node) + elif isinstance(node, Anchor): + return generate_anchor_example(node) + elif isinstance(node, Quantifier): + return generate_quantifier_example(node) + elif isinstance(node, Group): + return generate_group_example(node) + elif isinstance(node, Alternation): + return generate_alternation_example(node) + elif isinstance(node, Backreference): + return generate_backreference_example(node) + else: + return "" + + +def generate_examples(pattern: str, count: int = 5, flavor: str = "pcre") -> List[str]: + try: + ast = parse_regex(pattern) + examples = set() + + for _ in range(count * 3): + if len(examples) >= count: + break + + example = "".join(generate_node_example(node) for node in ast) + if example: + examples.add(example) + + if len(examples) >= count: + break + + if len(examples) < count: + test_strings = [ + "abc123", + "test@example.com", + "hello world", + "123-456-7890", + "https://example.com", + "foo bar baz", + "12345", + "ABCdef", + "word1 word2", + "special!@#chars", + ] + try: + compiled = re.compile(pattern) + for test_str in test_strings: + if len(examples) >= count: + break + match = compiled.search(test_str) + if match: + examples.add(match.group(0)) + except re.error: + pass + + return list(examples)[:count] + except Exception: + return [] + + +def generate_match_examples(pattern: str, test_string: str, count: int = 5, flavor: str = "pcre") -> List[str]: + try: + compiled = re.compile(pattern) + matches = compiled.findall(test_string) + unique_matches = [] + seen = set() + + for match in matches: + if isinstance(match, tuple): + match_str = "".join(match) + else: + match_str = match + + if match_str not in seen: + seen.add(match_str) + unique_matches.append(match_str) + + if len(unique_matches) >= count: + break + + return unique_matches + except re.error: + return []