diff --git a/shellgenius/ollama_client.py b/shellgenius/ollama_client.py new file mode 100644 index 0000000..9d693b8 --- /dev/null +++ b/shellgenius/ollama_client.py @@ -0,0 +1,173 @@ +"""Ollama client wrapper for ShellGenius.""" + +import json +import logging +from typing import Any, Dict, Generator, List, Optional + +import ollama + +from shellgenius.config import get_config + +logger = logging.getLogger(__name__) + + +class OllamaClient: + """Client for interacting with Ollama API.""" + + def __init__(self, host: Optional[str] = None, model: Optional[str] = None): + """Initialize Ollama client. + + Args: + host: Ollama server URL + model: Model name to use + """ + config = get_config() + self.host = host or config.ollama_host + self.model = model or config.ollama_model + self._client: Optional[ollama.Client] = None + + @property + def client(self) -> ollama.Client: + """Get or create Ollama client.""" + if self._client is None: + self._client = ollama.Client(host=self.host) + return self._client + + def is_available(self) -> bool: + """Check if Ollama is available. + + Returns: + True if Ollama is running and accessible + """ + try: + self.list_models() + return True + except Exception as e: + logger.error(f"Ollama not available: {e}") + return False + + def list_models(self) -> List[str]: + """List available models. + + Returns: + List of model names + """ + try: + response = self.client.list() + return [m["name"] for m in response.get("models", [])] + except Exception as e: + logger.error(f"Failed to list models: {e}") + return [] + + def pull_model(self, model: Optional[str] = None) -> bool: + """Pull a model from Ollama. + + Args: + model: Model name to pull + + Returns: + True if successful + """ + model = model or self.model + try: + self.client.pull(model) + return True + except Exception as e: + logger.error(f"Failed to pull model {model}: {e}") + return False + + def generate( + self, + prompt: str, + model: Optional[str] = None, + stream: bool = False, + **kwargs, + ) -> Dict[str, Any]: + """Generate response from model. + + Args: + prompt: Input prompt + model: Model name (uses default if not specified) + stream: Whether to stream response + **kwargs: Additional arguments for Ollama + + Returns: + Response dictionary + """ + model = model or self.model + try: + response = self.client.generate( + model=model, + prompt=prompt, + stream=stream, + **kwargs, + ) + return {"success": True, "response": response} + except Exception as e: + logger.error(f"Generation failed: {e}") + return {"success": False, "error": str(e), "response": None} + + def generate_stream( + self, prompt: str, model: Optional[str] = None + ) -> Generator[str, None, None]: + """Stream generation response. + + Args: + prompt: Input prompt + model: Model name + + Yields: + Chunks of generated text + """ + model = model or self.model + try: + response = self.client.generate(model=model, prompt=prompt, stream=True) + for chunk in response: + if "response" in chunk: + yield chunk["response"] + except Exception as e: + logger.error(f"Streaming generation failed: {e}") + yield f"Error: {e}" + + def chat( + self, + messages: List[Dict[str, str]], + model: Optional[str] = None, + stream: bool = False, + ) -> Dict[str, Any]: + """Chat with model using messages format. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model name + stream: Whether to stream response + + Returns: + Response dictionary + """ + model = model or self.model + try: + response = self.client.chat( + model=model, + messages=messages, + stream=stream, + ) + return {"success": True, "response": response} + except Exception as e: + logger.error(f"Chat failed: {e}") + return {"success": False, "error": str(e), "response": None} + + +def get_ollama_client( + host: Optional[str] = None, model: Optional[str] = None +) -> OllamaClient: + """Get Ollama client instance. + + Args: + host: Ollama server URL + model: Model name + + Returns: + OllamaClient instance + """ + return OllamaClient(host=host, model=model)