diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..5dafff1 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,313 @@ +"""Tests for the regex parser.""" + +import pytest +from regex_humanizer.parser import ( + RegexParser, + parse_regex, + NodeType, + RegexNode, + LiteralNode, + CharacterClassNode, + QuantifierNode, + GroupNode, +) + + +class TestRegexParser: + """Test cases for RegexParser.""" + + def test_parse_simple_literal(self): + """Test parsing a simple literal string.""" + parser = RegexParser("hello") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) == 1 + assert ast.children[0].node_type == NodeType.LITERAL + assert ast.children[0].value == "hello" + + def test_parse_digit_shorthand(self): + """Test parsing digit shorthand character class.""" + parser = RegexParser("\\d+") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) == 1 + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + assert quantifier.min_count == 1 + + def test_parse_character_class(self): + """Test parsing a character class.""" + parser = RegexParser("[a-z]") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) == 1 + child = ast.children[0] + assert child.node_type == NodeType.POSITIVE_SET + assert "a" in child.characters or len(child.ranges) > 0 + + def test_parse_negated_character_class(self): + """Test parsing a negated character class.""" + parser = RegexParser("[^0-9]") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + child = ast.children[0] + assert child.node_type == NodeType.NEGATIVE_SET + assert child.negated is True + + def test_parse_plus_quantifier(self): + """Test parsing a plus quantifier.""" + parser = RegexParser("a+") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) == 1 + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + assert quantifier.min_count == 1 + assert quantifier.max_count == float('inf') + + def test_parse_star_quantifier(self): + """Test parsing a star quantifier.""" + parser = RegexParser("b*") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) == 1 + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + assert quantifier.min_count == 0 + assert quantifier.max_count == float('inf') + + def test_parse_question_quantifier(self): + """Test parsing a question mark quantifier.""" + parser = RegexParser("c?") + ast = parser.parse() + + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + assert quantifier.min_count == 0 + assert quantifier.max_count == 1 + + def test_parse_range_quantifier(self): + """Test parsing a range quantifier like {2,5}.""" + parser = RegexParser("a{2,5}") + ast = parser.parse() + + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + assert quantifier.min_count == 2 + assert quantifier.max_count == 5 + + def test_parse_lazy_quantifier(self): + """Test parsing a lazy quantifier.""" + parser = RegexParser("a+?") + ast = parser.parse() + + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + assert quantifier.is_lazy is True + + def test_parse_capturing_group(self): + """Test parsing a capturing group.""" + parser = RegexParser("(hello)") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) == 1 + group = ast.children[0] + assert group.node_type == NodeType.CAPTURING_GROUP + assert not group.is_non_capturing + + def test_parse_non_capturing_group(self): + """Test parsing a non-capturing group.""" + parser = RegexParser("(?:hello)") + ast = parser.parse() + + group = ast.children[0] + assert group.node_type == NodeType.NON_CAPTURING_GROUP + assert group.is_non_capturing is True + + def test_parse_named_group(self): + """Test parsing a named group.""" + parser = RegexParser("(?Phello)") + ast = parser.parse() + + group = ast.children[0] + assert group.node_type == NodeType.NAMED_GROUP + assert group.name == "name" + + def test_parse_positive_lookahead(self): + """Test parsing a positive lookahead.""" + parser = RegexParser("(?=test)") + ast = parser.parse() + + group = ast.children[0] + assert group.node_type == NodeType.LOOKAHEAD + + def test_parse_negative_lookahead(self): + """Test parsing a negative lookahead.""" + parser = RegexParser("(?!test)") + ast = parser.parse() + + group = ast.children[0] + assert group.node_type == NodeType.NEGATIVE_LOOKAHEAD + + def test_parse_lookbehind(self): + """Test parsing a lookbehind.""" + parser = RegexParser("(?<=test)") + ast = parser.parse() + + group = ast.children[0] + assert group.node_type == NodeType.LOOKBEHIND + + def test_parse_anchor_start(self): + """Test parsing a start anchor.""" + parser = RegexParser("^start") + ast = parser.parse() + + assert ast.children[0].node_type == NodeType.ANCHOR_START + + def test_parse_anchor_end(self): + """Test parsing an end anchor.""" + parser = RegexParser("end$") + ast = parser.parse() + + anchor = ast.children[-1] + assert anchor.node_type == NodeType.ANCHOR_END + + def test_parse_word_boundary(self): + """Test parsing a word boundary.""" + parser = RegexParser("\\bword\\b") + ast = parser.parse() + + assert any(child.node_type == NodeType.WORD_BOUNDARY for child in ast.children) + + def test_parse_dot(self): + """Test parsing a dot (any character).""" + parser = RegexParser(".") + ast = parser.parse() + + assert ast.children[0].node_type == NodeType.DOT + + def test_parse_complex_pattern(self): + """Test parsing a complex regex pattern.""" + pattern = r"^(?:http|https)://[\w.-]+\.(?:com|org|net)$" + parser = RegexParser(pattern) + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) > 0 + + def test_parse_alternation(self): + """Test parsing alternation with pipe.""" + parser = RegexParser("cat|dog") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) >= 1 + + def test_parse_escaped_character(self): + """Test parsing escaped characters.""" + parser = RegexParser("\\.") + ast = parser.parse() + + assert len(ast.children) > 0 + + def test_parse_whitespace_shorthand(self): + """Test parsing whitespace shorthand.""" + parser = RegexParser("\\s+") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + quantifier = ast.children[0] + assert quantifier.node_type == NodeType.QUANTIFIER + + def test_parse_word_char_shorthand(self): + """Test parsing word character shorthand.""" + parser = RegexParser("\\w*") + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + + def test_parse_hex_escape(self): + """Test parsing hex escape sequence.""" + parser = RegexParser("\\x41") + ast = parser.parse() + + assert len(ast.children) > 0 + + def test_parse_backreference(self): + """Test parsing a backreference.""" + parser = RegexParser("(a)\\1") + ast = parser.parse() + + assert len(ast.children) > 0 + + def test_parse_empty_group(self): + """Test parsing an empty group.""" + parser = RegexParser("()") + ast = parser.parse() + + group = ast.children[0] + assert group.node_type == NodeType.CAPTURING_GROUP + + def test_parse_nested_groups(self): + """Test parsing nested groups.""" + parser = RegexParser("((a)(b))") + ast = parser.parse() + + assert len(ast.children) == 1 + + def test_errors_empty(self): + """Test that valid patterns have no errors.""" + parser = RegexParser("hello") + parser.parse() + + assert len(parser.get_errors()) == 0 + + def test_parse_email_pattern(self): + """Test parsing a typical email regex pattern.""" + pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}" + parser = RegexParser(pattern) + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + assert len(ast.children) > 0 + + def test_parse_phone_pattern(self): + """Test parsing a phone number regex.""" + pattern = r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}" + parser = RegexParser(pattern) + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + + def test_parse_date_pattern(self): + """Test parsing a date regex.""" + pattern = r"\d{4}-\d{2}-\d{2}" + parser = RegexParser(pattern) + ast = parser.parse() + + assert ast.node_type == NodeType.SEQUENCE + + +class TestNodeTypes: + """Test node type enumeration.""" + + def test_node_type_values(self): + """Test that all expected node types exist.""" + expected_types = [ + "LITERAL", "CHARACTER_CLASS", "POSITIVE_SET", "NEGATIVE_SET", + "DOT", "GROUP", "CAPTURING_GROUP", "NON_CAPTURING_GROUP", + "NAMED_GROUP", "LOOKAHEAD", "LOOKBEHIND", "NEGATIVE_LOOKAHEAD", + "NEGATIVE_LOOKBEHIND", "QUANTIFIER", "ANCHOR_START", "ANCHOR_END", + "WORD_BOUNDARY", "START_OF_STRING", "END_OF_STRING", "DIGIT", + "NON_DIGIT", "WORD_CHAR", "WHITESPACE", "BACKREFERENCE", + ] + + for type_name in expected_types: + assert hasattr(NodeType, type_name), f"Missing node type: {type_name}"