diff --git a/src/llm/ollama.py b/src/llm/ollama.py new file mode 100644 index 0000000..d745651 --- /dev/null +++ b/src/llm/ollama.py @@ -0,0 +1,72 @@ +"""Ollama LLM client implementation.""" + +from collections.abc import Iterator + +import requests + +from ..config import get_config +from .base import LLMClient + + +class OllamaClient(LLMClient): + """Client for Ollama API.""" + + def __init__(self, url: str = None): + config = get_config() + self.url = url or config.ollama_url + + def generate(self, prompt: str, model: str = None, **kwargs) -> str: + """Generate a response using Ollama API.""" + config = get_config() + model = model or config.default_model + payload = { + "model": model, + "prompt": prompt, + "stream": False, + } + payload.update(kwargs) + response = requests.post(f"{self.url}/api/generate", json=payload, timeout=120) + response.raise_for_status() + data = response.json() + return data.get("response", "") + + def stream_generate(self, prompt: str, model: str = None, **kwargs) -> Iterator[str]: + """Stream a response using Ollama API.""" + config = get_config() + model = model or config.default_model + payload = { + "model": model, + "prompt": prompt, + "stream": True, + } + payload.update(kwargs) + response = requests.post( + f"{self.url}/api/generate", + json=payload, + stream=True, + timeout=120 + ) + response.raise_for_status() + for line in response.iter_lines(): + if line: + data = requests.get(line.decode("utf-8")).json() + yield data.get("response", "") + + def test_connection(self) -> bool: + """Test if Ollama is available.""" + try: + response = requests.get(f"{self.url}/api/tags", timeout=5) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + def get_available_models(self) -> list[str]: + """Get list of available models from Ollama.""" + try: + response = requests.get(f"{self.url}/api/tags", timeout=5) + if response.status_code == 200: + data = response.json() + return [m["name"] for m in data.get("models", [])] + except requests.exceptions.RequestException: + pass + return []