202 lines
6.7 KiB
Python
202 lines
6.7 KiB
Python
"""Tests for the pattern matching engine."""
|
|
|
|
import re
|
|
import pytest
|
|
from src.pattern_matcher import (
|
|
PatternMatcher,
|
|
PatternLibrary,
|
|
MatchLocation,
|
|
LANGUAGE_EXTENSIONS,
|
|
EXTENSION_TO_LANGUAGE,
|
|
)
|
|
|
|
|
|
class TestPatternMatcher:
|
|
"""Tests for PatternMatcher class."""
|
|
|
|
def test_simple_pattern_match(self):
|
|
"""Test basic regex pattern matching."""
|
|
matcher = PatternMatcher(r"def\s+\w+")
|
|
content = "def hello():\n pass\ndef world():"
|
|
matches = matcher.find_matches(content, "test.py")
|
|
|
|
assert len(matches) == 2
|
|
assert matches[0].line_number == 1
|
|
assert matches[1].line_number == 3
|
|
|
|
def test_invalid_pattern_raises_error(self):
|
|
"""Test that invalid regex raises error when compiled."""
|
|
matcher = PatternMatcher(r"[")
|
|
try:
|
|
_ = matcher.compiled
|
|
pytest.fail("Should have raised an error")
|
|
except (ValueError, re.error):
|
|
pass
|
|
|
|
def test_valid_pattern_no_error(self):
|
|
"""Test that valid regex doesn't raise error."""
|
|
matcher = PatternMatcher(r"[unclosed")
|
|
assert matcher is not None
|
|
|
|
def test_no_matches(self):
|
|
"""Test pattern with no matches."""
|
|
matcher = PatternMatcher(r"notgoingtomatch")
|
|
content = "def hello(): pass"
|
|
matches = matcher.find_matches(content, "test.py")
|
|
|
|
assert len(matches) == 0
|
|
|
|
def test_matches_extension(self):
|
|
"""Test file extension matching."""
|
|
matcher = PatternMatcher(r"test")
|
|
|
|
assert matcher.matches_extension("test.py") is True
|
|
assert matcher.matches_extension("test.js") is True
|
|
assert matcher.matches_extension("test.py") is True
|
|
assert matcher.matches_extension("test") is False
|
|
assert matcher.matches_extension("test.txt") is False
|
|
|
|
def test_get_language(self):
|
|
"""Test language detection from extension."""
|
|
matcher = PatternMatcher(r"test")
|
|
|
|
assert matcher.get_language("file.py") == "python"
|
|
assert matcher.get_language("file.js") == "javascript"
|
|
assert matcher.get_language("file.ts") == "typescript"
|
|
assert matcher.get_language("file.go") == "go"
|
|
assert matcher.get_language("file.rs") == "rust"
|
|
assert matcher.get_language("file.java") == "java"
|
|
assert matcher.get_language("file.unknown") is None
|
|
|
|
def test_count_matches(self):
|
|
"""Test counting matches in content."""
|
|
matcher = PatternMatcher(r"\b\w+\b")
|
|
content = "hello world test one"
|
|
count = matcher.count_matches(content)
|
|
|
|
assert count == 4
|
|
|
|
def test_has_matches(self):
|
|
"""Test checking if content has matches."""
|
|
matcher = PatternMatcher(r"def\s+\w+")
|
|
|
|
assert matcher.has_matches("def hello():") is True
|
|
assert matcher.has_matches("class Hello:") is False
|
|
|
|
def test_case_sensitive_match(self):
|
|
"""Test case-sensitive matching."""
|
|
matcher = PatternMatcher(r"Hello")
|
|
content = "Hello hello HELLO"
|
|
|
|
matches = matcher.find_matches(content, "test.txt")
|
|
assert len(matches) == 1
|
|
|
|
def test_multiline_pattern(self):
|
|
"""Test patterns spanning multiple lines."""
|
|
PatternMatcher(r"def\s+\w+\s*\n\s*:", flags=0)
|
|
|
|
def test_special_characters_in_pattern(self):
|
|
"""Test pattern with special regex characters."""
|
|
matcher = PatternMatcher(r"\d+\.\d+")
|
|
content = "Numbers like 3.14 and 2.71"
|
|
matches = matcher.find_matches(content, "test.txt")
|
|
|
|
assert len(matches) == 2
|
|
|
|
|
|
class TestPatternLibrary:
|
|
"""Tests for PatternLibrary class."""
|
|
|
|
def test_get_preset(self):
|
|
"""Test getting a preset pattern."""
|
|
pattern = PatternLibrary.get_pattern("python-dataclass")
|
|
assert pattern is not None
|
|
assert "@dataclass" in pattern
|
|
|
|
def test_get_nonexistent_preset(self):
|
|
"""Test getting a preset that doesn't exist."""
|
|
pattern = PatternLibrary.get_pattern("nonexistent")
|
|
assert pattern is None
|
|
|
|
def test_list_presets(self):
|
|
"""Test listing all presets."""
|
|
presets = PatternLibrary.list_presets()
|
|
|
|
assert isinstance(presets, list)
|
|
assert len(presets) > 0
|
|
assert "python-dataclass" in presets
|
|
assert "react-useeffect" in presets
|
|
|
|
def test_get_presets_by_category(self):
|
|
"""Test getting presets organized by category."""
|
|
categories = PatternLibrary.get_presets_by_category()
|
|
|
|
assert isinstance(categories, dict)
|
|
assert "Python" in categories
|
|
assert "JavaScript/TypeScript" in categories
|
|
assert "Go" in categories
|
|
|
|
def test_all_presets_are_valid_regex(self):
|
|
"""Test that all presets are valid regex patterns."""
|
|
for name, pattern in PatternLibrary.PRESETS.items():
|
|
try:
|
|
PatternMatcher(pattern)
|
|
except ValueError:
|
|
pytest.fail(f"Preset '{name}' has invalid regex: {pattern}")
|
|
|
|
|
|
class TestMatchLocation:
|
|
"""Tests for MatchLocation dataclass."""
|
|
|
|
def test_match_location_creation(self):
|
|
"""Test creating a MatchLocation."""
|
|
match = MatchLocation(
|
|
file_path="test.py",
|
|
line_number=10,
|
|
line_content="def hello():" ,
|
|
match_start=0,
|
|
match_end=11,
|
|
)
|
|
|
|
assert match.file_path == "test.py"
|
|
assert match.line_number == 10
|
|
assert match.line_content == "def hello():"
|
|
assert match.match_start == 0
|
|
assert match.match_end == 11
|
|
|
|
def test_match_location_to_dict(self):
|
|
"""Test converting MatchLocation to dictionary."""
|
|
match = MatchLocation(
|
|
file_path="test.py",
|
|
line_number=5,
|
|
line_content="pass",
|
|
match_start=0,
|
|
match_end=4,
|
|
)
|
|
|
|
data = match.to_dict()
|
|
|
|
assert data["file_path"] == "test.py"
|
|
assert data["line_number"] == 5
|
|
assert data["line_content"] == "pass"
|
|
|
|
|
|
class TestLanguageExtensions:
|
|
"""Tests for language extension mappings."""
|
|
|
|
def test_python_extensions(self):
|
|
"""Test Python file extensions."""
|
|
for ext in LANGUAGE_EXTENSIONS["python"]:
|
|
assert EXTENSION_TO_LANGUAGE.get(ext) == "python"
|
|
|
|
def test_javascript_extensions(self):
|
|
"""Test JavaScript file extensions."""
|
|
for ext in LANGUAGE_EXTENSIONS["javascript"]:
|
|
assert EXTENSION_TO_LANGUAGE.get(ext) == "javascript"
|
|
|
|
def test_all_extensions_mapped(self):
|
|
"""Test that all extensions are mapped to a language."""
|
|
for lang, extensions in LANGUAGE_EXTENSIONS.items():
|
|
for ext in extensions:
|
|
assert ext in EXTENSION_TO_LANGUAGE, f"Extension {ext} not mapped"
|