diff --git a/src/promptforge/providers/ollama.py b/src/promptforge/providers/ollama.py index 1c04daf..2e4c500 100644 --- a/src/promptforge/providers/ollama.py +++ b/src/promptforge/providers/ollama.py @@ -1,6 +1,9 @@ -import asyncio +"""Ollama provider implementation for local models.""" + +import json import time from typing import Any, AsyncIterator, Dict, Optional + import httpx from .base import ProviderBase, ProviderResponse @@ -8,20 +11,35 @@ from ..core.exceptions import ProviderError class OllamaProvider(ProviderBase): + """Ollama local model provider.""" + def __init__( self, + api_key: Optional[str] = None, model: str = "llama2", temperature: float = 0.7, base_url: str = "http://localhost:11434", **kwargs, ): - super().__init__(None, model, temperature, **kwargs) + """Initialize Ollama provider.""" + super().__init__(api_key, model, temperature, **kwargs) self.base_url = base_url.rstrip('/') + self._client: Optional[httpx.AsyncClient] = None @property def name(self) -> str: return "ollama" + def _get_client(self) -> httpx.AsyncClient: + """Get or create HTTP client.""" + if self._client is None: + self._client = httpx.AsyncClient(timeout=120.0) + return self._client + + def _get_api_url(self, endpoint: str) -> str: + """Get full URL for an endpoint.""" + return f"{self.base_url}/{endpoint.lstrip('/')}" + async def complete( self, prompt: str, @@ -29,78 +47,133 @@ class OllamaProvider(ProviderBase): max_tokens: Optional[int] = None, **kwargs, ) -> ProviderResponse: + """Send completion request to Ollama.""" start_time = time.time() try: - async with httpx.AsyncClient() as client: - payload = { - "model": self.model, - "prompt": prompt, - "stream": False, - "options": { - "temperature": self.temperature, - } - } - if max_tokens: - payload["options"]["num_predict"] = max_tokens + client = self._get_client() - response = await client.post( - f"{self.base_url}/api/generate", - json=payload, - timeout=120.0 - ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) - response.raise_for_status() - data = response.json() + payload: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "stream": False, + "options": { + "temperature": self.temperature, + }, + } - latency_ms = (time.time() - start_time) * 1000 + if max_tokens: + payload["options"]["num_predict"] = max_tokens - return ProviderResponse( - content=data.get("response", ""), - model=self.model, - provider=self.name, - latency_ms=latency_ms, - ) - except httpx.HTTPStatusError as e: - raise ProviderError(f"Ollama HTTP error: {e}") - except httpx.RequestError as e: + response = await client.post( + self._get_api_url("/api/chat"), + json=payload, + ) + response.raise_for_status() + data = response.json() + + latency_ms = (time.time() - start_time) * 1000 + + content = "" + for msg in data.get("message", {}).get("content", ""): + if isinstance(msg, str): + content += msg + elif isinstance(msg, dict): + content += msg.get("content", "") + + return ProviderResponse( + content=content, + model=self.model, + provider=self.name, + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0), + }, + latency_ms=latency_ms, + metadata={ + "done": data.get("done", False), + }, + ) + except httpx.HTTPError as e: raise ProviderError(f"Ollama connection error: {e}") - async def stream_complete( + async def stream_complete( # type: ignore[override] self, prompt: str, system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, **kwargs, ) -> AsyncIterator[str]: + """Stream completion from Ollama.""" try: - async with httpx.AsyncClient() as client: - payload = { - "model": self.model, - "prompt": prompt, - "stream": True, - "options": { - "temperature": self.temperature, - } - } - if max_tokens: - payload["options"]["num_predict"] = max_tokens + client = self._get_client() - async with client.stream( - "POST", - f"{self.base_url}/api/generate", - json=payload, - timeout=120.0 - ) as response: - async for line in response.aiter_lines(): - import json + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "stream": True, + "options": { + "temperature": self.temperature, + }, + } + + if max_tokens: + payload["options"]["num_predict"] = max_tokens + + async with client.stream( + "POST", + self._get_api_url("/api/chat"), + json=payload, + ) as response: + async for line in response.aiter_lines(): + if line: data = json.loads(line) - if "response" in data: - yield data["response"] - except httpx.HTTPStatusError as e: - raise ProviderError(f"Ollama HTTP error: {e}") - except httpx.RequestError as e: + if "message" in data: + content = data["message"].get("content", "") + if content: + yield content + except httpx.HTTPError as e: raise ProviderError(f"Ollama connection error: {e}") + async def pull_model(self, model: Optional[str] = None) -> bool: + """Pull a model from Ollama registry.""" + try: + client = self._get_client() + target_model = model or self.model + + async with client.stream( + "POST", + self._get_api_url("/api/pull"), + json={"name": target_model, "stream": False}, + ) as response: + response.raise_for_status() + return True + except httpx.HTTPError: + return False + def validate_api_key(self) -> bool: - return True \ No newline at end of file + """Ollama doesn't use API keys, always returns True.""" + return True + + def list_models(self) -> list[str]: + """List available Ollama models.""" + return [ + "llama2", + "llama2-uncensored", + "mistral", + "mixtral", + "codellama", + "deepseek-coder", + "neural-chat", + ]