"""Tests for the regex parser.""" import pytest from regex_humanizer.parser import ( RegexParser, parse_regex, NodeType, RegexNode, LiteralNode, CharacterClassNode, QuantifierNode, GroupNode, ) class TestRegexParser: """Test cases for RegexParser.""" def test_parse_simple_literal(self): """Test parsing a simple literal string.""" parser = RegexParser("hello") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) == 1 assert ast.children[0].node_type == NodeType.LITERAL assert ast.children[0].value == "hello" def test_parse_digit_shorthand(self): """Test parsing digit shorthand character class.""" parser = RegexParser("\\d+") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) == 1 quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER assert quantifier.min_count == 1 def test_parse_character_class(self): """Test parsing a character class.""" parser = RegexParser("[a-z]") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) == 1 child = ast.children[0] assert child.node_type == NodeType.POSITIVE_SET assert "a" in child.characters or len(child.ranges) > 0 def test_parse_negated_character_class(self): """Test parsing a negated character class.""" parser = RegexParser("[^0-9]") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE child = ast.children[0] assert child.node_type == NodeType.NEGATIVE_SET assert child.negated is True def test_parse_plus_quantifier(self): """Test parsing a plus quantifier.""" parser = RegexParser("a+") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) == 1 quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER assert quantifier.min_count == 1 assert quantifier.max_count == float('inf') def test_parse_star_quantifier(self): """Test parsing a star quantifier.""" parser = RegexParser("b*") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) == 1 quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER assert quantifier.min_count == 0 assert quantifier.max_count == float('inf') def test_parse_question_quantifier(self): """Test parsing a question mark quantifier.""" parser = RegexParser("c?") ast = parser.parse() quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER assert quantifier.min_count == 0 assert quantifier.max_count == 1 def test_parse_range_quantifier(self): """Test parsing a range quantifier like {2,5}.""" parser = RegexParser("a{2,5}") ast = parser.parse() quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER assert quantifier.min_count == 2 assert quantifier.max_count == 5 def test_parse_lazy_quantifier(self): """Test parsing a lazy quantifier.""" parser = RegexParser("a+?") ast = parser.parse() quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER assert quantifier.is_lazy is True def test_parse_capturing_group(self): """Test parsing a capturing group.""" parser = RegexParser("(hello)") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) == 1 group = ast.children[0] assert group.node_type == NodeType.CAPTURING_GROUP assert not group.is_non_capturing def test_parse_non_capturing_group(self): """Test parsing a non-capturing group.""" parser = RegexParser("(?:hello)") ast = parser.parse() group = ast.children[0] assert group.node_type == NodeType.NON_CAPTURING_GROUP assert group.is_non_capturing is True def test_parse_named_group(self): """Test parsing a named group.""" parser = RegexParser("(?Phello)") ast = parser.parse() group = ast.children[0] assert group.node_type == NodeType.NAMED_GROUP assert group.name == "name" def test_parse_positive_lookahead(self): """Test parsing a positive lookahead.""" parser = RegexParser("(?=test)") ast = parser.parse() group = ast.children[0] assert group.node_type == NodeType.LOOKAHEAD def test_parse_negative_lookahead(self): """Test parsing a negative lookahead.""" parser = RegexParser("(?!test)") ast = parser.parse() group = ast.children[0] assert group.node_type == NodeType.NEGATIVE_LOOKAHEAD def test_parse_lookbehind(self): """Test parsing a lookbehind.""" parser = RegexParser("(?<=test)") ast = parser.parse() group = ast.children[0] assert group.node_type == NodeType.LOOKBEHIND def test_parse_anchor_start(self): """Test parsing a start anchor.""" parser = RegexParser("^start") ast = parser.parse() assert ast.children[0].node_type == NodeType.ANCHOR_START def test_parse_anchor_end(self): """Test parsing an end anchor.""" parser = RegexParser("end$") ast = parser.parse() anchor = ast.children[-1] assert anchor.node_type == NodeType.ANCHOR_END def test_parse_word_boundary(self): """Test parsing a word boundary.""" parser = RegexParser("\\bword\\b") ast = parser.parse() assert any(child.node_type == NodeType.WORD_BOUNDARY for child in ast.children) def test_parse_dot(self): """Test parsing a dot (any character).""" parser = RegexParser(".") ast = parser.parse() assert ast.children[0].node_type == NodeType.DOT def test_parse_complex_pattern(self): """Test parsing a complex regex pattern.""" pattern = r"^(?:http|https)://[\w.-]+\.(?:com|org|net)$" parser = RegexParser(pattern) ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) > 0 def test_parse_alternation(self): """Test parsing alternation with pipe.""" parser = RegexParser("cat|dog") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) >= 1 def test_parse_escaped_character(self): """Test parsing escaped characters.""" parser = RegexParser("\\.") ast = parser.parse() assert len(ast.children) > 0 def test_parse_whitespace_shorthand(self): """Test parsing whitespace shorthand.""" parser = RegexParser("\\s+") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE quantifier = ast.children[0] assert quantifier.node_type == NodeType.QUANTIFIER def test_parse_word_char_shorthand(self): """Test parsing word character shorthand.""" parser = RegexParser("\\w*") ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE def test_parse_hex_escape(self): """Test parsing hex escape sequence.""" parser = RegexParser("\\x41") ast = parser.parse() assert len(ast.children) > 0 def test_parse_backreference(self): """Test parsing a backreference.""" parser = RegexParser("(a)\\1") ast = parser.parse() assert len(ast.children) > 0 def test_parse_empty_group(self): """Test parsing an empty group.""" parser = RegexParser("()") ast = parser.parse() group = ast.children[0] assert group.node_type == NodeType.CAPTURING_GROUP def test_parse_nested_groups(self): """Test parsing nested groups.""" parser = RegexParser("((a)(b))") ast = parser.parse() assert len(ast.children) == 1 def test_errors_empty(self): """Test that valid patterns have no errors.""" parser = RegexParser("hello") parser.parse() assert len(parser.get_errors()) == 0 def test_parse_email_pattern(self): """Test parsing a typical email regex pattern.""" pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}" parser = RegexParser(pattern) ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE assert len(ast.children) > 0 def test_parse_phone_pattern(self): """Test parsing a phone number regex.""" pattern = r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}" parser = RegexParser(pattern) ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE def test_parse_date_pattern(self): """Test parsing a date regex.""" pattern = r"\d{4}-\d{2}-\d{2}" parser = RegexParser(pattern) ast = parser.parse() assert ast.node_type == NodeType.SEQUENCE class TestNodeTypes: """Test node type enumeration.""" def test_node_type_values(self): """Test that all expected node types exist.""" expected_types = [ "LITERAL", "CHARACTER_CLASS", "POSITIVE_SET", "NEGATIVE_SET", "DOT", "GROUP", "CAPTURING_GROUP", "NON_CAPTURING_GROUP", "NAMED_GROUP", "LOOKAHEAD", "LOOKBEHIND", "NEGATIVE_LOOKAHEAD", "NEGATIVE_LOOKBEHIND", "QUANTIFIER", "ANCHOR_START", "ANCHOR_END", "WORD_BOUNDARY", "START_OF_STRING", "END_OF_STRING", "DIGIT", "NON_DIGIT", "WORD_CHAR", "WHITESPACE", "BACKREFERENCE", ] for type_name in expected_types: assert hasattr(NodeType, type_name), f"Missing node type: {type_name}"