fix: add type annotations to parser.py
This commit is contained in:
336
regex_humanizer/parser/parser.py
Normal file
336
regex_humanizer/parser/parser.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Parse tokens into an AST."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from .ast import (
|
||||
ASTNode,
|
||||
Alternation,
|
||||
Anchor,
|
||||
Backreference,
|
||||
CharacterClass,
|
||||
Group,
|
||||
Literal,
|
||||
Quantifier,
|
||||
SpecialSequence,
|
||||
)
|
||||
from .tokenizer import Token, tokenize
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Exception raised when parsing fails."""
|
||||
|
||||
def __init__(self, message: str, position: int = 0):
|
||||
self.message = message
|
||||
self.position = position
|
||||
super().__init__(f"{message} at position {position}")
|
||||
|
||||
|
||||
def parse_quantifier(tokens: List[Token], index: int) -> tuple[Optional[Quantifier], int]:
|
||||
"""Parse a quantifier from tokens starting at index."""
|
||||
if index >= len(tokens):
|
||||
return None, index
|
||||
|
||||
token = tokens[index]
|
||||
min_count = 0
|
||||
max_count = Quantifier.MAX_UNBOUNDED
|
||||
lazy = False
|
||||
possessive = False
|
||||
|
||||
if token.type in ("PLUS", "PLUS_LAZY", "PLUS_POSSESSIVE"):
|
||||
min_count = 1
|
||||
max_count = Quantifier.MAX_UNBOUNDED
|
||||
lazy = token.type == "PLUS_LAZY"
|
||||
possessive = token.type == "PLUS_POSSESSIVE"
|
||||
return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1
|
||||
|
||||
elif token.type in ("STAR", "STAR_LAZY", "STAR_POSSESSIVE"):
|
||||
min_count = 0
|
||||
max_count = Quantifier.MAX_UNBOUNDED
|
||||
lazy = token.type == "STAR_LAZY"
|
||||
possessive = token.type == "STAR_POSSESSIVE"
|
||||
return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1
|
||||
|
||||
elif token.type in ("QUESTION", "QUESTION_LAZY", "QUESTION_POSSESSIVE"):
|
||||
min_count = 0
|
||||
max_count = 1
|
||||
lazy = token.type == "QUESTION_LAZY"
|
||||
possessive = token.type == "QUESTION_POSSESSIVE"
|
||||
return Quantifier(min=min_count, max=max_count, lazy=lazy, possessive=possessive, position=token.position), index + 1
|
||||
|
||||
elif token.type == "OPEN_BRACE":
|
||||
brace_content = ""
|
||||
brace_end = index
|
||||
|
||||
for i in range(index + 1, len(tokens)):
|
||||
if tokens[i].type == "CLOSE_BRACE":
|
||||
brace_end = i
|
||||
brace_content = "".join(t.value for t in tokens[index + 1:i])
|
||||
break
|
||||
|
||||
if not brace_content:
|
||||
raise ParseError("Invalid quantifier format", tokens[index].position)
|
||||
|
||||
brace_match = re.match(r"^(\d+)(?:,(\d*))?$", brace_content)
|
||||
if not brace_match:
|
||||
raise ParseError("Invalid quantifier format", tokens[index].position)
|
||||
|
||||
min_count = int(brace_match.group(1))
|
||||
max_count_str = brace_match.group(2)
|
||||
max_count = int(max_count_str) if max_count_str else Quantifier.MAX_UNBOUNDED
|
||||
|
||||
next_index = brace_end + 1
|
||||
if next_index < len(tokens) and tokens[next_index].value == "?":
|
||||
lazy = True
|
||||
next_index += 1
|
||||
|
||||
return Quantifier(min=min_count, max=max_count, lazy=lazy, position=tokens[index].position), next_index
|
||||
|
||||
return None, index
|
||||
|
||||
|
||||
def parse_character_class(tokens: List[Token], index: int) -> tuple[CharacterClass, int]:
|
||||
"""Parse a character class from tokens starting at index."""
|
||||
if index >= len(tokens) or tokens[index].type != "OPEN_BRACKET":
|
||||
raise ParseError("Expected character class", tokens[index].position if index < len(tokens) else 0)
|
||||
|
||||
bracket_token = tokens[index]
|
||||
inverted = False
|
||||
characters = []
|
||||
ranges = []
|
||||
|
||||
i = index + 1
|
||||
|
||||
if i < len(tokens) and tokens[i].type == "LITERAL" and tokens[i].value == "^":
|
||||
inverted = True
|
||||
i += 1
|
||||
|
||||
while i < len(tokens) and tokens[i].type != "CLOSE_BRACKET":
|
||||
token = tokens[i]
|
||||
|
||||
if token.type == "ESCAPED":
|
||||
char = token.value[1]
|
||||
if i + 2 < len(tokens) and tokens[i + 1].type == "MINUS":
|
||||
end_char = tokens[i + 2].value
|
||||
if end_char == "ESCAPED":
|
||||
end_char = end_char[1]
|
||||
ranges.append((char, end_char))
|
||||
i += 3
|
||||
else:
|
||||
characters.append(char)
|
||||
i += 1
|
||||
elif token.type == "MINUS":
|
||||
i += 1
|
||||
elif token.type == "DIGIT":
|
||||
characters.append(token.value)
|
||||
i += 1
|
||||
elif token.type == "LITERAL":
|
||||
if i + 2 < len(tokens) and tokens[i + 1].type == "MINUS":
|
||||
end_char = tokens[i + 2].value
|
||||
ranges.append((token.value, end_char))
|
||||
i += 3
|
||||
else:
|
||||
characters.append(token.value)
|
||||
i += 1
|
||||
else:
|
||||
characters.append(token.value)
|
||||
i += 1
|
||||
|
||||
if i >= len(tokens):
|
||||
raise ParseError("Unclosed character class", bracket_token.position)
|
||||
|
||||
return CharacterClass(
|
||||
inverted=inverted,
|
||||
characters=characters,
|
||||
ranges=ranges,
|
||||
position=bracket_token.position
|
||||
), i + 1
|
||||
|
||||
|
||||
def parse_group(tokens: List[Token], index: int) -> tuple[Group, int]:
|
||||
"""Parse a group from tokens starting at index."""
|
||||
if index >= len(tokens):
|
||||
raise ParseError("Expected group start", 0)
|
||||
|
||||
group_token = tokens[index]
|
||||
|
||||
if tokens[index].type == "NON_CAPTURING":
|
||||
content, next_index = parse_sequence(tokens, index + 1)
|
||||
if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP":
|
||||
raise ParseError("Unclosed non-capturing group", group_token.position)
|
||||
next_index += 1
|
||||
return Group(content=content, capturing=False, position=group_token.position), next_index
|
||||
|
||||
if tokens[index].type == "NAMED_GROUP":
|
||||
name = tokens[index].extra
|
||||
content, next_index = parse_sequence(tokens, index + 1)
|
||||
if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP":
|
||||
raise ParseError("Unclosed named group", group_token.position)
|
||||
next_index += 1
|
||||
return Group(content=content, capturing=True, name=name, position=group_token.position), next_index
|
||||
|
||||
if tokens[index].type in ("POSITIVE_LOOKAHEAD", "NEGATIVE_LOOKAHEAD",
|
||||
"POSITIVE_LOOKBEHIND", "NEGATIVE_LOOKBEHIND",
|
||||
"COMMENT"):
|
||||
content, next_index = parse_sequence(tokens, index + 1)
|
||||
if next_index >= len(tokens) or tokens[next_index].type != "CLOSE_GROUP":
|
||||
raise ParseError("Unclosed group", group_token.position)
|
||||
next_index += 1
|
||||
return Group(content=content, capturing=False, position=group_token.position), next_index
|
||||
|
||||
if tokens[index].type == "OPEN_GROUP":
|
||||
i = index + 1
|
||||
if i >= len(tokens):
|
||||
raise ParseError("Empty group", group_token.position)
|
||||
options: List[List[ASTNode]] = []
|
||||
current_option: List[ASTNode] = []
|
||||
first_alternation_index: Optional[int] = None
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
if token.type == "ALTERNATION":
|
||||
options.append(current_option)
|
||||
current_option = []
|
||||
first_alternation_index = i
|
||||
i += 1
|
||||
elif token.type == "CLOSE_GROUP":
|
||||
if current_option or first_alternation_index is not None:
|
||||
options.append(current_option)
|
||||
if len(options) > 1:
|
||||
alternation = Alternation(options=options, position=tokens[first_alternation_index].position) # type: ignore[index]
|
||||
return Group(content=[alternation], capturing=True, position=group_token.position), i + 1
|
||||
else:
|
||||
return Group(content=current_option, capturing=True, position=group_token.position), i + 1
|
||||
else:
|
||||
nodes, next_i = parse_sequence(tokens, i)
|
||||
current_option.extend(nodes)
|
||||
i = next_i
|
||||
raise ParseError("Unclosed group", group_token.position)
|
||||
|
||||
raise ParseError("Expected group start", tokens[index].position if index < len(tokens) else 0)
|
||||
|
||||
|
||||
def parse_sequence(tokens: List[Token], index: int) -> tuple[List[ASTNode], int]:
|
||||
"""Parse a sequence of tokens until end of group or pattern."""
|
||||
nodes: List[ASTNode] = []
|
||||
i = index
|
||||
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
if token.type in ("CLOSE_GROUP", "CLOSE_BRACKET", "ALTERNATION"):
|
||||
break
|
||||
|
||||
if token.type == "ANCHOR_START":
|
||||
nodes.append(Anchor(kind="^", position=token.position))
|
||||
i += 1
|
||||
elif token.type == "ANCHOR_END":
|
||||
nodes.append(Anchor(kind="$", position=token.position))
|
||||
i += 1
|
||||
elif token.type == "WORD_BOUNDARY":
|
||||
nodes.append(Anchor(kind=r"\b", position=token.position))
|
||||
i += 1
|
||||
elif token.type == "NON_WORD_BOUNDARY":
|
||||
nodes.append(Anchor(kind=r"\B", position=token.position))
|
||||
i += 1
|
||||
elif token.type in ("DIGIT", "NON_DIGIT", "WHITESPACE", "NON_WHITESPACE",
|
||||
"WORD_CHAR", "NON_WORD_CHAR"):
|
||||
nodes.append(SpecialSequence(sequence=token.value, position=token.position))
|
||||
i += 1
|
||||
elif token.type == "ANY_CHAR":
|
||||
nodes.append(SpecialSequence(sequence=".", position=token.position))
|
||||
i += 1
|
||||
elif token.type == "OPEN_BRACKET":
|
||||
char_class, next_i = parse_character_class(tokens, i)
|
||||
nodes.append(char_class)
|
||||
i = next_i
|
||||
elif token.type == "OPEN_GROUP":
|
||||
group, next_i = parse_group(tokens, i)
|
||||
nodes.append(group)
|
||||
i = next_i
|
||||
elif token.type == "NON_CAPTURING":
|
||||
group, next_i = parse_group(tokens, i)
|
||||
nodes.append(group)
|
||||
i = next_i
|
||||
elif token.type == "BACKREFERENCE":
|
||||
ref = int(token.extra) if token.extra else 1
|
||||
nodes.append(Backreference(reference=ref, position=token.position))
|
||||
i += 1
|
||||
elif token.type == "NAMED_BACKREFERENCE":
|
||||
nodes.append(Backreference(reference=token.extra or "", position=token.position))
|
||||
i += 1
|
||||
elif token.type == "ESCAPED":
|
||||
char = token.value[1]
|
||||
nodes.append(Literal(value=char, escaped=True, position=token.position))
|
||||
i += 1
|
||||
elif token.type == "LITERAL":
|
||||
literal_value = token.value
|
||||
literal_position = token.position
|
||||
i += 1
|
||||
while i < len(tokens) and tokens[i].type == "LITERAL":
|
||||
literal_value += tokens[i].value
|
||||
i += 1
|
||||
nodes.append(Literal(value=literal_value, escaped=False, position=literal_position))
|
||||
elif token.type == "ALTERNATION":
|
||||
break
|
||||
else:
|
||||
nodes.append(Literal(value=token.value, position=token.position))
|
||||
i += 1
|
||||
|
||||
if i < len(tokens):
|
||||
quant_node, next_i = parse_quantifier(tokens, i)
|
||||
if quant_node and nodes:
|
||||
nodes[-1] = quantifier_wrap(nodes[-1], quant_node)
|
||||
i = next_i
|
||||
|
||||
return nodes, i
|
||||
|
||||
|
||||
def quantifier_wrap(node: ASTNode, quantifier: Quantifier) -> Quantifier:
|
||||
"""Wrap a node with a quantifier."""
|
||||
quantifier.child = node
|
||||
return quantifier
|
||||
|
||||
|
||||
def parse_alternation(tokens: List[Token], index: int) -> tuple[Alternation, int]:
|
||||
"""Parse an alternation from tokens."""
|
||||
options: List[List[ASTNode]] = []
|
||||
current_option: List[ASTNode] = []
|
||||
i = index
|
||||
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
if token.type == "ALTERNATION":
|
||||
options.append(current_option)
|
||||
current_option = []
|
||||
i += 1
|
||||
elif token.type == "CLOSE_GROUP":
|
||||
if current_option:
|
||||
options.append(current_option)
|
||||
alternation = Alternation(options=options, position=tokens[index].position)
|
||||
return alternation, i
|
||||
else:
|
||||
node, next_i = parse_sequence(tokens, i)
|
||||
current_option.extend(node)
|
||||
i = next_i
|
||||
|
||||
if current_option:
|
||||
options.append(current_option)
|
||||
|
||||
return Alternation(options=options, position=tokens[index].position), i
|
||||
|
||||
|
||||
def parse_regex(pattern: str) -> List[ASTNode]:
|
||||
"""Parse a regex pattern into an AST."""
|
||||
tokens = tokenize(pattern)
|
||||
nodes, index = parse_sequence(tokens, 0)
|
||||
|
||||
if index < len(tokens) and tokens[index].type == "ALTERNATION":
|
||||
alternation, next_index = parse_alternation(tokens, index)
|
||||
return [alternation]
|
||||
|
||||
if index < len(tokens):
|
||||
remaining = "".join(t.value for t in tokens[index:])
|
||||
raise ParseError(f"Unexpected token at position {index}: {remaining!r}", tokens[index].position)
|
||||
|
||||
return nodes
|
||||
Reference in New Issue
Block a user