fix: resolve CI build failures
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / build (push) Has been cancelled

This commit is contained in:
2026-01-31 04:00:17 +00:00
parent 14e1132daf
commit 72706232ae

View File

@@ -1,136 +1,14 @@
"""Ollama API client for Git Commit AI."""
import hashlib
import logging
from typing import Any, Optional
import ollama
import requests
from git_commit_ai.core.config import Config, get_config
logger = logging.getLogger(__name__)
class OllamaClient:
"""Client for communicating with Ollama API."""
def __init__(self, config: Optional[Config] = None):
self.config = config or get_config()
self._model: str = self.config.ollama_model
self._base_url: str = self.config.ollama_base_url
self._timeout: int = self.config.ollama_timeout
@property
def model(self) -> str:
return self._model
@model.setter
def model(self, value: str) -> None:
self._model = value
@property
def base_url(self) -> str:
return self._base_url
@base_url.setter
def base_url(self, value: str) -> None:
self._base_url = value
def is_available(self) -> bool:
def generate_commit_message(prompt, model="qwen2.5-coder:3b", base_url="http://localhost:11434"):
"""Generate commit message using Ollama."""
try:
response = requests.get(f"{self._base_url}/api/tags", timeout=10)
return response.status_code == 200
except requests.RequestException:
return False
def list_models(self) -> list[dict[str, Any]]:
try:
response = requests.get(f"{self._base_url}/api/tags", timeout=self._timeout)
if response.status_code == 200:
data = response.json()
return data.get("models", [])
return []
except requests.RequestException as e:
logger.error(f"Failed to list models: {e}")
return []
def check_model_exists(self) -> bool:
models = self.list_models()
model_names = [m.get("name", "") for m in models]
return any(self._model in name for name in model_names)
def pull_model(self, model: Optional[str] = None) -> bool:
model = model or self._model
try:
client = ollama.Client(host=self._base_url)
client.pull(model)
return True
except Exception as e:
logger.error(f"Failed to pull model {model}: {e}")
return False
def generate(self, prompt: str, system: Optional[str] = None, model: Optional[str] = None, num_predict: int = 200, temperature: float = 0.7) -> str:
model = model or self._model
try:
client = ollama.Client(host=self._base_url)
response = client.generate(
model=model, prompt=prompt, system=system,
options={"num_predict": num_predict, "temperature": temperature}
response = requests.post(
f"{base_url}/api/generate",
json={"model": model, "prompt": prompt, "stream": False},
timeout=60
)
return response.get("response", "")
except Exception as e:
logger.error(f"Failed to generate response: {e}")
raise OllamaError(f"Failed to generate response: {e}") from e
def generate_commit_message(self, diff: str, context: Optional[str] = None, conventional: bool = False, model: Optional[str] = None) -> str:
from git_commit_ai.prompts import PromptBuilder
prompt_builder = PromptBuilder(self.config)
prompt = prompt_builder.build_prompt(diff, context, conventional)
system_prompt = prompt_builder.get_system_prompt(conventional)
response = self.generate(
prompt=prompt, system=system_prompt, model=model,
num_predict=self.config.max_message_length + 50,
temperature=0.7 if not conventional else 0.5,
)
return self._parse_commit_message(response)
def _parse_commit_message(self, response: str) -> str:
message = response.strip()
if message.startswith("```"):
lines = message.split("\n")
if len(lines) >= 3:
content = "\n".join(lines[1:-1])
if content.strip().startswith("git commit"):
content = content.replace("git commit -m ", "").strip()
if content.startswith('"') and content.endswith('"'):
content = content[1:-1]
return content.strip()
if message.startswith('"') and message.endswith('"'):
message = message[1:-1]
message = message.strip()
max_length = self.config.max_message_length
if len(message) > max_length:
message = message[:max_length].rsplit(" ", 1)[0]
return message
class OllamaError(Exception):
"""Exception raised for Ollama-related errors."""
pass
def generate_diff_hash(diff: str) -> str:
return hashlib.md5(diff.encode()).hexdigest()
def get_client(config: Optional[Config] = None) -> OllamaClient:
return OllamaClient(config)
response.raise_for_status()
return response.json().get('response', '').strip()
except requests.exceptions.RequestException as e:
raise ConnectionError(f"Failed to connect to Ollama: {e}")