fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:25 +00:00
parent 03ed9d92b2
commit 8090d3eeba

View File

@@ -1,6 +1,9 @@
import asyncio """Ollama provider implementation for local models."""
import json
import time import time
from typing import Any, AsyncIterator, Dict, Optional from typing import Any, AsyncIterator, Dict, Optional
import httpx import httpx
from .base import ProviderBase, ProviderResponse from .base import ProviderBase, ProviderResponse
@@ -8,20 +11,35 @@ from ..core.exceptions import ProviderError
class OllamaProvider(ProviderBase): class OllamaProvider(ProviderBase):
"""Ollama local model provider."""
def __init__( def __init__(
self, self,
api_key: Optional[str] = None,
model: str = "llama2", model: str = "llama2",
temperature: float = 0.7, temperature: float = 0.7,
base_url: str = "http://localhost:11434", base_url: str = "http://localhost:11434",
**kwargs, **kwargs,
): ):
super().__init__(None, model, temperature, **kwargs) """Initialize Ollama provider."""
super().__init__(api_key, model, temperature, **kwargs)
self.base_url = base_url.rstrip('/') self.base_url = base_url.rstrip('/')
self._client: Optional[httpx.AsyncClient] = None
@property @property
def name(self) -> str: def name(self) -> str:
return "ollama" return "ollama"
def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.AsyncClient(timeout=120.0)
return self._client
def _get_api_url(self, endpoint: str) -> str:
"""Get full URL for an endpoint."""
return f"{self.base_url}/{endpoint.lstrip('/')}"
async def complete( async def complete(
self, self,
prompt: str, prompt: str,
@@ -29,78 +47,133 @@ class OllamaProvider(ProviderBase):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ProviderResponse: ) -> ProviderResponse:
"""Send completion request to Ollama."""
start_time = time.time() start_time = time.time()
try: try:
async with httpx.AsyncClient() as client: client = self._get_client()
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": self.temperature,
}
}
if max_tokens:
payload["options"]["num_predict"] = max_tokens
response = await client.post( messages = []
f"{self.base_url}/api/generate", if system_prompt:
json=payload, messages.append({"role": "system", "content": system_prompt})
timeout=120.0 messages.append({"role": "user", "content": prompt})
)
response.raise_for_status() payload: Dict[str, Any] = {
data = response.json() "model": self.model,
"messages": messages,
"stream": False,
"options": {
"temperature": self.temperature,
},
}
latency_ms = (time.time() - start_time) * 1000 if max_tokens:
payload["options"]["num_predict"] = max_tokens
return ProviderResponse( response = await client.post(
content=data.get("response", ""), self._get_api_url("/api/chat"),
model=self.model, json=payload,
provider=self.name, )
latency_ms=latency_ms, response.raise_for_status()
) data = response.json()
except httpx.HTTPStatusError as e:
raise ProviderError(f"Ollama HTTP error: {e}") latency_ms = (time.time() - start_time) * 1000
except httpx.RequestError as e:
content = ""
for msg in data.get("message", {}).get("content", ""):
if isinstance(msg, str):
content += msg
elif isinstance(msg, dict):
content += msg.get("content", "")
return ProviderResponse(
content=content,
model=self.model,
provider=self.name,
usage={
"prompt_tokens": data.get("prompt_eval_count", 0),
"completion_tokens": data.get("eval_count", 0),
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
},
latency_ms=latency_ms,
metadata={
"done": data.get("done", False),
},
)
except httpx.HTTPError as e:
raise ProviderError(f"Ollama connection error: {e}") raise ProviderError(f"Ollama connection error: {e}")
async def stream_complete( async def stream_complete( # type: ignore[override]
self, self,
prompt: str, prompt: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Stream completion from Ollama."""
try: try:
async with httpx.AsyncClient() as client: client = self._get_client()
payload = {
"model": self.model,
"prompt": prompt,
"stream": True,
"options": {
"temperature": self.temperature,
}
}
if max_tokens:
payload["options"]["num_predict"] = max_tokens
async with client.stream( messages = []
"POST", if system_prompt:
f"{self.base_url}/api/generate", messages.append({"role": "system", "content": system_prompt})
json=payload, messages.append({"role": "user", "content": prompt})
timeout=120.0
) as response: payload: Dict[str, Any] = {
async for line in response.aiter_lines(): "model": self.model,
import json "messages": messages,
"stream": True,
"options": {
"temperature": self.temperature,
},
}
if max_tokens:
payload["options"]["num_predict"] = max_tokens
async with client.stream(
"POST",
self._get_api_url("/api/chat"),
json=payload,
) as response:
async for line in response.aiter_lines():
if line:
data = json.loads(line) data = json.loads(line)
if "response" in data: if "message" in data:
yield data["response"] content = data["message"].get("content", "")
except httpx.HTTPStatusError as e: if content:
raise ProviderError(f"Ollama HTTP error: {e}") yield content
except httpx.RequestError as e: except httpx.HTTPError as e:
raise ProviderError(f"Ollama connection error: {e}") raise ProviderError(f"Ollama connection error: {e}")
async def pull_model(self, model: Optional[str] = None) -> bool:
"""Pull a model from Ollama registry."""
try:
client = self._get_client()
target_model = model or self.model
async with client.stream(
"POST",
self._get_api_url("/api/pull"),
json={"name": target_model, "stream": False},
) as response:
response.raise_for_status()
return True
except httpx.HTTPError:
return False
def validate_api_key(self) -> bool: def validate_api_key(self) -> bool:
"""Ollama doesn't use API keys, always returns True."""
return True return True
def list_models(self) -> list[str]:
"""List available Ollama models."""
return [
"llama2",
"llama2-uncensored",
"mistral",
"mixtral",
"codellama",
"deepseek-coder",
"neural-chat",
]