"""Git operations handler for Git Commit AI.""" import os from pathlib import Path from typing import Optional from git import Repo from git.exc import GitCommandError, InvalidGitRepositoryError class GitHandler: """Handler for Git operations.""" def __init__(self, repo_path: Optional[str] = None): if repo_path is None: repo_path = os.getcwd() self.repo_path = Path(repo_path) self._repo: Optional[Repo] = None @property def repo(self) -> Repo: if self._repo is None: self._repo = Repo(str(self.repo_path)) return self._repo def is_repository(self) -> bool: try: self.repo.git.status() return True except (InvalidGitRepositoryError, GitCommandError): return False def ensure_repository(self) -> bool: return self.is_repository() def get_staged_changes(self) -> str: try: if not self.is_staged(): return "" diff = self.repo.git.diff("--cached") return diff except GitCommandError as e: raise GitError(f"Failed to get staged changes: {e}") from e def get_staged_files(self) -> list[str]: try: staged = self.repo.index.diff("HEAD") return [s.a_path for s in staged] except GitCommandError: return [] def is_staged(self) -> bool: try: return bool(self.repo.index.diff("HEAD")) except GitCommandError: return False def get_commit_history(self, max_commits: int = 5, conventional_only: bool = False) -> list[dict[str, str]]: try: commits = [] for commit in self.repo.iter_commits(max_count=max_commits): message = commit.message.strip() if conventional_only: if not self._is_conventional(message): continue commits.append({"hash": commit.hexsha[:7], "message": message, "type": self._extract_type(message)}) return commits except GitCommandError as e: raise GitError(f"Failed to get commit history: {e}") from e def _is_conventional(self, message: str) -> bool: conventional_types = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"] return any(message.startswith(f"{t}:") for t in conventional_types) def _extract_type(self, message: str) -> str: conventional_types = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"] for t in conventional_types: if message.startswith(f"{t}:"): return t return "unknown" def get_changed_languages(self) -> list[str]: staged_files = self.get_staged_files() languages = set() extension_map = { ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", ".jsx": "React", ".tsx": "TypeScript React", ".java": "Java", ".go": "Go", ".rs": "Rust", ".rb": "Ruby", ".php": "PHP", ".swift": "Swift", ".c": "C", ".cpp": "C++", ".h": "C Header", ".cs": "C#", ".scala": "Scala", ".kt": "Kotlin", ".lua": "Lua", ".r": "R", ".sql": "SQL", ".html": "HTML", ".css": "CSS", ".scss": "SCSS", ".json": "JSON", ".yaml": "YAML", ".yml": "YAML", ".xml": "XML", ".md": "Markdown", ".sh": "Shell", ".bash": "Bash", ".zsh": "Zsh", ".dockerfile": "Docker", ".tf": "Terraform", } for file_path in staged_files: ext = Path(file_path).suffix.lower() if ext in extension_map: languages.add(extension_map[ext]) return sorted(list(languages)) def get_diff_summary(self) -> str: diff = self.get_staged_changes() if not diff: return "No staged changes" files = self.get_staged_files() languages = self.get_changed_languages() summary = f"Files changed: {len(files)}\n" if languages: summary += f"Languages: {', '.join(languages)}\n" summary += f"\nDiff length: {len(diff)} characters" return summary class GitError(Exception): """Exception raised for Git-related errors.""" pass def get_git_handler(repo_path: Optional[str] = None) -> GitHandler: return GitHandler(repo_path)