diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py new file mode 100644 index 0000000..0d95490 --- /dev/null +++ b/src/utils/file_utils.py @@ -0,0 +1,113 @@ +"""File utilities for AI Code Audit CLI.""" + +import os +from pathlib import Path +from typing import Iterator, Optional + + +class FileUtils: + """Utilities for file operations.""" + + SUPPORTED_EXTENSIONS = { + ".py", + ".js", + ".ts", + ".jsx", + ".tsx", + } + + def __init__(self): + """Initialize file utilities.""" + self.supported_extensions = self.SUPPORTED_EXTENSIONS + + def find_files( + self, + path: Path, + max_size: int = 5 * 1024 * 1024, + excluded_patterns: list[str] | None = None, + ) -> Iterator[Path]: + """Find all files in a directory recursively.""" + if excluded_patterns is None: + excluded_patterns = [] + + if path.is_file(): + if self._should_include(path, max_size, excluded_patterns): + yield path + return + + for root, dirs, files in os.walk(path): + for pattern in excluded_patterns[:]: + if pattern in dirs: + dirs.remove(pattern) + + for file in files: + file_path = Path(root) / file + if self._should_include(file_path, max_size, excluded_patterns): + yield file_path + + def _should_include( + self, file_path: Path, max_size: int, excluded_patterns: list[str] + ) -> bool: + """Check if a file should be included in scanning.""" + try: + if file_path.suffix.lower() not in self.supported_extensions: + return False + + if file_path.stat().st_size > max_size: + return False + + str_path = str(file_path) + for pattern in excluded_patterns: + if pattern in str_path: + return False + + return True + except (OSError, PermissionError): + return False + + def read_file(self, file_path: Path, encoding: str = "utf-8") -> str: + """Read file contents with error handling.""" + try: + return file_path.read_text(encoding=encoding, errors="replace") + except UnicodeDecodeError: + try: + return file_path.read_text(encoding="latin-1", errors="replace") + except Exception: + raise ValueError(f"Could not decode file: {file_path}") + + def get_file_info(self, file_path: Path) -> dict: + """Get information about a file.""" + try: + stat = file_path.stat() + return { + "path": str(file_path.absolute()), + "name": file_path.name, + "extension": file_path.suffix, + "size": stat.st_size, + "modified": stat.st_mtime, + "is_readable": os.access(file_path, os.R_OK), + } + except Exception: + return { + "path": str(file_path.absolute()), + "name": file_path.name, + "extension": file_path.suffix, + "size": 0, + "modified": 0, + "is_readable": False, + } + + def count_lines(self, content: str) -> int: + """Count lines in content.""" + if not content: + return 0 + return len(content.split('\n')) + + def get_lines_around( + self, content: str, line_number: int, context: int = 2 + ) -> list[str]: + """Get lines around a specific line number.""" + lines = content.split('\n') + start = max(0, line_number - context - 1) + end = min(len(lines), line_number + context) + return lines[start:end]