"""Parse tokens into an AST.""" import re from typing import List, Optional from .ast import ( ASTNode, Alternation, Anchor, Backreference, CharacterClass, Group, Literal, Quantifier, SpecialSequence, ) from .tokenizer import Token, tokenize class ParseError(Exception): """Exception raised when parsing fails.""" def __init__(self, message: str, position: int = 0): self.message = message self.position = position super().__init__(f"{message} at position {position}") def parse_quantifier(tokens: List[Token], index: int) -> tuple[Optional[Quantifier], int]: """Parse a quantifier from tokens starting at index.""" if index >= len(tokens): return None, index token = tokens[index] min_count = 0 max_count = Quantifier.MAX_UNBOUNDED lazy = False possessive = False if token.type in ("PLUS", "PLUS_LAZY", "PLUS_POSSESSIVE"): min_count = 1 max_count = Quantifier.MAX_UNBOUNDED lazy = token.type == "PLUS_LAZY" possessive = token.type == "PLUS_POSSESSIVE" return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1 elif token.type in ("STAR", "STAR_LAZY", "STAR_POSSESSIVE"): min_count = 0 max_count = Quantifier.MAX_UNBOUNDED lazy = token.type == "STAR_LAZY" possessive = token.type == "STAR_POSSESSIVE" return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1 elif token.type in ("QUESTION", "QUESTION_LAZY", "QUESTION_POSSESSIVE"): min_count = 0 max_count = 1 lazy = token.type == "QUESTION_LAZY" possessive = token.type == "QUESTION_POSSESSIVE" return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1 elif token.type == "OPEN_BRACE": brace_content = "" brace_end = index for i in range(index + 1, len(tokens)): if tokens[i].type == "CLOSE_BRACE": brace_end = i brace_content = "".join(t.value for t in tokens[index + 1:i]) break if not brace_content: raise ParseError("Invalid quantifier format", tokens[index].position) brace_match = re.match(r"^(\d+)(?:,(\d*))?$", brace_content) if not brace_match: raise ParseError("Invalid quantifier format", tokens[index].position) min_count = int(brace_match.group(1)) max_count_str = brace_match.group(2) max_count = int(max_count_str) if max_count_str else Quantifier.MAX_UNBOUNDED next_index = brace_end + 1 if next_index < len(tokens) and tokens[next_index].value == "?": lazy = True next_index += 1 return Quantifier(min=min_count, max=max_count, lazy=lazy, position=tokens[index].position), next_index return None, index def parse_character_class(tokens: List[Token], index: int) -> tuple[CharacterClass, int]: """Parse a character class from tokens starting at index.""" if index >= len(tokens) or tokens[index].type != "OPEN_BRACKET": raise ParseError("Expected character class", tokens[index].position if index < len(tokens) else 0) bracket_token = tokens[index] inverted = False characters = [] ranges = [] i = index + 1 if i < len(tokens) and tokens[i].type == "LITERAL" and tokens[i].value == "^": inverted = True i += 1 while i < len(tokens) and tokens[i].type != "CLOSE_BRACKET": token = tokens[i] if token.type == "ESCAPED": char = token.value[1] if i + 2 < len(tokens) and tokens[i + 1].type == "MINUS": end_char = tokens[i + 2].value if end_char == "ESCAPED": end_char = end_char[1] ranges.append((char, end_char)) i += 3 else: characters.append(char) i += 1 elif token.type == "MINUS": i += 1 elif token.type == "DIGIT": characters.append(token.value) i += 1 elif token.type == "LITERAL": if i + 2 < len(tokens) and tokens[i + 1].type == "MINUS": end_char = tokens[i + 2].value ranges.append((token.value, end_char)) i += 3 else: characters.append(token.value) i += 1 else: characters.append(token.value) i += 1 if i >= len(tokens): raise ParseError("Unclosed character class", bracket_token.position) return CharacterClass( inverted=inverted, characters=characters, ranges=ranges, position=bracket_token.position ), i + 1 def parse_group(tokens: List[Token], index: int) -> tuple[Group, int]: """Parse a group from tokens starting at index.""" if index >= len(tokens): raise ParseError("Expected group start", 0) group_token = tokens[index] if tokens[index].type == "NON_CAPTURING": content, next_index = parse_sequence(tokens, index + 1) if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP": raise ParseError("Unclosed non-capturing group", group_token.position) next_index += 1 return Group(content=content, capturing=False, position=group_token.position), next_index if tokens[index].type == "NAMED_GROUP": name = tokens[index].extra content, next_index = parse_sequence(tokens, index + 1) if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP": raise ParseError("Unclosed named group", group_token.position) next_index += 1 return Group(content=content, capturing=True, name=name, position=group_token.position), next_index if tokens[index].type in ("POSITIVE_LOOKAHEAD", "NEGATIVE_LOOKAHEAD", "POSITIVE_LOOKBEHIND", "NEGATIVE_LOOKBEHIND", "COMMENT"): content, next_index = parse_sequence(tokens, index + 1) if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP": raise ParseError("Unclosed group", group_token.position) next_index += 1 return Group(content=content, capturing=False, position=group_token.position), next_index if tokens[index].type == "OPEN_GROUP": i = index + 1 if i >= len(tokens): raise ParseError("Empty group", group_token.position) options: List[List[ASTNode]] = [] current_option: List[ASTNode] = [] first_alternation_index: Optional[int] = None while i < len(tokens): token = tokens[i] if token.type == "ALTERNATION": options.append(current_option) current_option = [] first_alternation_index = i i += 1 elif token.type == "CLOSE_GROUP": if current_option or first_alternation_index is not None: options.append(current_option) if len(options) > 1: alternation = Alternation(options=options, position=tokens[first_alternation_index].position) # type: ignore[index] return Group(content=[alternation], capturing=True, position=group_token.position), i + 1 else: return Group(content=current_option, capturing=True, position=group_token.position), i + 1 else: nodes, next_i = parse_sequence(tokens, i) current_option.extend(nodes) i = next_i raise ParseError("Unclosed group", group_token.position) raise ParseError("Expected group start", tokens[index].position if index < len(tokens) else 0) def parse_sequence(tokens: List[Token], index: int) -> tuple[List[ASTNode], int]: """Parse a sequence of tokens until end of group or pattern.""" nodes: List[ASTNode] = [] i = index while i < len(tokens): token = tokens[i] if token.type in ("CLOSE_GROUP", "CLOSE_BRACKET", "ALTERNATION"): break if token.type == "ANCHOR_START": nodes.append(Anchor(kind="^", position=token.position)) i += 1 elif token.type == "ANCHOR_END": nodes.append(Anchor(kind="$", position=token.position)) i += 1 elif token.type == "WORD_BOUNDARY": nodes.append(Anchor(kind=r"\b", position=token.position)) i += 1 elif token.type == "NON_WORD_BOUNDARY": nodes.append(Anchor(kind=r"\B", position=token.position)) i += 1 elif token.type in ("DIGIT", "NON_DIGIT", "WHITESPACE", "NON_WHITESPACE", "WORD_CHAR", "NON_WORD_CHAR"): nodes.append(SpecialSequence(sequence=token.value, position=token.position)) i += 1 elif token.type == "ANY_CHAR": nodes.append(SpecialSequence(sequence=".", position=token.position)) i += 1 elif token.type == "OPEN_BRACKET": char_class, next_i = parse_character_class(tokens, i) nodes.append(char_class) i = next_i elif token.type == "OPEN_GROUP": group, next_i = parse_group(tokens, i) nodes.append(group) i = next_i elif token.type == "NON_CAPTURING": group, next_i = parse_group(tokens, i) nodes.append(group) i = next_i elif token.type == "BACKREFERENCE": ref = int(token.extra) if token.extra else 1 nodes.append(Backreference(reference=ref, position=token.position)) i += 1 elif token.type == "NAMED_BACKREFERENCE": nodes.append(Backreference(reference=token.extra or "", position=token.position)) i += 1 elif token.type == "ESCAPED": char = token.value[1] nodes.append(Literal(value=char, escaped=True, position=token.position)) i += 1 elif token.type == "LITERAL": literal_value = token.value literal_position = token.position i += 1 while i < len(tokens) and tokens[i].type == "LITERAL": literal_value += tokens[i].value i += 1 nodes.append(Literal(value=literal_value, escaped=False, position=literal_position)) elif token.type == "ALTERNATION": break else: nodes.append(Literal(value=token.value, position=token.position)) i += 1 if i < len(tokens): quant_node, next_i = parse_quantifier(tokens, i) if quant_node and nodes: nodes[-1] = quantifier_wrap(nodes[-1], quant_node) i = next_i return nodes, i def quantifier_wrap(node: ASTNode, quantifier: Quantifier) -> Quantifier: """Wrap a node with a quantifier.""" quantifier.child = node return quantifier def parse_alternation(tokens: List[Token], index: int) -> tuple[Alternation, int]: """Parse an alternation from tokens.""" options: List[List[ASTNode]] = [] current_option: List[ASTNode] = [] i = index while i < len(tokens): token = tokens[i] if token.type == "ALTERNATION": options.append(current_option) current_option = [] i += 1 elif token.type == "CLOSE_GROUP": if current_option: options.append(current_option) alternation = Alternation(options=options, position=tokens[index].position) return alternation, i else: node, next_i = parse_sequence(tokens, i) current_option.extend(node) i = next_i if current_option: options.append(current_option) return Alternation(options=options, position=tokens[index].position), i def parse_regex(pattern: str) -> List[ASTNode]: """Parse a regex pattern into an AST.""" tokens = tokenize(pattern) nodes, index = parse_sequence(tokens, 0) if index < len(tokens) and tokens[index].type == "ALTERNATION": alternation, next_index = parse_alternation(tokens, index) return [alternation] if index < len(tokens): remaining = "".join(t.value for t in tokens[index:]) raise ParseError(f"Unexpected token at position {index}: {remaining!r}", tokens[index].position) return nodes