diff --git a/local_code_assistant/services/ollama.py b/local_code_assistant/services/ollama.py new file mode 100644 index 0000000..abeba8b --- /dev/null +++ b/local_code_assistant/services/ollama.py @@ -0,0 +1,295 @@ +"""Ollama service for Local Code Assistant.""" + +import json +from collections.abc import Generator +from typing import Any, Optional + +import requests +from requests.exceptions import ConnectionError, Timeout + +from local_code_assistant.services.config import ConfigService + + +class OllamaServiceError(Exception): + """Base exception for Ollama service errors.""" + pass + + +class OllamaConnectionError(OllamaServiceError): + """Exception raised when connection to Ollama fails.""" + pass + + +class OllamaModelError(OllamaServiceError): + """Exception raised when model operation fails.""" + pass + + +class OllamaService: + """Service for interacting with Ollama API.""" + + def __init__(self, config: ConfigService): + """Initialize Ollama service. + + Args: + config: Configuration service instance. + """ + self.config = config + self.base_url = config.ollama_base_url + self.timeout = config.ollama_timeout + + def _make_request( + self, + endpoint: str, + method: str = "GET", + data: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + """Make HTTP request to Ollama API. + + Args: + endpoint: API endpoint. + method: HTTP method. + data: Request data. + + Returns: + Response data as dictionary. + + Raises: + OllamaConnectionError: If connection fails. + OllamaModelError: If API returns error. + """ + url = f"{self.base_url}/{endpoint}" + + try: + if method == "GET": + response = requests.get(url, timeout=self.timeout) + elif method == "POST": + response = requests.post(url, json=data, timeout=self.timeout) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response.json() + + except ConnectionError as e: + raise OllamaConnectionError( + f"Failed to connect to Ollama at {self.base_url}. " + "Make sure Ollama is running." + ) from e + except Timeout as e: + raise OllamaServiceError(f"Request timed out after {self.timeout}s") from e + except requests.exceptions.HTTPError as e: + error_msg = f"API request failed: {e.response.text}" + try: + error_data = e.response.json() + if "error" in error_data: + error_msg = f"Ollama error: {error_data['error']}" + except Exception: + pass + raise OllamaModelError(error_msg) from e + except Exception as e: + raise OllamaServiceError(f"Unexpected error: {str(e)}") from e + + def check_connection(self) -> bool: + """Check if Ollama is running and accessible. + + Returns: + True if connection successful, False otherwise. + """ + try: + self._make_request("api/tags") + return True + except Exception: + return False + + def list_models(self) -> list[str]: + """List available models. + + Returns: + List of model names. + """ + try: + response = self._make_request("api/tags") + models = response.get("models", []) + return [model["name"] for model in models] + except Exception: + return [] + + def generate( + self, + prompt: str, + model: Optional[str] = None, + stream: Optional[bool] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> str: + """Generate response from model. + + Args: + prompt: User prompt. + model: Model to use. Defaults to config default. + stream: Whether to stream response. Defaults to config setting. + system: System prompt. + temperature: Temperature for generation. + max_tokens: Maximum tokens to generate. + + Returns: + Generated response text. + """ + model = model or self.config.ollama_model + stream = stream if stream is not None else self.config.streaming + + data: dict[str, Any] = { + "model": model, + "prompt": prompt, + "stream": stream, + "options": {} + } + + if system: + data["system"] = system + if temperature is not None: + data["options"]["temperature"] = temperature + if max_tokens is not None: + data["options"]["num_predict"] = max_tokens + + if stream: + response_text = "" + for chunk in self._stream_generate(data): + if "response" in chunk: + response_text += chunk["response"] + return response_text + else: + response = self._make_request("api/generate", method="POST", data=data) + return response.get("response", "") + + def _stream_generate(self, data: dict[str, Any]) -> Generator[dict[str, Any], None, None]: + """Stream response from model. + + Args: + data: Request data. + + Yields: + Response chunks. + """ + url = f"{self.base_url}/api/generate" + try: + response = requests.post( + url, + json=data, + timeout=self.timeout, + stream=True + ) + response.raise_for_status() + + for line in response.iter_lines(): + if line: + chunk = json.loads(line.decode('utf-8')) + yield chunk + if chunk.get("done", False): + break + + except ConnectionError as e: + raise OllamaConnectionError( + f"Failed to connect to Ollama at {self.base_url}." + ) from e + except Timeout as e: + raise OllamaServiceError(f"Streaming timed out after {self.timeout}s") from e + except Exception as e: + raise OllamaServiceError(f"Streaming error: {str(e)}") from e + + def chat( + self, + messages: list[dict[str, str]], + model: Optional[str] = None, + stream: Optional[bool] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> str: + """Chat with model using message history. + + Args: + messages: List of messages with 'role' and 'content'. + model: Model to use. Defaults to config default. + stream: Whether to stream response. + temperature: Temperature for generation. + max_tokens: Maximum tokens to generate. + + Returns: + Generated response text. + """ + model = model or self.config.ollama_model + stream = stream if stream is not None else self.config.streaming + + data: dict[str, Any] = { + "model": model, + "messages": messages, + "stream": stream, + "options": {} + } + + if temperature is not None: + data["options"]["temperature"] = temperature + if max_tokens is not None: + data["options"]["num_predict"] = max_tokens + + if stream: + response_text = "" + for chunk in self._stream_chat(data): + if "message" in chunk and "content" in chunk["message"]: + response_text += chunk["message"]["content"] + return response_text + else: + response = self._make_request("api/chat", method="POST", data=data) + return response.get("message", {}).get("response", "") + + def _stream_chat(self, data: dict[str, Any]) -> Generator[dict[str, Any], None, None]: + """Stream chat response from model. + + Args: + data: Request data. + + Yields: + Response chunks. + """ + url = f"{self.base_url}/api/chat" + try: + response = requests.post( + url, + json=data, + timeout=self.timeout, + stream=True + ) + response.raise_for_status() + + for line in response.iter_lines(): + if line: + chunk = json.loads(line.decode('utf-8')) + yield chunk + if chunk.get("done", False): + break + + except ConnectionError as e: + raise OllamaConnectionError( + f"Failed to connect to Ollama at {self.base_url}." + ) from e + except Timeout as e: + raise OllamaServiceError(f"Streaming timed out after {self.timeout}s") from e + except Exception as e: + raise OllamaServiceError(f"Streaming error: {str(e)}") from e + + def get_model_info(self, model: str) -> dict[str, Any]: + """Get information about a specific model. + + Args: + model: Model name. + + Returns: + Model information. + """ + try: + response = self._make_request("api/show", method="POST", data={"name": model}) + return response + except Exception: + return {} \ No newline at end of file