Files
shellgenius/shellgenius/ollama_client.py
7000pctAUTO bf0bf36e66
Some checks failed
CI / test (push) Failing after 11s
CI / lint (push) Failing after 6s
CI / type-check (push) Failing after 11s
Add config and ollama client modules
2026-02-04 10:58:50 +00:00

174 lines
4.7 KiB
Python

"""Ollama client wrapper for ShellGenius."""
import json
import logging
from typing import Any, Dict, Generator, List, Optional
import ollama
from shellgenius.config import get_config
logger = logging.getLogger(__name__)
class OllamaClient:
"""Client for interacting with Ollama API."""
def __init__(self, host: Optional[str] = None, model: Optional[str] = None):
"""Initialize Ollama client.
Args:
host: Ollama server URL
model: Model name to use
"""
config = get_config()
self.host = host or config.ollama_host
self.model = model or config.ollama_model
self._client: Optional[ollama.Client] = None
@property
def client(self) -> ollama.Client:
"""Get or create Ollama client."""
if self._client is None:
self._client = ollama.Client(host=self.host)
return self._client
def is_available(self) -> bool:
"""Check if Ollama is available.
Returns:
True if Ollama is running and accessible
"""
try:
self.list_models()
return True
except Exception as e:
logger.error(f"Ollama not available: {e}")
return False
def list_models(self) -> List[str]:
"""List available models.
Returns:
List of model names
"""
try:
response = self.client.list()
return [m["name"] for m in response.get("models", [])]
except Exception as e:
logger.error(f"Failed to list models: {e}")
return []
def pull_model(self, model: Optional[str] = None) -> bool:
"""Pull a model from Ollama.
Args:
model: Model name to pull
Returns:
True if successful
"""
model = model or self.model
try:
self.client.pull(model)
return True
except Exception as e:
logger.error(f"Failed to pull model {model}: {e}")
return False
def generate(
self,
prompt: str,
model: Optional[str] = None,
stream: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""Generate response from model.
Args:
prompt: Input prompt
model: Model name (uses default if not specified)
stream: Whether to stream response
**kwargs: Additional arguments for Ollama
Returns:
Response dictionary
"""
model = model or self.model
try:
response = self.client.generate(
model=model,
prompt=prompt,
stream=stream,
**kwargs,
)
return {"success": True, "response": response}
except Exception as e:
logger.error(f"Generation failed: {e}")
return {"success": False, "error": str(e), "response": None}
def generate_stream(
self, prompt: str, model: Optional[str] = None
) -> Generator[str, None, None]:
"""Stream generation response.
Args:
prompt: Input prompt
model: Model name
Yields:
Chunks of generated text
"""
model = model or self.model
try:
response = self.client.generate(model=model, prompt=prompt, stream=True)
for chunk in response:
if "response" in chunk:
yield chunk["response"]
except Exception as e:
logger.error(f"Streaming generation failed: {e}")
yield f"Error: {e}"
def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
stream: bool = False,
) -> Dict[str, Any]:
"""Chat with model using messages format.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model name
stream: Whether to stream response
Returns:
Response dictionary
"""
model = model or self.model
try:
response = self.client.chat(
model=model,
messages=messages,
stream=stream,
)
return {"success": True, "response": response}
except Exception as e:
logger.error(f"Chat failed: {e}")
return {"success": False, "error": str(e), "response": None}
def get_ollama_client(
host: Optional[str] = None, model: Optional[str] = None
) -> OllamaClient:
"""Get Ollama client instance.
Args:
host: Ollama server URL
model: Model name
Returns:
OllamaClient instance
"""
return OllamaClient(host=host, model=model)