From 3525029e7eef309e7283395f7c9647becb3fbc0f Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:49:04 +0000 Subject: [PATCH] fix: resolve CI linting and type errors --- app/src/promptforge/providers/ollama.py | 177 ++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 app/src/promptforge/providers/ollama.py diff --git a/app/src/promptforge/providers/ollama.py b/app/src/promptforge/providers/ollama.py new file mode 100644 index 0000000..87e7c74 --- /dev/null +++ b/app/src/promptforge/providers/ollama.py @@ -0,0 +1,177 @@ +import json +import time +from typing import Any, AsyncIterator, Dict, Optional + +import httpx + +from .base import ProviderBase, ProviderResponse +from ..core.exceptions import ProviderError + + +class OllamaProvider(ProviderBase): + """Ollama local model provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "llama2", + temperature: float = 0.7, + base_url: str = "http://localhost:11434", + **kwargs, + ): + """Initialize Ollama provider.""" + super().__init__(api_key, model, temperature, **kwargs) + self.base_url = base_url.rstrip('/') + self._client: Optional[httpx.AsyncClient] = None + + @property + def name(self) -> str: + return "ollama" + + def _get_client(self) -> httpx.AsyncClient: + """Get or create HTTP client.""" + if self._client is None: + self._client = httpx.AsyncClient(timeout=120.0) + return self._client + + def _get_api_url(self, endpoint: str) -> str: + """Get full URL for an endpoint.""" + return f"{self.base_url}/{endpoint.lstrip('/')}" + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> ProviderResponse: + """Send completion request to Ollama.""" + start_time = time.time() + + try: + client = self._get_client() + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "stream": False, + "options": { + "temperature": self.temperature, + }, + } + + if max_tokens: + payload["options"]["num_predict"] = max_tokens + + response = await client.post( + self._get_api_url("/api/chat"), + json=payload, + ) + response.raise_for_status() + data = response.json() + + latency_ms = (time.time() - start_time) * 1000 + + content = "" + for msg in data.get("message", {}).get("content", ""): + if isinstance(msg, str): + content += msg + elif isinstance(msg, dict): + content += msg.get("content", "") + + return ProviderResponse( + content=content, + model=self.model, + provider=self.name, + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0), + }, + latency_ms=latency_ms, + metadata={ + "done": data.get("done", False), + }, + ) + except httpx.HTTPError as e: + raise ProviderError(f"Ollama connection error: {e}") + + async def stream_complete( # type: ignore[override] + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> AsyncIterator[str]: + """Stream completion from Ollama.""" + try: + client = self._get_client() + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "stream": True, + "options": { + "temperature": self.temperature, + }, + } + + if max_tokens: + payload["options"]["num_predict"] = max_tokens + + async with client.stream( + "POST", + self._get_api_url("/api/chat"), + json=payload, + ) as response: + async for line in response.aiter_lines(): + if line: + data = json.loads(line) + if "message" in data: + content = data["message"].get("content", "") + if content: + yield content + except httpx.HTTPError as e: + raise ProviderError(f"Ollama connection error: {e}") + + async def pull_model(self, model: Optional[str] = None) -> bool: + """Pull a model from Ollama registry.""" + try: + client = self._get_client() + target_model = model or self.model + + async with client.stream( + "POST", + self._get_api_url("/api/pull"), + json={"name": target_model, "stream": False}, + ) as response: + response.raise_for_status() + return True + except httpx.HTTPError: + return False + + def validate_api_key(self) -> bool: + """Ollama doesn't use API keys, always returns True.""" + return True + + def list_models(self) -> list[str]: + """List available Ollama models.""" + return [ + "llama2", + "llama2-uncensored", + "mistral", + "mixtral", + "codellama", + "deepseek-coder", + "neural-chat", + ]