"""Abstract Syntax Tree nodes for regex patterns.""" from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List, Optional, Union @dataclass class ASTNode(ABC): """Base class for all AST nodes.""" position: int = 0 @abstractmethod def to_regex(self) -> str: """Convert the node back to a regex pattern.""" pass @dataclass class Literal(ASTNode): """A literal character or string.""" escaped: bool = False value: str = "" def to_regex(self) -> str: if self.escaped or not self.value.isalnum(): escaped = "" for char in self.value: if char in r".^$*+?{}[]\|()": escaped += "\\" + char else: escaped += char return escaped return self.value @dataclass class CharacterClass(ASTNode): """A character class like [a-z] or [^A-Z].""" inverted: bool = False characters: List[str] = field(default_factory=list) ranges: List[tuple] = field(default_factory=list) def to_regex(self) -> str: content = "" special_chars = {'^', '-', ']', '\\'} for char in self.characters: if char in special_chars: content += "\\" + char else: content += char for start, end in self.ranges: content += f"{start}-{end}" negated = "^" if self.inverted else "" return f"[{negated}{content}]" @dataclass class Quantifier(ASTNode): """A quantifier like *, +, ?, {n}, {n,m}.""" MIN_UNBOUNDED = -1 MAX_UNBOUNDED = -1 min: int = 0 max: int = 1 lazy: bool = False possessive: bool = False child: Optional[ASTNode] = None def to_regex(self) -> str: base = "" if self.min == 0 and self.max == 1: base = "?" elif self.min == 0 and self.max == self.MAX_UNBOUNDED: base = "*" elif self.min == 1 and self.max == self.MAX_UNBOUNDED: base = "+" elif self.min == self.max: base = f"{{{self.min}}}" elif self.max == self.MAX_UNBOUNDED: base = f"{{{self.min},}}" else: base = f"{{{self.min},{self.max}}}" modifier = "?" if self.lazy else "" modifier = "+" if self.possessive else modifier return base + modifier @dataclass class Group(ASTNode): """A capturing or non-capturing group.""" content: List[ASTNode] = field(default_factory=list) capturing: bool = True name: Optional[str] = None def to_regex(self) -> str: inner = "".join(node.to_regex() for node in self.content) if self.name: return f"(?P<{self.name}>{inner})" elif not self.capturing: return f"(?:{inner})" return f"({inner})" @dataclass class Alternation(ASTNode): """An alternation (OR) construct.""" options: List[List[ASTNode]] = field(default_factory=list) def to_regex(self) -> str: option_strs = [] for option in self.options: option_strs.append("".join(node.to_regex() for node in option)) return "|".join(option_strs) @dataclass class Anchor(ASTNode): """An anchor like ^, $, \\b, \\B.""" kind: str = "^" def to_regex(self) -> str: return self.kind @dataclass class SpecialSequence(ASTNode): """A special sequence like \\d, \\w, \\s, etc.""" sequence: str = "\\d" def to_regex(self) -> str: return self.sequence @dataclass class Backreference(ASTNode): """A backreference like \\1 or \\k.""" reference: Union[int, str] = 1 def to_regex(self) -> str: if isinstance(self.reference, str): return f"\\k<{self.reference}>" return f"\\{self.reference}"