Add services, prompts, and utils modules
This commit is contained in:
295
local_code_assistant/services/ollama.py
Normal file
295
local_code_assistant/services/ollama.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Ollama service for Local Code Assistant."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import ConnectionError, Timeout
|
||||||
|
|
||||||
|
from local_code_assistant.services.config import ConfigService
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaServiceError(Exception):
|
||||||
|
"""Base exception for Ollama service errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaConnectionError(OllamaServiceError):
|
||||||
|
"""Exception raised when connection to Ollama fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaModelError(OllamaServiceError):
|
||||||
|
"""Exception raised when model operation fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaService:
|
||||||
|
"""Service for interacting with Ollama API."""
|
||||||
|
|
||||||
|
def __init__(self, config: ConfigService):
|
||||||
|
"""Initialize Ollama service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration service instance.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.base_url = config.ollama_base_url
|
||||||
|
self.timeout = config.ollama_timeout
|
||||||
|
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
method: str = "GET",
|
||||||
|
data: Optional[dict[str, Any]] = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Make HTTP request to Ollama API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: API endpoint.
|
||||||
|
method: HTTP method.
|
||||||
|
data: Request data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response data as dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OllamaConnectionError: If connection fails.
|
||||||
|
OllamaModelError: If API returns error.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/{endpoint}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if method == "GET":
|
||||||
|
response = requests.get(url, timeout=self.timeout)
|
||||||
|
elif method == "POST":
|
||||||
|
response = requests.post(url, json=data, timeout=self.timeout)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except ConnectionError as e:
|
||||||
|
raise OllamaConnectionError(
|
||||||
|
f"Failed to connect to Ollama at {self.base_url}. "
|
||||||
|
"Make sure Ollama is running."
|
||||||
|
) from e
|
||||||
|
except Timeout as e:
|
||||||
|
raise OllamaServiceError(f"Request timed out after {self.timeout}s") from e
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
error_msg = f"API request failed: {e.response.text}"
|
||||||
|
try:
|
||||||
|
error_data = e.response.json()
|
||||||
|
if "error" in error_data:
|
||||||
|
error_msg = f"Ollama error: {error_data['error']}"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise OllamaModelError(error_msg) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise OllamaServiceError(f"Unexpected error: {str(e)}") from e
|
||||||
|
|
||||||
|
def check_connection(self) -> bool:
|
||||||
|
"""Check if Ollama is running and accessible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._make_request("api/tags")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self._make_request("api/tags")
|
||||||
|
models = response.get("models", [])
|
||||||
|
return [model["name"] for model in models]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""Generate response from model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt.
|
||||||
|
model: Model to use. Defaults to config default.
|
||||||
|
stream: Whether to stream response. Defaults to config setting.
|
||||||
|
system: System prompt.
|
||||||
|
temperature: Temperature for generation.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated response text.
|
||||||
|
"""
|
||||||
|
model = model or self.config.ollama_model
|
||||||
|
stream = stream if stream is not None else self.config.streaming
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": stream,
|
||||||
|
"options": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if system:
|
||||||
|
data["system"] = system
|
||||||
|
if temperature is not None:
|
||||||
|
data["options"]["temperature"] = temperature
|
||||||
|
if max_tokens is not None:
|
||||||
|
data["options"]["num_predict"] = max_tokens
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
response_text = ""
|
||||||
|
for chunk in self._stream_generate(data):
|
||||||
|
if "response" in chunk:
|
||||||
|
response_text += chunk["response"]
|
||||||
|
return response_text
|
||||||
|
else:
|
||||||
|
response = self._make_request("api/generate", method="POST", data=data)
|
||||||
|
return response.get("response", "")
|
||||||
|
|
||||||
|
def _stream_generate(self, data: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
|
||||||
|
"""Stream response from model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Request data.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Response chunks.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/generate"
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=data,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
chunk = json.loads(line.decode('utf-8'))
|
||||||
|
yield chunk
|
||||||
|
if chunk.get("done", False):
|
||||||
|
break
|
||||||
|
|
||||||
|
except ConnectionError as e:
|
||||||
|
raise OllamaConnectionError(
|
||||||
|
f"Failed to connect to Ollama at {self.base_url}."
|
||||||
|
) from e
|
||||||
|
except Timeout as e:
|
||||||
|
raise OllamaServiceError(f"Streaming timed out after {self.timeout}s") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise OllamaServiceError(f"Streaming error: {str(e)}") from e
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""Chat with model using message history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages with 'role' and 'content'.
|
||||||
|
model: Model to use. Defaults to config default.
|
||||||
|
stream: Whether to stream response.
|
||||||
|
temperature: Temperature for generation.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated response text.
|
||||||
|
"""
|
||||||
|
model = model or self.config.ollama_model
|
||||||
|
stream = stream if stream is not None else self.config.streaming
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": stream,
|
||||||
|
"options": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if temperature is not None:
|
||||||
|
data["options"]["temperature"] = temperature
|
||||||
|
if max_tokens is not None:
|
||||||
|
data["options"]["num_predict"] = max_tokens
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
response_text = ""
|
||||||
|
for chunk in self._stream_chat(data):
|
||||||
|
if "message" in chunk and "content" in chunk["message"]:
|
||||||
|
response_text += chunk["message"]["content"]
|
||||||
|
return response_text
|
||||||
|
else:
|
||||||
|
response = self._make_request("api/chat", method="POST", data=data)
|
||||||
|
return response.get("message", {}).get("response", "")
|
||||||
|
|
||||||
|
def _stream_chat(self, data: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
|
||||||
|
"""Stream chat response from model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Request data.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Response chunks.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/chat"
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=data,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
chunk = json.loads(line.decode('utf-8'))
|
||||||
|
yield chunk
|
||||||
|
if chunk.get("done", False):
|
||||||
|
break
|
||||||
|
|
||||||
|
except ConnectionError as e:
|
||||||
|
raise OllamaConnectionError(
|
||||||
|
f"Failed to connect to Ollama at {self.base_url}."
|
||||||
|
) from e
|
||||||
|
except Timeout as e:
|
||||||
|
raise OllamaServiceError(f"Streaming timed out after {self.timeout}s") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise OllamaServiceError(f"Streaming error: {str(e)}") from e
|
||||||
|
|
||||||
|
def get_model_info(self, model: str) -> dict[str, Any]:
|
||||||
|
"""Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model information.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self._make_request("api/show", method="POST", data={"name": model})
|
||||||
|
return response
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
Reference in New Issue
Block a user