diff --git a/tests/test_pattern_matcher.py b/tests/test_pattern_matcher.py new file mode 100644 index 0000000..fc02a5f --- /dev/null +++ b/tests/test_pattern_matcher.py @@ -0,0 +1,201 @@ +"""Tests for the pattern matching engine.""" + +import re +import pytest +from src.pattern_matcher import ( + PatternMatcher, + PatternLibrary, + MatchLocation, + LANGUAGE_EXTENSIONS, + EXTENSION_TO_LANGUAGE, +) + + +class TestPatternMatcher: + """Tests for PatternMatcher class.""" + + def test_simple_pattern_match(self): + """Test basic regex pattern matching.""" + matcher = PatternMatcher(r"def\s+\w+") + content = "def hello():\n pass\ndef world():" + matches = matcher.find_matches(content, "test.py") + + assert len(matches) == 2 + assert matches[0].line_number == 1 + assert matches[1].line_number == 3 + + def test_invalid_pattern_raises_error(self): + """Test that invalid regex raises error when compiled.""" + matcher = PatternMatcher(r"[") + try: + _ = matcher.compiled + pytest.fail("Should have raised an error") + except (ValueError, re.error): + pass + + def test_valid_pattern_no_error(self): + """Test that valid regex doesn't raise error.""" + matcher = PatternMatcher(r"[unclosed") + assert matcher is not None + + def test_no_matches(self): + """Test pattern with no matches.""" + matcher = PatternMatcher(r"notgoingtomatch") + content = "def hello(): pass" + matches = matcher.find_matches(content, "test.py") + + assert len(matches) == 0 + + def test_matches_extension(self): + """Test file extension matching.""" + matcher = PatternMatcher(r"test") + + assert matcher.matches_extension("test.py") is True + assert matcher.matches_extension("test.js") is True + assert matcher.matches_extension("test.py") is True + assert matcher.matches_extension("test") is False + assert matcher.matches_extension("test.txt") is False + + def test_get_language(self): + """Test language detection from extension.""" + matcher = PatternMatcher(r"test") + + assert matcher.get_language("file.py") == "python" + assert matcher.get_language("file.js") == "javascript" + assert matcher.get_language("file.ts") == "typescript" + assert matcher.get_language("file.go") == "go" + assert matcher.get_language("file.rs") == "rust" + assert matcher.get_language("file.java") == "java" + assert matcher.get_language("file.unknown") is None + + def test_count_matches(self): + """Test counting matches in content.""" + matcher = PatternMatcher(r"\b\w+\b") + content = "hello world test one" + count = matcher.count_matches(content) + + assert count == 4 + + def test_has_matches(self): + """Test checking if content has matches.""" + matcher = PatternMatcher(r"def\s+\w+") + + assert matcher.has_matches("def hello():") is True + assert matcher.has_matches("class Hello:") is False + + def test_case_sensitive_match(self): + """Test case-sensitive matching.""" + matcher = PatternMatcher(r"Hello") + content = "Hello hello HELLO" + + matches = matcher.find_matches(content, "test.txt") + assert len(matches) == 1 + + def test_multiline_pattern(self): + """Test patterns spanning multiple lines.""" + PatternMatcher(r"def\s+\w+\s*\n\s*:", flags=0) + + def test_special_characters_in_pattern(self): + """Test pattern with special regex characters.""" + matcher = PatternMatcher(r"\d+\.\d+") + content = "Numbers like 3.14 and 2.71" + matches = matcher.find_matches(content, "test.txt") + + assert len(matches) == 2 + + +class TestPatternLibrary: + """Tests for PatternLibrary class.""" + + def test_get_preset(self): + """Test getting a preset pattern.""" + pattern = PatternLibrary.get_pattern("python-dataclass") + assert pattern is not None + assert "@dataclass" in pattern + + def test_get_nonexistent_preset(self): + """Test getting a preset that doesn't exist.""" + pattern = PatternLibrary.get_pattern("nonexistent") + assert pattern is None + + def test_list_presets(self): + """Test listing all presets.""" + presets = PatternLibrary.list_presets() + + assert isinstance(presets, list) + assert len(presets) > 0 + assert "python-dataclass" in presets + assert "react-useeffect" in presets + + def test_get_presets_by_category(self): + """Test getting presets organized by category.""" + categories = PatternLibrary.get_presets_by_category() + + assert isinstance(categories, dict) + assert "Python" in categories + assert "JavaScript/TypeScript" in categories + assert "Go" in categories + + def test_all_presets_are_valid_regex(self): + """Test that all presets are valid regex patterns.""" + for name, pattern in PatternLibrary.PRESETS.items(): + try: + PatternMatcher(pattern) + except ValueError: + pytest.fail(f"Preset '{name}' has invalid regex: {pattern}") + + +class TestMatchLocation: + """Tests for MatchLocation dataclass.""" + + def test_match_location_creation(self): + """Test creating a MatchLocation.""" + match = MatchLocation( + file_path="test.py", + line_number=10, + line_content="def hello():" , + match_start=0, + match_end=11, + ) + + assert match.file_path == "test.py" + assert match.line_number == 10 + assert match.line_content == "def hello():" + assert match.match_start == 0 + assert match.match_end == 11 + + def test_match_location_to_dict(self): + """Test converting MatchLocation to dictionary.""" + match = MatchLocation( + file_path="test.py", + line_number=5, + line_content="pass", + match_start=0, + match_end=4, + ) + + data = match.to_dict() + + assert data["file_path"] == "test.py" + assert data["line_number"] == 5 + assert data["line_content"] == "pass" + + +class TestLanguageExtensions: + """Tests for language extension mappings.""" + + def test_python_extensions(self): + """Test Python file extensions.""" + for ext in LANGUAGE_EXTENSIONS["python"]: + assert EXTENSION_TO_LANGUAGE.get(ext) == "python" + + def test_javascript_extensions(self): + """Test JavaScript file extensions.""" + for ext in LANGUAGE_EXTENSIONS["javascript"]: + assert EXTENSION_TO_LANGUAGE.get(ext) == "javascript" + + def test_all_extensions_mapped(self): + """Test that all extensions are mapped to a language.""" + for lang, extensions in LANGUAGE_EXTENSIONS.items(): + for ext in extensions: + assert ext in EXTENSION_TO_LANGUAGE, f"Extension {ext} not mapped"