diff --git a/src/pattern_matcher.py b/src/pattern_matcher.py new file mode 100644 index 0000000..a87a6e6 --- /dev/null +++ b/src/pattern_matcher.py @@ -0,0 +1,263 @@ +"""Pattern matching engine for Code Pattern Search CLI.""" + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .models import MatchLocation + + +@dataclass +class FileMatch: + """Represents matches found in a single file.""" + + file_path: str + matches: list[MatchLocation] + match_count: int + + +LANGUAGE_EXTENSIONS: dict[str, list[str]] = { + "python": [".py", ".pyw", ".pyx"], + "javascript": [".js", ".jsx", ".mjs"], + "typescript": [".ts", ".tsx"], + "java": [".java"], + "kotlin": [".kt", ".kts"], + "go": [".go"], + "rust": [".rs"], + "cpp": [".cpp", ".cc", ".cxx", ".hpp", ".h", ".hxx"], + "c": [".c", ".h"], + "c#": [".cs"], + "ruby": [".rb", ".erb"], + "php": [".php", ".phtml"], + "swift": [".swift"], + "objective-c": [".m", ".mm"], + "scala": [".scala"], + "html": [".html", ".htm"], + "css": [".css", ".scss", ".sass", ".less"], + "json": [".json"], + "yaml": [".yaml", ".yml"], + "xml": [".xml"], + "markdown": [".md", ".markdown"], + "shell": [".sh", ".bash", ".zsh"], + "powershell": [".ps1"], + "dockerfile": ["Dockerfile"], + "sql": [".sql"], + "vue": [".vue"], + "svelte": [".svelte"], +} + + +EXTENSION_TO_LANGUAGE: dict[str, str] = {} +for lang, extensions in LANGUAGE_EXTENSIONS.items(): + for ext in extensions: + EXTENSION_TO_LANGUAGE[ext] = lang + + +class PatternLibrary: + """Library of built-in code patterns.""" + + PRESETS: dict[str, str] = { + "react-useeffect": r"useEffect\s*\([^)]*\)\s*\{", + "react-useeffect-deps": r"useEffect\s*\([^)]*,\s*\[[^\]]*\]\s*\)", + "python-dataclass": r"@dataclass\s*\n?class\s+\w+", + "python-decorator": r"@\w+\s*\n?def\s+\w+\s*\(", + "python-async": r"async\s+def\s+\w+", + "python-typed-dict": r"class\s+\w+\s*\(\s*TypedDict\s*\)", + "go-error-handling": r"if\s+err\s*!=\s*nil", + "go-defer": r"defer\s+\w+\.", + "go-goroutine": r"go\s+\w+\(", + "go-error-wrap": r"fmt\.Errorf\s*\(\s*[\"']", + "ts-interface": r"interface\s+\w+\s*\{", + "ts-type-alias": r"type\s+\w+\s*=\s*", + "ts-generic": r"<[A-Z]\w*>", + "js-async-await": r"async\s+\w+\s*\([^)]*\)\s*=>\s*await", + "js-fetch": r"fetch\s*\(\s*[\"']", + "rust-match": r"match\s+\w+\s*\{", + "rust-trait": r"impl\s+\w+\s+for\s+\w+", + "rust-lifetime": r"'[a-z]", + "java-stream": r"\.stream\s*\(\)", + "java-optional": r"Optional\.", + "ruby-on-rails": r"class\s+\w+\s*<\s*ApplicationController", + "docker-cmd": r"CMD\s+\[", + "docker-entrypoint": r"ENTRYPOINT\s+\[", + "sql-select": r"SELECT\s+\*?\s+FROM\s+\w+", + "sql-join": r"JOIN\s+\w+\s+ON", + } + + @classmethod + def get_pattern(cls, name: str) -> Optional[str]: + """Get a preset pattern by name.""" + return cls.PRESETS.get(name) + + @classmethod + def list_presets(cls) -> list[str]: + """List all available preset names.""" + return sorted(cls.PRESETS.keys()) + + @classmethod + def get_presets_by_category(cls) -> dict[str, list[str]]: + """Get presets organized by category.""" + categories: dict[str, list[str]] = {} + + for name in cls.PRESETS: + if name.startswith("python"): + category = "Python" + elif name.startswith(("react", "js", "ts")): + category = "JavaScript/TypeScript" + elif name.startswith("go"): + category = "Go" + elif name.startswith("rust"): + category = "Rust" + elif name.startswith("java"): + category = "Java" + elif name.startswith("ruby"): + category = "Ruby" + elif name.startswith("docker"): + category = "Docker" + elif name.startswith("sql"): + category = "SQL" + else: + category = "Other" + + if category not in categories: + categories[category] = [] + categories[category].append(name) + + return categories + + +class PatternMatcher: + """Engine for matching patterns in code content.""" + + def __init__( + self, + pattern: str, + flags: re.RegexFlag = re.MULTILINE, + ) -> None: + """Initialize the pattern matcher.""" + self.pattern = pattern + self.flags = flags + self._compiled: Optional[re.Pattern[str]] = None + + @property + def compiled(self) -> re.Pattern[str]: + """Get compiled regex pattern.""" + if self._compiled is None: + try: + self._compiled = re.compile(self.pattern, self.flags) + except re.error as e: + raise ValueError(f"Invalid regex pattern: {e}") + return self._compiled + + def matches_extension(self, file_path: str) -> bool: + """Check if a file path extension should be searched.""" + ext = Path(file_path).suffix.lower() + + if not ext: + return False + + return ext in EXTENSION_TO_LANGUAGE + + def get_language(self, file_path: str) -> Optional[str]: + """Get the language of a file based on its extension.""" + ext = Path(file_path).suffix.lower() + return EXTENSION_TO_LANGUAGE.get(ext) + + def find_matches( + self, + content: str, + file_path: str, + ) -> list[MatchLocation]: + """Find all pattern matches in content.""" + matches: list[MatchLocation] = [] + + lines = content.split("\n") + + for line_num, line in enumerate(lines, start=1): + for match in self.compiled.finditer(line): + matches.append( + MatchLocation( + file_path=file_path, + line_number=line_num, + line_content=line, + match_start=match.start(), + match_end=match.end(), + ) + ) + + return matches + + def find_multiline_matches( + self, + content: str, + file_path: str, + ) -> list[MatchLocation]: + """Find multiline pattern matches in content.""" + matches: list[MatchLocation] = [] + + lines = content.split("\n") + + for line_num, line in enumerate(lines, start=1): + if self.compiled.search(line): + matches.append( + MatchLocation( + file_path=file_path, + line_number=line_num, + line_content=line, + match_start=0, + match_end=len(line), + ) + ) + + return matches + + def count_matches(self, content: str) -> int: + """Count total matches in content.""" + return len(self.compiled.findall(content)) + + def has_matches(self, content: str) -> bool: + """Check if content contains any matches.""" + return self.compiled.search(content) is not None + + def get_match_context( + self, + content: str, + file_path: str, + context_lines: int = 2, + ) -> list[FileMatch]: + """Get matches with surrounding context.""" + matches = self.find_matches(content, file_path) + + if not matches: + return [] + + file_matches: list[FileMatch] = [] + current_match_file: Optional[str] = None + current_matches: list[MatchLocation] = [] + + for match in matches: + if current_match_file != match.file_path: + if current_matches and current_match_file is not None: + file_matches.append( + FileMatch( + file_path=current_match_file, + matches=current_matches, + match_count=len(current_matches), + ) + ) + current_match_file = match.file_path + current_matches = [] + + current_matches.append(match) + + if current_matches and current_match_file is not None: + file_matches.append( + FileMatch( + file_path=current_match_file, + matches=current_matches, + match_count=len(current_matches), + ) + ) + + return file_matches