fix: add type annotations to parser.py
Some checks failed
CI / test (push) Failing after 11s
CI / build (push) Has been skipped

This commit is contained in:
2026-02-02 07:04:38 +00:00
parent e86a5dede4
commit 352813814d

View File

@@ -0,0 +1,336 @@
"""Parse tokens into an AST."""
import re
from typing import List, Optional
from .ast import (
ASTNode,
Alternation,
Anchor,
Backreference,
CharacterClass,
Group,
Literal,
Quantifier,
SpecialSequence,
)
from .tokenizer import Token, tokenize
class ParseError(Exception):
"""Exception raised when parsing fails."""
def __init__(self, message: str, position: int = 0):
self.message = message
self.position = position
super().__init__(f"{message} at position {position}")
def parse_quantifier(tokens: List[Token], index: int) -> tuple[Optional[Quantifier], int]:
"""Parse a quantifier from tokens starting at index."""
if index >= len(tokens):
return None, index
token = tokens[index]
min_count = 0
max_count = Quantifier.MAX_UNBOUNDED
lazy = False
possessive = False
if token.type in ("PLUS", "PLUS_LAZY", "PLUS_POSSESSIVE"):
min_count = 1
max_count = Quantifier.MAX_UNBOUNDED
lazy = token.type == "PLUS_LAZY"
possessive = token.type == "PLUS_POSSESSIVE"
return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1
elif token.type in ("STAR", "STAR_LAZY", "STAR_POSSESSIVE"):
min_count = 0
max_count = Quantifier.MAX_UNBOUNDED
lazy = token.type == "STAR_LAZY"
possessive = token.type == "STAR_POSSESSIVE"
return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1
elif token.type in ("QUESTION", "QUESTION_LAZY", "QUESTION_POSSESSIVE"):
min_count = 0
max_count = 1
lazy = token.type == "QUESTION_LAZY"
possessive = token.type == "QUESTION_POSSESSIVE"
return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1
elif token.type == "OPEN_BRACE":
brace_content = ""
brace_end = index
for i in range(index + 1, len(tokens)):
if tokens[i].type == "CLOSE_BRACE":
brace_end = i
brace_content = "".join(t.value for t in tokens[index + 1:i])
break
if not brace_content:
raise ParseError("Invalid quantifier format", tokens[index].position)
brace_match = re.match(r"^(\d+)(?:,(\d*))?$", brace_content)
if not brace_match:
raise ParseError("Invalid quantifier format", tokens[index].position)
min_count = int(brace_match.group(1))
max_count_str = brace_match.group(2)
max_count = int(max_count_str) if max_count_str else Quantifier.MAX_UNBOUNDED
next_index = brace_end + 1
if next_index < len(tokens) and tokens[next_index].value == "?":
lazy = True
next_index += 1
return Quantifier(min=min_count, max=max_count, lazy=lazy, position=tokens[index].position), next_index
return None, index
def parse_character_class(tokens: List[Token], index: int) -> tuple[CharacterClass, int]:
"""Parse a character class from tokens starting at index."""
if index >= len(tokens) or tokens[index].type != "OPEN_BRACKET":
raise ParseError("Expected character class", tokens[index].position if index < len(tokens) else 0)
bracket_token = tokens[index]
inverted = False
characters = []
ranges = []
i = index + 1
if i < len(tokens) and tokens[i].type == "LITERAL" and tokens[i].value == "^":
inverted = True
i += 1
while i < len(tokens) and tokens[i].type != "CLOSE_BRACKET":
token = tokens[i]
if token.type == "ESCAPED":
char = token.value[1]
if i + 2 < len(tokens) and tokens[i + 1].type == "MINUS":
end_char = tokens[i + 2].value
if end_char == "ESCAPED":
end_char = end_char[1]
ranges.append((char, end_char))
i += 3
else:
characters.append(char)
i += 1
elif token.type == "MINUS":
i += 1
elif token.type == "DIGIT":
characters.append(token.value)
i += 1
elif token.type == "LITERAL":
if i + 2 < len(tokens) and tokens[i + 1].type == "MINUS":
end_char = tokens[i + 2].value
ranges.append((token.value, end_char))
i += 3
else:
characters.append(token.value)
i += 1
else:
characters.append(token.value)
i += 1
if i >= len(tokens):
raise ParseError("Unclosed character class", bracket_token.position)
return CharacterClass(
inverted=inverted,
characters=characters,
ranges=ranges,
position=bracket_token.position
), i + 1
def parse_group(tokens: List[Token], index: int) -> tuple[Group, int]:
"""Parse a group from tokens starting at index."""
if index >= len(tokens):
raise ParseError("Expected group start", 0)
group_token = tokens[index]
if tokens[index].type == "NON_CAPTURING":
content, next_index = parse_sequence(tokens, index + 1)
if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP":
raise ParseError("Unclosed non-capturing group", group_token.position)
next_index += 1
return Group(content=content, capturing=False, position=group_token.position), next_index
if tokens[index].type == "NAMED_GROUP":
name = tokens[index].extra
content, next_index = parse_sequence(tokens, index + 1)
if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP":
raise ParseError("Unclosed named group", group_token.position)
next_index += 1
return Group(content=content, capturing=True, name=name, position=group_token.position), next_index
if tokens[index].type in ("POSITIVE_LOOKAHEAD", "NEGATIVE_LOOKAHEAD",
"POSITIVE_LOOKBEHIND", "NEGATIVE_LOOKBEHIND",
"COMMENT"):
content, next_index = parse_sequence(tokens, index + 1)
if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP":
raise ParseError("Unclosed group", group_token.position)
next_index += 1
return Group(content=content, capturing=False, position=group_token.position), next_index
if tokens[index].type == "OPEN_GROUP":
i = index + 1
if i >= len(tokens):
raise ParseError("Empty group", group_token.position)
options: List[List[ASTNode]] = []
current_option: List[ASTNode] = []
first_alternation_index: Optional[int] = None
while i < len(tokens):
token = tokens[i]
if token.type == "ALTERNATION":
options.append(current_option)
current_option = []
first_alternation_index = i
i += 1
elif token.type == "CLOSE_GROUP":
if current_option or first_alternation_index is not None:
options.append(current_option)
if len(options) > 1:
alternation = Alternation(options=options, position=tokens[first_alternation_index].position) # type: ignore[index]
return Group(content=[alternation], capturing=True, position=group_token.position), i + 1
else:
return Group(content=current_option, capturing=True, position=group_token.position), i + 1
else:
nodes, next_i = parse_sequence(tokens, i)
current_option.extend(nodes)
i = next_i
raise ParseError("Unclosed group", group_token.position)
raise ParseError("Expected group start", tokens[index].position if index < len(tokens) else 0)
def parse_sequence(tokens: List[Token], index: int) -> tuple[List[ASTNode], int]:
"""Parse a sequence of tokens until end of group or pattern."""
nodes: List[ASTNode] = []
i = index
while i < len(tokens):
token = tokens[i]
if token.type in ("CLOSE_GROUP", "CLOSE_BRACKET", "ALTERNATION"):
break
if token.type == "ANCHOR_START":
nodes.append(Anchor(kind="^", position=token.position))
i += 1
elif token.type == "ANCHOR_END":
nodes.append(Anchor(kind="$", position=token.position))
i += 1
elif token.type == "WORD_BOUNDARY":
nodes.append(Anchor(kind=r"\b", position=token.position))
i += 1
elif token.type == "NON_WORD_BOUNDARY":
nodes.append(Anchor(kind=r"\B", position=token.position))
i += 1
elif token.type in ("DIGIT", "NON_DIGIT", "WHITESPACE", "NON_WHITESPACE",
"WORD_CHAR", "NON_WORD_CHAR"):
nodes.append(SpecialSequence(sequence=token.value, position=token.position))
i += 1
elif token.type == "ANY_CHAR":
nodes.append(SpecialSequence(sequence=".", position=token.position))
i += 1
elif token.type == "OPEN_BRACKET":
char_class, next_i = parse_character_class(tokens, i)
nodes.append(char_class)
i = next_i
elif token.type == "OPEN_GROUP":
group, next_i = parse_group(tokens, i)
nodes.append(group)
i = next_i
elif token.type == "NON_CAPTURING":
group, next_i = parse_group(tokens, i)
nodes.append(group)
i = next_i
elif token.type == "BACKREFERENCE":
ref = int(token.extra) if token.extra else 1
nodes.append(Backreference(reference=ref, position=token.position))
i += 1
elif token.type == "NAMED_BACKREFERENCE":
nodes.append(Backreference(reference=token.extra or "", position=token.position))
i += 1
elif token.type == "ESCAPED":
char = token.value[1]
nodes.append(Literal(value=char, escaped=True, position=token.position))
i += 1
elif token.type == "LITERAL":
literal_value = token.value
literal_position = token.position
i += 1
while i < len(tokens) and tokens[i].type == "LITERAL":
literal_value += tokens[i].value
i += 1
nodes.append(Literal(value=literal_value, escaped=False, position=literal_position))
elif token.type == "ALTERNATION":
break
else:
nodes.append(Literal(value=token.value, position=token.position))
i += 1
if i < len(tokens):
quant_node, next_i = parse_quantifier(tokens, i)
if quant_node and nodes:
nodes[-1] = quantifier_wrap(nodes[-1], quant_node)
i = next_i
return nodes, i
def quantifier_wrap(node: ASTNode, quantifier: Quantifier) -> Quantifier:
"""Wrap a node with a quantifier."""
quantifier.child = node
return quantifier
def parse_alternation(tokens: List[Token], index: int) -> tuple[Alternation, int]:
"""Parse an alternation from tokens."""
options: List[List[ASTNode]] = []
current_option: List[ASTNode] = []
i = index
while i < len(tokens):
token = tokens[i]
if token.type == "ALTERNATION":
options.append(current_option)
current_option = []
i += 1
elif token.type == "CLOSE_GROUP":
if current_option:
options.append(current_option)
alternation = Alternation(options=options, position=tokens[index].position)
return alternation, i
else:
node, next_i = parse_sequence(tokens, i)
current_option.extend(node)
i = next_i
if current_option:
options.append(current_option)
return Alternation(options=options, position=tokens[index].position), i
def parse_regex(pattern: str) -> List[ASTNode]:
"""Parse a regex pattern into an AST."""
tokens = tokenize(pattern)
nodes, index = parse_sequence(tokens, 0)
if index < len(tokens) and tokens[index].type == "ALTERNATION":
alternation, next_index = parse_alternation(tokens, index)
return [alternation]
if index < len(tokens):
remaining = "".join(t.value for t in tokens[index:])
raise ParseError(f"Unexpected token at position {index}: {remaining!r}", tokens[index].position)
return nodes