diff --git a/src/codeguard/llm/client.py b/src/codeguard/llm/client.py new file mode 100644 index 0000000..0eea84d --- /dev/null +++ b/src/codeguard/llm/client.py @@ -0,0 +1,166 @@ +"""LLM client for CodeGuard.""" + +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional +import urllib.request +import urllib.error + +logger = logging.getLogger(__name__) + + +class LLMClient(ABC): + @abstractmethod + def chat(self, messages: list[dict[str, str]], **kwargs: Any) -> str: + pass + + @abstractmethod + def health_check(self) -> bool: + pass + + @abstractmethod + def list_models(self) -> list[str]: + pass + + +class OllamaClient(LLMClient): + def __init__( + self, + base_url: str = "http://localhost:11434", + timeout: int = 120, + max_retries: int = 3, + ): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.max_retries = max_retries + + def _make_request( + self, + endpoint: str, + data: Optional[dict] = None, + method: str = "POST", + ) -> dict: + url = f"{self.base_url}/{endpoint}" + headers = {"Content-Type": "application/json"} + + if data: + body = json.dumps(data).encode("utf-8") + else: + body = None + + for attempt in range(self.max_retries): + try: + req = urllib.request.Request( + url, data=body, headers=headers, method=method + ) + with urllib.request.urlopen(req, timeout=self.timeout) as response: + return json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as e: + if e.code == 404: + model_name: str = "unknown" + if data is not None: + model_name = data.get("model", "unknown") + raise ModelNotFoundError(f"Model not found: {model_name}") + if attempt == self.max_retries - 1: + raise ConnectionError(f"HTTP error: {e.code}") + except urllib.error.URLError as e: + if attempt == self.max_retries - 1: + raise ConnectionError(f"Connection error: {e.reason}") + return {} + + def chat( + self, + messages: list[dict[str, str]], + **kwargs: Any, + ) -> str: + model = kwargs.get("model", "codellama") + stream = kwargs.get("stream", False) + data = { + "model": model, + "messages": messages, + "stream": stream, + "options": { + "temperature": 0.1, + "top_k": 10, + "top_p": 0.9, + }, + } + result = self._make_request("api/chat", data) + return result.get("message", {}).get("content", "") + + def health_check(self) -> bool: + try: + self._make_request("api/tags", method="GET") + return True + except Exception: + return False + + def list_models(self) -> list[str]: + try: + result = self._make_request("api/tags", method="GET") + models = result.get("models", []) + return [m.get("name", "unknown") for m in models] + except Exception: + return [] + + def pull_model(self, model: str) -> bool: + try: + data = {"name": model} + self._make_request("api/pull", data) + return True + except Exception: + return False + + +class LlamaCppClient: + def __init__( + self, + base_url: str = "http://localhost:8080", + timeout: int = 120, + ): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def chat(self, messages: list[dict[str, str]], **kwargs) -> str: + data = {"messages": messages, "stream": False} + result = self._make_request("v1/chat/completions", data) + return result.get("choices", [{}])[0].get("message", {}).get("content", "") + + def health_check(self) -> bool: + try: + self._make_request("health", method="GET") + return True + except Exception: + return False + + def list_models(self) -> list[str]: + return [] + + def _make_request( + self, + endpoint: str, + data: Optional[dict] = None, + method: str = "POST", + ) -> dict: + url = f"{self.base_url}/{endpoint}" + headers = {"Content-Type": "application/json"} + + if data: + body = json.dumps(data).encode("utf-8") + else: + body = None + + req = urllib.request.Request( + url, data=body, headers=headers, method=method + ) + with urllib.request.urlopen(req, timeout=self.timeout) as response: + return json.loads(response.read().decode("utf-8")) + + +class ModelNotFoundError(Exception): + pass + + +class ConnectionError(Exception): + pass