"""Ollama API client for Git Commit AI.""" import hashlib import logging from typing import Any, Optional import ollama import requests from git_commit_ai.core.config import Config, get_config logger = logging.getLogger(__name__) class OllamaClient: """Client for communicating with Ollama API.""" def __init__(self, config: Optional[Config] = None): self.config = config or get_config() self._model: str = self.config.ollama_model self._base_url: str = self.config.ollama_base_url self._timeout: int = self.config.ollama_timeout @property def model(self) -> str: return self._model @model.setter def model(self, value: str) -> None: self._model = value @property def base_url(self) -> str: return self._base_url @base_url.setter def base_url(self, value: str) -> None: self._base_url = value def is_available(self) -> bool: try: response = requests.get(f"{self._base_url}/api/tags", timeout=10) return response.status_code == 200 except requests.RequestException: return False def list_models(self) -> list[dict[str, Any]]: try: response = requests.get(f"{self._base_url}/api/tags", timeout=self._timeout) if response.status_code == 200: data = response.json() return data.get("models", []) return [] except requests.RequestException as e: logger.error(f"Failed to list models: {e}") return [] def check_model_exists(self) -> bool: models = self.list_models() model_names = [m.get("name", "") for m in models] return any(self._model in name for name in model_names) def pull_model(self, model: Optional[str] = None) -> bool: model = model or self._model try: client = ollama.Client(host=self._base_url) client.pull(model) return True except Exception as e: logger.error(f"Failed to pull model {model}: {e}") return False def generate(self, prompt: str, system: Optional[str] = None, model: Optional[str] = None, num_predict: int = 200, temperature: float = 0.7) -> str: model = model or self._model try: client = ollama.Client(host=self._base_url) response = client.generate( model=model, prompt=prompt, system=system, options={"num_predict": num_predict, "temperature": temperature} ) return response.get("response", "") except Exception as e: logger.error(f"Failed to generate response: {e}") raise OllamaError(f"Failed to generate response: {e}") from e def generate_commit_message(self, diff: str, context: Optional[str] = None, conventional: bool = False, model: Optional[str] = None) -> str: from git_commit_ai.prompts import PromptBuilder prompt_builder = PromptBuilder(self.config) prompt = prompt_builder.build_prompt(diff, context, conventional) system_prompt = prompt_builder.get_system_prompt(conventional) response = self.generate( prompt=prompt, system=system_prompt, model=model, num_predict=self.config.max_message_length + 50, temperature=0.7 if not conventional else 0.5, ) return self._parse_commit_message(response) def _parse_commit_message(self, response: str) -> str: message = response.strip() if message.startswith("```"): lines = message.split("\n") if len(lines) >= 3: content = "\n".join(lines[1:-1]) if content.strip().startswith("git commit"): content = content.replace("git commit -m ", "").strip() if content.startswith('"') and content.endswith('"'): content = content[1:-1] return content.strip() if message.startswith('"') and message.endswith('"'): message = message[1:-1] message = message.strip() max_length = self.config.max_message_length if len(message) > max_length: message = message[:max_length].rsplit(" ", 1)[0] return message class OllamaError(Exception): """Exception raised for Ollama-related errors.""" pass def generate_diff_hash(diff: str) -> str: return hashlib.md5(diff.encode()).hexdigest() def get_client(config: Optional[Config] = None) -> OllamaClient: return OllamaClient(config)