663 lines
24 KiB
Python
663 lines
24 KiB
Python
from typing import Optional, Any
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
|
|
|
|
class NodeType(Enum):
|
|
LITERAL = "literal"
|
|
CHARACTER_CLASS = "character_class"
|
|
POSITIVE_SET = "positive_set"
|
|
NEGATIVE_SET = "negative_set"
|
|
DOT = "dot"
|
|
GROUP = "group"
|
|
CAPTURING_GROUP = "capturing_group"
|
|
NON_CAPTURING_GROUP = "non_capturing_group"
|
|
NAMED_GROUP = "named_group"
|
|
LOOKAHEAD = "lookahead"
|
|
LOOKBEHIND = "lookbehind"
|
|
NEGATIVE_LOOKAHEAD = "negative_lookahead"
|
|
NEGATIVE_LOOKBEHIND = "negative_lookbehind"
|
|
QUANTIFIER = "quantifier"
|
|
ANCHOR_START = "anchor_start"
|
|
ANCHOR_END = "anchor_end"
|
|
WORD_BOUNDARY = "word_boundary"
|
|
NON_WORD_BOUNDARY = "non_word_boundary"
|
|
START_OF_STRING = "start_of_string"
|
|
END_OF_STRING = "end_of_string"
|
|
END_OF_STRING_Z = "end_of_string_z"
|
|
ANY_NEWLINE = "any_newline"
|
|
CONTROL_CHAR = "control_char"
|
|
ESCAPED_CHAR = "escaped_char"
|
|
HEX_ESCAPE = "hex_escape"
|
|
OCTAL_ESCAPE = "octal_escape"
|
|
UNICODE_PROPERTY = "unicode_property"
|
|
BACKREFERENCE = "backreference"
|
|
BRANCH = "branch"
|
|
SEQUENCE = "sequence"
|
|
DIGIT = "digit"
|
|
NON_DIGIT = "non_digit"
|
|
WORD_CHAR = "word_char"
|
|
NON_WORD_CHAR = "non_word_char"
|
|
WHITESPACE = "whitespace"
|
|
NON_WHITESPACE = "non_whitespace"
|
|
|
|
|
|
@dataclass
|
|
class RegexNode:
|
|
"""Base class for regex AST nodes."""
|
|
node_type: NodeType
|
|
children: list["RegexNode"] = field(default_factory=list)
|
|
raw: str = ""
|
|
position: int = 0
|
|
|
|
|
|
@dataclass
|
|
class LiteralNode(RegexNode):
|
|
"""Represents a literal character or string."""
|
|
value: str = ""
|
|
|
|
|
|
@dataclass
|
|
class CharacterClassNode(RegexNode):
|
|
"""Represents a character class like [a-z]."""
|
|
negated: bool = False
|
|
ranges: list[tuple[str, str]] = field(default_factory=list)
|
|
characters: str = ""
|
|
|
|
|
|
@dataclass
|
|
class QuantifierNode(RegexNode):
|
|
"""Represents a quantifier like *, +, ?, {n,m}."""
|
|
min_count: Optional[int] = None
|
|
max_count: Any = None
|
|
is_lazy: bool = False
|
|
is_possessive: bool = False
|
|
|
|
|
|
@dataclass
|
|
class GroupNode(RegexNode):
|
|
"""Represents a group."""
|
|
name: Optional[str] = None
|
|
group_index: Optional[int] = None
|
|
is_non_capturing: bool = False
|
|
|
|
|
|
class RegexParser:
|
|
"""Parser for regex patterns that builds an AST."""
|
|
|
|
def __init__(self, pattern: str, flavor: str = "pcre"):
|
|
self.pattern = pattern
|
|
self.flavor = flavor
|
|
self.pos = 0
|
|
self.length = len(pattern)
|
|
self._errors: list[str] = []
|
|
|
|
def parse(self) -> RegexNode:
|
|
"""Parse the entire pattern into an AST."""
|
|
self.pos = 0
|
|
self._errors = []
|
|
result = self._parse_sequence()
|
|
if self.pos < self.length:
|
|
remaining = self.pattern[self.pos:]
|
|
self._errors.append(f"Unexpected content at position {self.pos}: {remaining[:20]}")
|
|
return result
|
|
|
|
def _parse_sequence(self) -> RegexNode:
|
|
"""Parse a sequence of regex elements."""
|
|
children = []
|
|
start_pos = self.pos
|
|
while self.pos < self.length:
|
|
char = self.pattern[self.pos]
|
|
if char == ')':
|
|
break
|
|
elif char == '\\':
|
|
node = self._parse_escape()
|
|
if node:
|
|
children.append(node)
|
|
elif char == '[':
|
|
node = self._parse_character_class()
|
|
if node:
|
|
children.append(node)
|
|
elif char == '.':
|
|
children.append(RegexNode(
|
|
node_type=NodeType.DOT,
|
|
raw=char,
|
|
position=self.pos
|
|
))
|
|
self.pos += 1
|
|
elif char == '(':
|
|
node = self._parse_group()
|
|
if node:
|
|
children.append(node)
|
|
elif char == '|':
|
|
self.pos += 1
|
|
first_alt_children = []
|
|
if children and children[-1].node_type == NodeType.BRANCH:
|
|
first_alt_children = children[-1].children
|
|
else:
|
|
first_alt_children = children[:]
|
|
children.clear()
|
|
|
|
alt_children = first_alt_children
|
|
while self.pos < self.length and self.pattern[self.pos] != ')' and self.pattern[self.pos] != '|':
|
|
char = self.pattern[self.pos]
|
|
if char == '\\':
|
|
node = self._parse_escape()
|
|
if node:
|
|
alt_children.append(node)
|
|
elif char == '[':
|
|
node = self._parse_character_class()
|
|
if node:
|
|
alt_children.append(node)
|
|
elif char == '.':
|
|
alt_children.append(RegexNode(
|
|
node_type=NodeType.DOT,
|
|
raw=char,
|
|
position=self.pos
|
|
))
|
|
self.pos += 1
|
|
elif char == '(':
|
|
node = self._parse_group()
|
|
if node:
|
|
alt_children.append(node)
|
|
elif char in '*+?{':
|
|
if alt_children:
|
|
prev = alt_children.pop()
|
|
if char == '{':
|
|
node = self._parse_quantifier(char, prev)
|
|
else:
|
|
node = self._parse_quantifier(char, prev)
|
|
if node:
|
|
alt_children.append(node)
|
|
else:
|
|
alt_children.append(prev)
|
|
self.pos += 1
|
|
elif char == ')':
|
|
break
|
|
else:
|
|
literal = char
|
|
self.pos += 1
|
|
while self.pos < self.length and self.pattern[self.pos] not in r')|*+?[\.^{$':
|
|
literal += self.pattern[self.pos]
|
|
self.pos += 1
|
|
alt_children.append(LiteralNode(
|
|
node_type=NodeType.LITERAL,
|
|
value=literal,
|
|
raw=literal,
|
|
position=self.pos - len(literal)
|
|
))
|
|
|
|
if children and children[-1].node_type == NodeType.BRANCH:
|
|
pass
|
|
else:
|
|
branch = RegexNode(
|
|
node_type=NodeType.BRANCH,
|
|
children=first_alt_children,
|
|
raw='|',
|
|
position=self.pos - 1
|
|
)
|
|
children.append(branch)
|
|
elif char in '^$':
|
|
if char == '^':
|
|
children.append(RegexNode(
|
|
node_type=NodeType.ANCHOR_START,
|
|
raw=char,
|
|
position=self.pos
|
|
))
|
|
else:
|
|
children.append(RegexNode(
|
|
node_type=NodeType.ANCHOR_END,
|
|
raw=char,
|
|
position=self.pos
|
|
))
|
|
self.pos += 1
|
|
elif char in '*+?':
|
|
node = self._parse_quantifier(char, children.pop() if children else None)
|
|
if node:
|
|
children.append(node)
|
|
else:
|
|
self._errors.append(f"Quantifier '{char}' without preceding element at position {self.pos}")
|
|
self.pos += 1
|
|
elif char == '{':
|
|
if children:
|
|
node = self._parse_quantifier(char, children.pop())
|
|
if node:
|
|
children.append(node)
|
|
else:
|
|
self._errors.append(f"Invalid quantifier at position {self.pos}")
|
|
self.pos += 1
|
|
else:
|
|
self._errors.append(f"Quantifier '{{' without preceding element at position {self.pos}")
|
|
self.pos += 1
|
|
else:
|
|
literal = char
|
|
self.pos += 1
|
|
while self.pos < self.length and self.pattern[self.pos] not in r')|*+?[\.^{$':
|
|
char = self.pattern[self.pos]
|
|
if char == '\\':
|
|
if self.pos + 1 < self.length:
|
|
literal += char + self.pattern[self.pos + 1]
|
|
self.pos += 2
|
|
else:
|
|
literal += char
|
|
self.pos += 1
|
|
else:
|
|
literal += char
|
|
self.pos += 1
|
|
children.append(LiteralNode(
|
|
node_type=NodeType.LITERAL,
|
|
value=literal,
|
|
raw=literal,
|
|
position=self.pos - len(literal)
|
|
))
|
|
|
|
end_pos = self.pos
|
|
return RegexNode(
|
|
node_type=NodeType.SEQUENCE,
|
|
children=children,
|
|
raw=self.pattern[start_pos:end_pos],
|
|
position=start_pos
|
|
)
|
|
|
|
def _parse_escape(self) -> Optional[RegexNode]:
|
|
"""Parse an escape sequence."""
|
|
if self.pos + 1 >= self.length:
|
|
return None
|
|
|
|
self.pos += 1
|
|
char = self.pattern[self.pos]
|
|
self.pos += 1
|
|
|
|
escaped_chars = {
|
|
'd': ('digit', '\\d'),
|
|
'D': ('non_digit', '\\D'),
|
|
'w': ('word_char', '\\w'),
|
|
'W': ('non_word_char', '\\W'),
|
|
's': ('whitespace', '\\s'),
|
|
'S': ('non_whitespace', '\\S'),
|
|
'b': ('word_boundary', '\\b'),
|
|
'B': ('non_word_boundary', '\\B'),
|
|
}
|
|
|
|
if char in escaped_chars:
|
|
node_type_name, raw = escaped_chars[char]
|
|
return RegexNode(
|
|
node_type=NodeType(node_type_name),
|
|
raw=f'\\{char}',
|
|
position=self.pos - 2
|
|
)
|
|
|
|
special_escaped = {
|
|
'.': '.',
|
|
'*': '*',
|
|
'+': '+',
|
|
'?': '?',
|
|
'^': '^',
|
|
'$': '$',
|
|
'|': '|',
|
|
'(': '(',
|
|
')': ')',
|
|
'[': '[',
|
|
']': ']',
|
|
'{': '{',
|
|
'}': '}',
|
|
'\\': '\\',
|
|
'-': '-',
|
|
'n': '\n',
|
|
'r': '\r',
|
|
't': '\t',
|
|
}
|
|
|
|
if char in special_escaped:
|
|
return LiteralNode(
|
|
node_type=NodeType.ESCAPED_CHAR,
|
|
value=special_escaped[char],
|
|
raw=f'\\{char}',
|
|
position=self.pos - 2
|
|
)
|
|
|
|
if char == '0':
|
|
return RegexNode(
|
|
node_type=NodeType.OCTAL_ESCAPE,
|
|
raw=f'\\{char}',
|
|
position=self.pos - 2
|
|
)
|
|
|
|
if char == 'x':
|
|
if self.pos + 2 <= self.length:
|
|
hex_part = self.pattern[self.pos:self.pos + 2]
|
|
if all(c in '0123456789abcdefABCDEF' for c in hex_part):
|
|
self.pos += 2
|
|
return RegexNode(
|
|
node_type=NodeType.HEX_ESCAPE,
|
|
raw=f'\\x{hex_part}',
|
|
position=self.pos - 4
|
|
)
|
|
|
|
if char == 'u':
|
|
if self.pos + 4 <= self.length:
|
|
hex_part = self.pattern[self.pos:self.pos + 4]
|
|
if all(c in '0123456789abcdefABCDEF' for c in hex_part):
|
|
self.pos += 4
|
|
return RegexNode(
|
|
node_type=NodeType.UNICODE_PROPERTY,
|
|
raw=f'\\u{hex_part}',
|
|
position=self.pos - 6
|
|
)
|
|
|
|
if char == 'p':
|
|
if self.pos < self.length and self.pattern[self.pos] == '{':
|
|
end = self.pattern.find('}', self.pos + 1)
|
|
if end != -1:
|
|
prop = self.pattern[self.pos + 1:end]
|
|
self.pos = end + 1
|
|
return RegexNode(
|
|
node_type=NodeType.UNICODE_PROPERTY,
|
|
raw=f'\\p{{{prop}}}',
|
|
position=self.pos - len(f'\\p{{{prop}}}')
|
|
)
|
|
|
|
if char == 'c':
|
|
if self.pos < self.length:
|
|
ctrl_char = self.pattern[self.pos]
|
|
self.pos += 1
|
|
return RegexNode(
|
|
node_type=NodeType.CONTROL_CHAR,
|
|
raw=f'\\c{ctrl_char}',
|
|
position=self.pos - 3
|
|
)
|
|
|
|
if char.isdigit():
|
|
backref = char
|
|
while self.pos < self.length and self.pattern[self.pos].isdigit():
|
|
backref += self.pattern[self.pos]
|
|
self.pos += 1
|
|
return RegexNode(
|
|
node_type=NodeType.BACKREFERENCE,
|
|
raw=f'\\{backref}',
|
|
position=self.pos - len(backref) - 1
|
|
)
|
|
|
|
return LiteralNode(
|
|
node_type=NodeType.ESCAPED_CHAR,
|
|
value=char,
|
|
raw=f'\\{char}',
|
|
position=self.pos - 2
|
|
)
|
|
|
|
def _parse_character_class(self) -> Optional[RegexNode]:
|
|
"""Parse a character class like [a-z] or [^a-z]."""
|
|
if self.pos >= self.length or self.pattern[self.pos] != '[':
|
|
return None
|
|
|
|
start_pos = self.pos
|
|
self.pos += 1
|
|
|
|
negated = False
|
|
if self.pos < self.length and self.pattern[self.pos] == '^':
|
|
negated = True
|
|
self.pos += 1
|
|
elif self.pos < self.length and self.pattern[self.pos] == ']':
|
|
self.pos += 1
|
|
|
|
ranges = []
|
|
characters = ""
|
|
|
|
while self.pos < self.length:
|
|
char = self.pattern[self.pos]
|
|
|
|
if char == ']':
|
|
self.pos += 1
|
|
break
|
|
elif char == '\\':
|
|
if self.pos + 1 < self.length:
|
|
next_char = self.pattern[self.pos + 1]
|
|
if next_char == 'd' or next_char == 'D':
|
|
self.pos += 2
|
|
elif next_char == 'w' or next_char == 'W':
|
|
self.pos += 2
|
|
elif next_char == 's' or next_char == 'S':
|
|
self.pos += 2
|
|
else:
|
|
self.pos += 2
|
|
characters += next_char
|
|
else:
|
|
self.pos += 1
|
|
elif char == '-' and characters and self.pos + 1 < self.length and self.pattern[self.pos + 1] != ']':
|
|
self.pos += 1
|
|
end_char = self.pattern[self.pos]
|
|
self.pos += 1
|
|
if characters[-1]:
|
|
ranges.append((characters[-1], end_char))
|
|
characters = characters[:-1]
|
|
else:
|
|
characters += char
|
|
self.pos += 1
|
|
|
|
node = CharacterClassNode(
|
|
node_type=NodeType.NEGATIVE_SET if negated else NodeType.POSITIVE_SET,
|
|
negated=negated,
|
|
ranges=ranges,
|
|
characters=characters,
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos
|
|
)
|
|
|
|
return node
|
|
|
|
def _parse_group(self) -> Optional[RegexNode]:
|
|
"""Parse a group like (?:...) or (?<name>...) or (?=...)."""
|
|
if self.pos >= self.length or self.pattern[self.pos] != '(':
|
|
return None
|
|
|
|
start_pos = self.pos
|
|
self.pos += 1
|
|
|
|
if self.pos < self.length and self.pattern[self.pos] == '?':
|
|
self.pos += 1
|
|
|
|
if self.pos < self.length:
|
|
next_char = self.pattern[self.pos]
|
|
|
|
if next_char == '=':
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.LOOKAHEAD,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
is_non_capturing=True
|
|
)
|
|
elif next_char == '!':
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.NEGATIVE_LOOKAHEAD,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
is_non_capturing=True
|
|
)
|
|
elif next_char == '<':
|
|
self.pos += 1
|
|
if self.pos < self.length:
|
|
if self.pattern[self.pos] == '=':
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.LOOKBEHIND,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
is_non_capturing=True
|
|
)
|
|
elif self.pattern[self.pos] == '!':
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.NEGATIVE_LOOKBEHIND,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
is_non_capturing=True
|
|
)
|
|
else:
|
|
name_start = self.pos
|
|
while self.pos < self.length and self.pattern[self.pos] != '>':
|
|
self.pos += 1
|
|
name = self.pattern[name_start:self.pos]
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.NAMED_GROUP,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
name=name,
|
|
is_non_capturing=False
|
|
)
|
|
elif next_char == ':':
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.NON_CAPTURING_GROUP,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
is_non_capturing=True
|
|
)
|
|
elif next_char == '#':
|
|
comment_end = self.pattern.find(')', self.pos)
|
|
if comment_end != -1:
|
|
self.pos = comment_end + 1
|
|
children = self._parse_sequence()
|
|
return RegexNode(
|
|
node_type=NodeType.NON_CAPTURING_GROUP,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos
|
|
)
|
|
elif next_char == 'P':
|
|
self.pos += 1
|
|
if self.pos < self.length and self.pattern[self.pos] == '<':
|
|
name_start = self.pos + 1
|
|
name_end = self.pattern.find('>', name_start)
|
|
if name_end != -1:
|
|
name = self.pattern[name_start:name_end]
|
|
self.pos = name_end + 1
|
|
children = self._parse_sequence()
|
|
return GroupNode(
|
|
node_type=NodeType.NAMED_GROUP,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
name=name,
|
|
is_non_capturing=False
|
|
)
|
|
elif next_char in 'iDsx':
|
|
self.pos += 1
|
|
children = self._parse_sequence()
|
|
return RegexNode(
|
|
node_type=NodeType.NON_CAPTURING_GROUP,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos
|
|
)
|
|
|
|
children = self._parse_sequence()
|
|
|
|
if self.pos < self.length and self.pattern[self.pos] == ')':
|
|
self.pos += 1
|
|
|
|
return GroupNode(
|
|
node_type=NodeType.CAPTURING_GROUP,
|
|
children=[children],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
is_non_capturing=False
|
|
)
|
|
|
|
def _parse_quantifier(self, char: str, node: Optional[RegexNode]) -> Optional[RegexNode]:
|
|
"""Parse a quantifier like *, +, ?, {n,m}."""
|
|
if node is None:
|
|
return None
|
|
|
|
start_pos = self.pos
|
|
is_lazy = False
|
|
is_possessive = False
|
|
|
|
if char in '*+?':
|
|
self.pos += 1
|
|
if char == '*':
|
|
min_count = 0
|
|
max_count = float('inf')
|
|
elif char == '+':
|
|
min_count = 1
|
|
max_count = float('inf')
|
|
else:
|
|
min_count = 0
|
|
max_count = 1
|
|
|
|
if self.pos < self.length and self.pattern[self.pos] in '?+':
|
|
modifier = self.pattern[self.pos]
|
|
if modifier == '?':
|
|
is_lazy = True
|
|
elif modifier == '+':
|
|
is_possessive = True
|
|
self.pos += 1
|
|
|
|
elif char == '{':
|
|
end = self.pattern.find('}', self.pos + 1)
|
|
if end == -1:
|
|
return None
|
|
|
|
quant_content = self.pattern[self.pos + 1:end]
|
|
self.pos = end + 1
|
|
|
|
parts = quant_content.split(',')
|
|
min_count = int(parts[0])
|
|
if len(parts) > 1 and parts[1].strip():
|
|
max_count = int(parts[1])
|
|
else:
|
|
max_count = min_count
|
|
|
|
if self.pos < self.length and self.pattern[self.pos] in '?+':
|
|
modifier = self.pattern[self.pos]
|
|
if modifier == '?':
|
|
is_lazy = True
|
|
elif modifier == '+':
|
|
is_possessive = True
|
|
self.pos += 1
|
|
|
|
|
|
else:
|
|
return None
|
|
|
|
result = QuantifierNode(
|
|
node_type=NodeType.QUANTIFIER,
|
|
children=[node] if node else [],
|
|
raw=self.pattern[start_pos:self.pos],
|
|
position=start_pos,
|
|
min_count=min_count,
|
|
max_count=max_count,
|
|
is_lazy=is_lazy,
|
|
is_possessive=is_possessive
|
|
)
|
|
|
|
if node:
|
|
result.children = [node]
|
|
|
|
return result
|
|
|
|
def get_errors(self) -> list[str]:
|
|
"""Return any parsing errors."""
|
|
return self._errors
|
|
|
|
|
|
def parse_regex(pattern: str, flavor: str = "pcre") -> RegexNode:
|
|
"""Parse a regex pattern into an AST."""
|
|
parser = RegexParser(pattern, flavor)
|
|
ast = parser.parse()
|
|
return ast
|