153 lines
3.8 KiB
Python
153 lines
3.8 KiB
Python
"""Abstract Syntax Tree nodes for regex patterns."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
@dataclass
|
|
class ASTNode(ABC):
|
|
"""Base class for all AST nodes."""
|
|
|
|
position: int = 0
|
|
|
|
@abstractmethod
|
|
def to_regex(self) -> str:
|
|
"""Convert the node back to a regex pattern."""
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class Literal(ASTNode):
|
|
"""A literal character or string."""
|
|
|
|
escaped: bool = False
|
|
value: str = ""
|
|
|
|
def to_regex(self) -> str:
|
|
if self.escaped or not self.value.isalnum():
|
|
escaped = ""
|
|
for char in self.value:
|
|
if char in r".^$*+?{}[]\|()":
|
|
escaped += "\\" + char
|
|
else:
|
|
escaped += char
|
|
return escaped
|
|
return self.value
|
|
|
|
|
|
@dataclass
|
|
class CharacterClass(ASTNode):
|
|
"""A character class like [a-z] or [^A-Z]."""
|
|
|
|
inverted: bool = False
|
|
characters: List[str] = field(default_factory=list)
|
|
ranges: List[tuple] = field(default_factory=list)
|
|
|
|
def to_regex(self) -> str:
|
|
content = ""
|
|
special_chars = {'^', '-', ']', '\\'}
|
|
for char in self.characters:
|
|
if char in special_chars:
|
|
content += "\\" + char
|
|
else:
|
|
content += char
|
|
for start, end in self.ranges:
|
|
content += f"{start}-{end}"
|
|
negated = "^" if self.inverted else ""
|
|
return f"[{negated}{content}]"
|
|
|
|
|
|
@dataclass
|
|
class Quantifier(ASTNode):
|
|
"""A quantifier like *, +, ?, {n}, {n,m}."""
|
|
|
|
MIN_UNBOUNDED = -1
|
|
MAX_UNBOUNDED = -1
|
|
|
|
min: int = 0
|
|
max: int = 1
|
|
lazy: bool = False
|
|
possessive: bool = False
|
|
child: Optional[ASTNode] = None
|
|
|
|
def to_regex(self) -> str:
|
|
base = ""
|
|
if self.min == 0 and self.max == 1:
|
|
base = "?"
|
|
elif self.min == 0 and self.max == self.MAX_UNBOUNDED:
|
|
base = "*"
|
|
elif self.min == 1 and self.max == self.MAX_UNBOUNDED:
|
|
base = "+"
|
|
elif self.min == self.max:
|
|
base = f"{{{self.min}}}"
|
|
elif self.max == self.MAX_UNBOUNDED:
|
|
base = f"{{{self.min},}}"
|
|
else:
|
|
base = f"{{{self.min},{self.max}}}"
|
|
modifier = "?" if self.lazy else ""
|
|
modifier = "+" if self.possessive else modifier
|
|
return base + modifier
|
|
|
|
|
|
@dataclass
|
|
class Group(ASTNode):
|
|
"""A capturing or non-capturing group."""
|
|
|
|
content: List[ASTNode] = field(default_factory=list)
|
|
capturing: bool = True
|
|
name: Optional[str] = None
|
|
|
|
def to_regex(self) -> str:
|
|
inner = "".join(node.to_regex() for node in self.content)
|
|
if self.name:
|
|
return f"(?P<{self.name}>{inner})"
|
|
elif not self.capturing:
|
|
return f"(?:{inner})"
|
|
return f"({inner})"
|
|
|
|
|
|
@dataclass
|
|
class Alternation(ASTNode):
|
|
"""An alternation (OR) construct."""
|
|
|
|
options: List[List[ASTNode]] = field(default_factory=list)
|
|
|
|
def to_regex(self) -> str:
|
|
option_strs = []
|
|
for option in self.options:
|
|
option_strs.append("".join(node.to_regex() for node in option))
|
|
return "|".join(option_strs)
|
|
|
|
|
|
@dataclass
|
|
class Anchor(ASTNode):
|
|
"""An anchor like ^, $, \\b, \\B."""
|
|
|
|
kind: str = "^"
|
|
|
|
def to_regex(self) -> str:
|
|
return self.kind
|
|
|
|
|
|
@dataclass
|
|
class SpecialSequence(ASTNode):
|
|
"""A special sequence like \\d, \\w, \\s, etc."""
|
|
|
|
sequence: str = "\\d"
|
|
|
|
def to_regex(self) -> str:
|
|
return self.sequence
|
|
|
|
|
|
@dataclass
|
|
class Backreference(ASTNode):
|
|
"""A backreference like \\1 or \\k<name>."""
|
|
|
|
reference: Union[int, str] = 1
|
|
|
|
def to_regex(self) -> str:
|
|
if isinstance(self.reference, str):
|
|
return f"\\k<{self.reference}>"
|
|
return f"\\{self.reference}"
|