diff --git a/app/src/git_commit_generator/message_generator.py b/app/src/git_commit_generator/message_generator.py new file mode 100644 index 0000000..dd9981a --- /dev/null +++ b/app/src/git_commit_generator/message_generator.py @@ -0,0 +1,208 @@ +"""Message generator for commit messages.""" +from typing import Any + +from git_commit_generator.config import Config, get_config +from git_commit_generator.git_utils import GitUtils, get_git_utils +from git_commit_generator.ollama_client import OllamaClient, get_ollama_client + + +class MessageGenerator: + """Generates conventional commit messages using LLM.""" + + CHANGE_TYPE_INDICATORS = { + "test": ["+def test_", "+class Test", "test_", "assert ", "pytest", "+ assert "], + "feat": ["+def ", "+class ", "+async def ", "new feature", "add "], + "fix": ["-def ", "bug", "fix", "resolve", "error", "issue"], + "docs": ["#", "doc", "readme", "documentation"], + "refactor": ["refactor", "rename", "reorganize"], + "chore": ["chore", "dependency", "config"], + } + + def __init__(self, config: Config, ollama_client: OllamaClient, git_utils: GitUtils | None = None): + """Initialize message generator. + + Args: + config: Configuration object. + ollama_client: Ollama client instance. + git_utils: Git utils instance (optional). + """ + self.config = config + self.ollama_client = ollama_client + self.git_utils = git_utils or get_git_utils() + + def detect_change_type(self, diff: str) -> str: + """Detect the type of change from diff. + + Args: + diff: Git diff content. + + Returns: + Change type (feat, fix, docs, test, refactor, chore). + """ + diff_lower = diff.lower() + + for change_type, indicators in self.CHANGE_TYPE_INDICATORS.items(): + for indicator in indicators: + if indicator.lower() in diff_lower: + return change_type + + return "feat" + + def detect_scope(self, files: list[str]) -> str: + """Detect the scope from file paths. + + Args: + files: List of file paths. + + Returns: + Scope name (directory or 'core' for root files). + """ + if not files: + return "core" + + scopes: dict[str, int] = {} + + for file_path in files: + parts = file_path.split("/") + if len(parts) > 1: + scope = parts[0] + scopes[scope] = scopes.get(scope, 0) + 1 + else: + scopes["core"] = scopes.get("core", 0) + 1 + + if not scopes: + return "core" + + return max(scopes.items(), key=lambda x: x[1])[0] + + def parse_conventional_message(self, message: str) -> dict[str, Any]: + """Parse a conventional commit message. + + Args: + message: Commit message to parse. + + Returns: + Dictionary with type, scope, description, and full_message. + """ + message = self._clean_message(message) + + if ": " in message: + parts = message.split(": ", 1) + type_scope = parts[0] + description = parts[1] if len(parts) > 1 else "" + elif ": " not in message and " " in message: + type_scope = message.split(" ")[0] + description = message[len(type_scope) + 1:] if len(message) > len(type_scope) + 1 else "" + else: + return { + "type": "feat", + "scope": "", + "description": message, + "full_message": message, + } + + if "(" in type_scope and ")" in type_scope: + scope_start = type_scope.find("(") + scope_end = type_scope.find(")") + change_type = type_scope[:scope_start] + scope = type_scope[scope_start + 1:scope_end] + else: + change_type = type_scope + scope = "" + + return { + "type": change_type, + "scope": scope, + "description": description, + "full_message": message, + } + + def format_conventional_message(self, message_type: str, scope: str, description: str) -> str: + """Format a conventional commit message. + + Args: + message_type: Type of change (feat, fix, etc.). + scope: Scope of the change. + description: Description of the change. + + Returns: + Formatted commit message. + """ + if scope: + return f"{message_type}({scope}): {description}" + return f"{message_type}: {description}" + + def _clean_message(self, message: str) -> str: + """Clean a commit message. + + Args: + message: Raw message. + + Returns: + Cleaned message. + """ + message = message.strip() + if (message.startswith('"') and message.endswith('"')) or ( + message.startswith("'") and message.endswith("'") + ): + message = message[1:-1] + return message.strip() + + def generate(self) -> str: + """Generate a commit message. + + Returns: + Generated commit message. + + Raises: + ValueError: If no changes are detected. + """ + diff = self.git_utils.get_all_changes() + + if not diff: + raise ValueError("No changes detected. Stage some changes first.") + + change_type = self.detect_change_type(diff) + files = self.git_utils.get_staged_files() + scope = self.detect_scope(files) if files else "core" + + prompt = self.config.read_prompt("commit_message.txt") + + raw_message = self.ollama_client.generate_commit_message(diff, prompt) + message = self._clean_message(raw_message) + + parsed = self.parse_conventional_message(message) + parsed["type"] = change_type + parsed["scope"] = scope + + formatted = self.format_conventional_message( + message_type=parsed["type"], + scope=parsed["scope"], + description=parsed["description"], + ) + + return formatted + + +def get_message_generator( + config: Config | None = None, + ollama_client: OllamaClient | None = None, + git_utils: GitUtils | None = None, +) -> MessageGenerator: + """Get MessageGenerator instance. + + Args: + config: Configuration object (optional). + ollama_client: Ollama client (optional). + git_utils: Git utils (optional). + + Returns: + MessageGenerator instance. + """ + config = config or get_config() + ollama_client = ollama_client or get_ollama_client() + return MessageGenerator( + config=config, + ollama_client=ollama_client, + git_utils=git_utils, + )