diff --git a/git_commit_ai/core/ollama_client.py b/git_commit_ai/core/ollama_client.py new file mode 100644 index 0000000..e2c4e3f --- /dev/null +++ b/git_commit_ai/core/ollama_client.py @@ -0,0 +1,136 @@ +"""Ollama API client for Git Commit AI.""" + +import hashlib +import logging +from typing import Any, Optional + +import ollama +import requests + +from git_commit_ai.core.config import Config, get_config + +logger = logging.getLogger(__name__) + + +class OllamaClient: + """Client for communicating with Ollama API.""" + + def __init__(self, config: Optional[Config] = None): + self.config = config or get_config() + self._model: str = self.config.ollama_model + self._base_url: str = self.config.ollama_base_url + self._timeout: int = self.config.ollama_timeout + + @property + def model(self) -> str: + return self._model + + @model.setter + def model(self, value: str) -> None: + self._model = value + + @property + def base_url(self) -> str: + return self._base_url + + @base_url.setter + def base_url(self, value: str) -> None: + self._base_url = value + + def is_available(self) -> bool: + try: + response = requests.get(f"{self._base_url}/api/tags", timeout=10) + return response.status_code == 200 + except requests.RequestException: + return False + + def list_models(self) -> list[dict[str, Any]]: + try: + response = requests.get(f"{self._base_url}/api/tags", timeout=self._timeout) + if response.status_code == 200: + data = response.json() + return data.get("models", []) + return [] + except requests.RequestException as e: + logger.error(f"Failed to list models: {e}") + return [] + + def check_model_exists(self) -> bool: + models = self.list_models() + model_names = [m.get("name", "") for m in models] + return any(self._model in name for name in model_names) + + def pull_model(self, model: Optional[str] = None) -> bool: + model = model or self._model + try: + client = ollama.Client(host=self._base_url) + client.pull(model) + return True + except Exception as e: + logger.error(f"Failed to pull model {model}: {e}") + return False + + def generate(self, prompt: str, system: Optional[str] = None, model: Optional[str] = None, num_predict: int = 200, temperature: float = 0.7) -> str: + model = model or self._model + try: + client = ollama.Client(host=self._base_url) + response = client.generate( + model=model, prompt=prompt, system=system, + options={"num_predict": num_predict, "temperature": temperature} + ) + return response.get("response", "") + except Exception as e: + logger.error(f"Failed to generate response: {e}") + raise OllamaError(f"Failed to generate response: {e}") from e + + def generate_commit_message(self, diff: str, context: Optional[str] = None, conventional: bool = False, model: Optional[str] = None) -> str: + from git_commit_ai.prompts import PromptBuilder + + prompt_builder = PromptBuilder(self.config) + prompt = prompt_builder.build_prompt(diff, context, conventional) + system_prompt = prompt_builder.get_system_prompt(conventional) + + response = self.generate( + prompt=prompt, system=system_prompt, model=model, + num_predict=self.config.max_message_length + 50, + temperature=0.7 if not conventional else 0.5, + ) + + return self._parse_commit_message(response) + + def _parse_commit_message(self, response: str) -> str: + message = response.strip() + + if message.startswith("```"): + lines = message.split("\n") + if len(lines) >= 3: + content = "\n".join(lines[1:-1]) + if content.strip().startswith("git commit"): + content = content.replace("git commit -m ", "").strip() + if content.startswith('"') and content.endswith('"'): + content = content[1:-1] + return content.strip() + + if message.startswith('"') and message.endswith('"'): + message = message[1:-1] + + message = message.strip() + + max_length = self.config.max_message_length + if len(message) > max_length: + message = message[:max_length].rsplit(" ", 1)[0] + + return message + + +class OllamaError(Exception): + """Exception raised for Ollama-related errors.""" + pass + + +def generate_diff_hash(diff: str) -> str: + return hashlib.md5(diff.encode()).hexdigest() + + +def get_client(config: Optional[Config] = None) -> OllamaClient: + return OllamaClient(config)