Initial upload: Git Commit AI - privacy-first CLI for generating commit messages with local LLM
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
This commit is contained in:
136
git_commit_ai/core/ollama_client.py
Normal file
136
git_commit_ai/core/ollama_client.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Ollama API client for Git Commit AI."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import ollama
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from git_commit_ai.core.config import Config, get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient:
|
||||||
|
"""Client for communicating with Ollama API."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[Config] = None):
|
||||||
|
self.config = config or get_config()
|
||||||
|
self._model: str = self.config.ollama_model
|
||||||
|
self._base_url: str = self.config.ollama_base_url
|
||||||
|
self._timeout: int = self.config.ollama_timeout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@model.setter
|
||||||
|
def model(self, value: str) -> None:
|
||||||
|
self._model = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_url(self) -> str:
|
||||||
|
return self._base_url
|
||||||
|
|
||||||
|
@base_url.setter
|
||||||
|
def base_url(self, value: str) -> None:
|
||||||
|
self._base_url = value
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self._base_url}/api/tags", timeout=10)
|
||||||
|
return response.status_code == 200
|
||||||
|
except requests.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self._base_url}/api/tags", timeout=self._timeout)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return data.get("models", [])
|
||||||
|
return []
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to list models: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def check_model_exists(self) -> bool:
|
||||||
|
models = self.list_models()
|
||||||
|
model_names = [m.get("name", "") for m in models]
|
||||||
|
return any(self._model in name for name in model_names)
|
||||||
|
|
||||||
|
def pull_model(self, model: Optional[str] = None) -> bool:
|
||||||
|
model = model or self._model
|
||||||
|
try:
|
||||||
|
client = ollama.Client(host=self._base_url)
|
||||||
|
client.pull(model)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to pull model {model}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate(self, prompt: str, system: Optional[str] = None, model: Optional[str] = None, num_predict: int = 200, temperature: float = 0.7) -> str:
|
||||||
|
model = model or self._model
|
||||||
|
try:
|
||||||
|
client = ollama.Client(host=self._base_url)
|
||||||
|
response = client.generate(
|
||||||
|
model=model, prompt=prompt, system=system,
|
||||||
|
options={"num_predict": num_predict, "temperature": temperature}
|
||||||
|
)
|
||||||
|
return response.get("response", "")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate response: {e}")
|
||||||
|
raise OllamaError(f"Failed to generate response: {e}") from e
|
||||||
|
|
||||||
|
def generate_commit_message(self, diff: str, context: Optional[str] = None, conventional: bool = False, model: Optional[str] = None) -> str:
|
||||||
|
from git_commit_ai.prompts import PromptBuilder
|
||||||
|
|
||||||
|
prompt_builder = PromptBuilder(self.config)
|
||||||
|
prompt = prompt_builder.build_prompt(diff, context, conventional)
|
||||||
|
system_prompt = prompt_builder.get_system_prompt(conventional)
|
||||||
|
|
||||||
|
response = self.generate(
|
||||||
|
prompt=prompt, system=system_prompt, model=model,
|
||||||
|
num_predict=self.config.max_message_length + 50,
|
||||||
|
temperature=0.7 if not conventional else 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_commit_message(response)
|
||||||
|
|
||||||
|
def _parse_commit_message(self, response: str) -> str:
|
||||||
|
message = response.strip()
|
||||||
|
|
||||||
|
if message.startswith("```"):
|
||||||
|
lines = message.split("\n")
|
||||||
|
if len(lines) >= 3:
|
||||||
|
content = "\n".join(lines[1:-1])
|
||||||
|
if content.strip().startswith("git commit"):
|
||||||
|
content = content.replace("git commit -m ", "").strip()
|
||||||
|
if content.startswith('"') and content.endswith('"'):
|
||||||
|
content = content[1:-1]
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
if message.startswith('"') and message.endswith('"'):
|
||||||
|
message = message[1:-1]
|
||||||
|
|
||||||
|
message = message.strip()
|
||||||
|
|
||||||
|
max_length = self.config.max_message_length
|
||||||
|
if len(message) > max_length:
|
||||||
|
message = message[:max_length].rsplit(" ", 1)[0]
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaError(Exception):
|
||||||
|
"""Exception raised for Ollama-related errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def generate_diff_hash(diff: str) -> str:
|
||||||
|
return hashlib.md5(diff.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def get_client(config: Optional[Config] = None) -> OllamaClient:
|
||||||
|
return OllamaClient(config)
|
||||||
Reference in New Issue
Block a user