from typing import Optional, Any from dataclasses import dataclass, field from enum import Enum class NodeType(Enum): LITERAL = "literal" CHARACTER_CLASS = "character_class" POSITIVE_SET = "positive_set" NEGATIVE_SET = "negative_set" DOT = "dot" GROUP = "group" CAPTURING_GROUP = "capturing_group" NON_CAPTURING_GROUP = "non_capturing_group" NAMED_GROUP = "named_group" LOOKAHEAD = "lookahead" LOOKBEHIND = "lookbehind" NEGATIVE_LOOKAHEAD = "negative_lookahead" NEGATIVE_LOOKBEHIND = "negative_lookbehind" QUANTIFIER = "quantifier" ANCHOR_START = "anchor_start" ANCHOR_END = "anchor_end" WORD_BOUNDARY = "word_boundary" NON_WORD_BOUNDARY = "non_word_boundary" START_OF_STRING = "start_of_string" END_OF_STRING = "end_of_string" END_OF_STRING_Z = "end_of_string_z" ANY_NEWLINE = "any_newline" CONTROL_CHAR = "control_char" ESCAPED_CHAR = "escaped_char" HEX_ESCAPE = "hex_escape" OCTAL_ESCAPE = "octal_escape" UNICODE_PROPERTY = "unicode_property" BACKREFERENCE = "backreference" BRANCH = "branch" SEQUENCE = "sequence" DIGIT = "digit" NON_DIGIT = "non_digit" WORD_CHAR = "word_char" NON_WORD_CHAR = "non_word_char" WHITESPACE = "whitespace" NON_WHITESPACE = "non_whitespace" @dataclass class RegexNode: """Base class for regex AST nodes.""" node_type: NodeType children: list["RegexNode"] = field(default_factory=list) raw: str = "" position: int = 0 @dataclass class LiteralNode(RegexNode): """Represents a literal character or string.""" value: str = "" @dataclass class CharacterClassNode(RegexNode): """Represents a character class like [a-z].""" negated: bool = False ranges: list[tuple[str, str]] = field(default_factory=list) characters: str = "" @dataclass class QuantifierNode(RegexNode): """Represents a quantifier like *, +, ?, {n,m}.""" min_count: Optional[int] = None max_count: Any = None is_lazy: bool = False is_possessive: bool = False @dataclass class GroupNode(RegexNode): """Represents a group.""" name: Optional[str] = None group_index: Optional[int] = None is_non_capturing: bool = False class RegexParser: """Parser for regex patterns that builds an AST.""" def __init__(self, pattern: str, flavor: str = "pcre"): self.pattern = pattern self.flavor = flavor self.pos = 0 self.length = len(pattern) self._errors: list[str] = [] def parse(self) -> RegexNode: """Parse the entire pattern into an AST.""" self.pos = 0 self._errors = [] result = self._parse_sequence() if self.pos < self.length: remaining = self.pattern[self.pos:] self._errors.append(f"Unexpected content at position {self.pos}: {remaining[:20]}") return result def _parse_sequence(self) -> RegexNode: """Parse a sequence of regex elements.""" children = [] start_pos = self.pos while self.pos < self.length: char = self.pattern[self.pos] if char == ')': break elif char == '\\': node = self._parse_escape() if node: children.append(node) elif char == '[': node = self._parse_character_class() if node: children.append(node) elif char == '.': children.append(RegexNode( node_type=NodeType.DOT, raw=char, position=self.pos )) self.pos += 1 elif char == '(': node = self._parse_group() if node: children.append(node) elif char == '|': self.pos += 1 first_alt_children = [] if children and children[-1].node_type == NodeType.BRANCH: first_alt_children = children[-1].children else: first_alt_children = children[:] children.clear() alt_children = first_alt_children while self.pos < self.length and self.pattern[self.pos] != ')' and self.pattern[self.pos] != '|': char = self.pattern[self.pos] if char == '\\': node = self._parse_escape() if node: alt_children.append(node) elif char == '[': node = self._parse_character_class() if node: alt_children.append(node) elif char == '.': alt_children.append(RegexNode( node_type=NodeType.DOT, raw=char, position=self.pos )) self.pos += 1 elif char == '(': node = self._parse_group() if node: alt_children.append(node) elif char in '*+?{': if alt_children: prev = alt_children.pop() if char == '{': node = self._parse_quantifier(char, prev) else: node = self._parse_quantifier(char, prev) if node: alt_children.append(node) else: alt_children.append(prev) self.pos += 1 elif char == ')': break else: literal = char self.pos += 1 while self.pos < self.length and self.pattern[self.pos] not in r')|*+?[\.^{$': literal += self.pattern[self.pos] self.pos += 1 alt_children.append(LiteralNode( node_type=NodeType.LITERAL, value=literal, raw=literal, position=self.pos - len(literal) )) if children and children[-1].node_type == NodeType.BRANCH: pass else: branch = RegexNode( node_type=NodeType.BRANCH, children=first_alt_children, raw='|', position=self.pos - 1 ) children.append(branch) elif char in '^$': if char == '^': children.append(RegexNode( node_type=NodeType.ANCHOR_START, raw=char, position=self.pos )) else: children.append(RegexNode( node_type=NodeType.ANCHOR_END, raw=char, position=self.pos )) self.pos += 1 elif char in '*+?': node = self._parse_quantifier(char, children.pop() if children else None) if node: children.append(node) else: self._errors.append(f"Quantifier '{char}' without preceding element at position {self.pos}") self.pos += 1 elif char == '{': if children: node = self._parse_quantifier(char, children.pop()) if node: children.append(node) else: self._errors.append(f"Invalid quantifier at position {self.pos}") self.pos += 1 else: self._errors.append(f"Quantifier '{{' without preceding element at position {self.pos}") self.pos += 1 else: literal = char self.pos += 1 while self.pos < self.length and self.pattern[self.pos] not in r')|*+?[\.^{$': char = self.pattern[self.pos] if char == '\\': if self.pos + 1 < self.length: literal += char + self.pattern[self.pos + 1] self.pos += 2 else: literal += char self.pos += 1 else: literal += char self.pos += 1 children.append(LiteralNode( node_type=NodeType.LITERAL, value=literal, raw=literal, position=self.pos - len(literal) )) end_pos = self.pos return RegexNode( node_type=NodeType.SEQUENCE, children=children, raw=self.pattern[start_pos:end_pos], position=start_pos ) def _parse_escape(self) -> Optional[RegexNode]: """Parse an escape sequence.""" if self.pos + 1 >= self.length: return None self.pos += 1 char = self.pattern[self.pos] self.pos += 1 escaped_chars = { 'd': ('digit', '\\d'), 'D': ('non_digit', '\\D'), 'w': ('word_char', '\\w'), 'W': ('non_word_char', '\\W'), 's': ('whitespace', '\\s'), 'S': ('non_whitespace', '\\S'), 'b': ('word_boundary', '\\b'), 'B': ('non_word_boundary', '\\B'), } if char in escaped_chars: node_type_name, raw = escaped_chars[char] return RegexNode( node_type=NodeType(node_type_name), raw=f'\\{char}', position=self.pos - 2 ) special_escaped = { '.': '.', '*': '*', '+': '+', '?': '?', '^': '^', '$': '$', '|': '|', '(': '(', ')': ')', '[': '[', ']': ']', '{': '{', '}': '}', '\\': '\\', '-': '-', 'n': '\n', 'r': '\r', 't': '\t', } if char in special_escaped: return LiteralNode( node_type=NodeType.ESCAPED_CHAR, value=special_escaped[char], raw=f'\\{char}', position=self.pos - 2 ) if char == '0': return RegexNode( node_type=NodeType.OCTAL_ESCAPE, raw=f'\\{char}', position=self.pos - 2 ) if char == 'x': if self.pos + 2 <= self.length: hex_part = self.pattern[self.pos:self.pos + 2] if all(c in '0123456789abcdefABCDEF' for c in hex_part): self.pos += 2 return RegexNode( node_type=NodeType.HEX_ESCAPE, raw=f'\\x{hex_part}', position=self.pos - 4 ) if char == 'u': if self.pos + 4 <= self.length: hex_part = self.pattern[self.pos:self.pos + 4] if all(c in '0123456789abcdefABCDEF' for c in hex_part): self.pos += 4 return RegexNode( node_type=NodeType.UNICODE_PROPERTY, raw=f'\\u{hex_part}', position=self.pos - 6 ) if char == 'p': if self.pos < self.length and self.pattern[self.pos] == '{': end = self.pattern.find('}', self.pos + 1) if end != -1: prop = self.pattern[self.pos + 1:end] self.pos = end + 1 return RegexNode( node_type=NodeType.UNICODE_PROPERTY, raw=f'\\p{{{prop}}}', position=self.pos - len(f'\\p{{{prop}}}') ) if char == 'c': if self.pos < self.length: ctrl_char = self.pattern[self.pos] self.pos += 1 return RegexNode( node_type=NodeType.CONTROL_CHAR, raw=f'\\c{ctrl_char}', position=self.pos - 3 ) if char.isdigit(): backref = char while self.pos < self.length and self.pattern[self.pos].isdigit(): backref += self.pattern[self.pos] self.pos += 1 return RegexNode( node_type=NodeType.BACKREFERENCE, raw=f'\\{backref}', position=self.pos - len(backref) - 1 ) return LiteralNode( node_type=NodeType.ESCAPED_CHAR, value=char, raw=f'\\{char}', position=self.pos - 2 ) def _parse_character_class(self) -> Optional[RegexNode]: """Parse a character class like [a-z] or [^a-z].""" if self.pos >= self.length or self.pattern[self.pos] != '[': return None start_pos = self.pos self.pos += 1 negated = False if self.pos < self.length and self.pattern[self.pos] == '^': negated = True self.pos += 1 elif self.pos < self.length and self.pattern[self.pos] == ']': self.pos += 1 ranges = [] characters = "" while self.pos < self.length: char = self.pattern[self.pos] if char == ']': self.pos += 1 break elif char == '\\': if self.pos + 1 < self.length: next_char = self.pattern[self.pos + 1] if next_char == 'd' or next_char == 'D': self.pos += 2 elif next_char == 'w' or next_char == 'W': self.pos += 2 elif next_char == 's' or next_char == 'S': self.pos += 2 else: self.pos += 2 characters += next_char else: self.pos += 1 elif char == '-' and characters and self.pos + 1 < self.length and self.pattern[self.pos + 1] != ']': self.pos += 1 end_char = self.pattern[self.pos] self.pos += 1 if characters[-1]: ranges.append((characters[-1], end_char)) characters = characters[:-1] else: characters += char self.pos += 1 node = CharacterClassNode( node_type=NodeType.NEGATIVE_SET if negated else NodeType.POSITIVE_SET, negated=negated, ranges=ranges, characters=characters, raw=self.pattern[start_pos:self.pos], position=start_pos ) return node def _parse_group(self) -> Optional[RegexNode]: """Parse a group like (?:...) or (?...) or (?=...).""" if self.pos >= self.length or self.pattern[self.pos] != '(': return None start_pos = self.pos self.pos += 1 if self.pos < self.length and self.pattern[self.pos] == '?': self.pos += 1 if self.pos < self.length: next_char = self.pattern[self.pos] if next_char == '=': self.pos += 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.LOOKAHEAD, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, is_non_capturing=True ) elif next_char == '!': self.pos += 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.NEGATIVE_LOOKAHEAD, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, is_non_capturing=True ) elif next_char == '<': self.pos += 1 if self.pos < self.length: if self.pattern[self.pos] == '=': self.pos += 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.LOOKBEHIND, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, is_non_capturing=True ) elif self.pattern[self.pos] == '!': self.pos += 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.NEGATIVE_LOOKBEHIND, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, is_non_capturing=True ) else: name_start = self.pos while self.pos < self.length and self.pattern[self.pos] != '>': self.pos += 1 name = self.pattern[name_start:self.pos] self.pos += 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.NAMED_GROUP, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, name=name, is_non_capturing=False ) elif next_char == ':': self.pos += 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.NON_CAPTURING_GROUP, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, is_non_capturing=True ) elif next_char == '#': comment_end = self.pattern.find(')', self.pos) if comment_end != -1: self.pos = comment_end + 1 children = self._parse_sequence() return RegexNode( node_type=NodeType.NON_CAPTURING_GROUP, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos ) elif next_char == 'P': self.pos += 1 if self.pos < self.length and self.pattern[self.pos] == '<': name_start = self.pos + 1 name_end = self.pattern.find('>', name_start) if name_end != -1: name = self.pattern[name_start:name_end] self.pos = name_end + 1 children = self._parse_sequence() return GroupNode( node_type=NodeType.NAMED_GROUP, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, name=name, is_non_capturing=False ) elif next_char in 'iDsx': self.pos += 1 children = self._parse_sequence() return RegexNode( node_type=NodeType.NON_CAPTURING_GROUP, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos ) children = self._parse_sequence() if self.pos < self.length and self.pattern[self.pos] == ')': self.pos += 1 return GroupNode( node_type=NodeType.CAPTURING_GROUP, children=[children], raw=self.pattern[start_pos:self.pos], position=start_pos, is_non_capturing=False ) def _parse_quantifier(self, char: str, node: Optional[RegexNode]) -> Optional[RegexNode]: """Parse a quantifier like *, +, ?, {n,m}.""" if node is None: return None start_pos = self.pos is_lazy = False is_possessive = False if char in '*+?': self.pos += 1 if char == '*': min_count = 0 max_count = float('inf') elif char == '+': min_count = 1 max_count = float('inf') else: min_count = 0 max_count = 1 if self.pos < self.length and self.pattern[self.pos] in '?+': modifier = self.pattern[self.pos] if modifier == '?': is_lazy = True elif modifier == '+': is_possessive = True self.pos += 1 elif char == '{': end = self.pattern.find('}', self.pos + 1) if end == -1: return None quant_content = self.pattern[self.pos + 1:end] self.pos = end + 1 parts = quant_content.split(',') min_count = int(parts[0]) if len(parts) > 1 and parts[1].strip(): max_count = int(parts[1]) else: max_count = min_count if self.pos < self.length and self.pattern[self.pos] in '?+': modifier = self.pattern[self.pos] if modifier == '?': is_lazy = True elif modifier == '+': is_possessive = True self.pos += 1 else: return None result = QuantifierNode( node_type=NodeType.QUANTIFIER, children=[node] if node else [], raw=self.pattern[start_pos:self.pos], position=start_pos, min_count=min_count, max_count=max_count, is_lazy=is_lazy, is_possessive=is_possessive ) if node: result.children = [node] return result def get_errors(self) -> list[str]: """Return any parsing errors.""" return self._errors def parse_regex(pattern: str, flavor: str = "pcre") -> RegexNode: """Parse a regex pattern into an AST.""" parser = RegexParser(pattern, flavor) ast = parser.parse() return ast