Add source code files
This commit is contained in:
166
src/codeguard/llm/client.py
Normal file
166
src/codeguard/llm/client.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""LLM client for CodeGuard."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
@abstractmethod
|
||||
def chat(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def health_check(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self) -> list[str]:
|
||||
pass
|
||||
|
||||
|
||||
class OllamaClient(LLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:11434",
|
||||
timeout: int = 120,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
data: Optional[dict] = None,
|
||||
method: str = "POST",
|
||||
) -> dict:
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
if data:
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
else:
|
||||
body = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
url, data=body, headers=headers, method=method
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=self.timeout) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
model_name: str = "unknown"
|
||||
if data is not None:
|
||||
model_name = data.get("model", "unknown")
|
||||
raise ModelNotFoundError(f"Model not found: {model_name}")
|
||||
if attempt == self.max_retries - 1:
|
||||
raise ConnectionError(f"HTTP error: {e.code}")
|
||||
except urllib.error.URLError as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise ConnectionError(f"Connection error: {e.reason}")
|
||||
return {}
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
model = kwargs.get("model", "codellama")
|
||||
stream = kwargs.get("stream", False)
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"top_k": 10,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
}
|
||||
result = self._make_request("api/chat", data)
|
||||
return result.get("message", {}).get("content", "")
|
||||
|
||||
def health_check(self) -> bool:
|
||||
try:
|
||||
self._make_request("api/tags", method="GET")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
try:
|
||||
result = self._make_request("api/tags", method="GET")
|
||||
models = result.get("models", [])
|
||||
return [m.get("name", "unknown") for m in models]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def pull_model(self, model: str) -> bool:
|
||||
try:
|
||||
data = {"name": model}
|
||||
self._make_request("api/pull", data)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class LlamaCppClient:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8080",
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
def chat(self, messages: list[dict[str, str]], **kwargs) -> str:
|
||||
data = {"messages": messages, "stream": False}
|
||||
result = self._make_request("v1/chat/completions", data)
|
||||
return result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
def health_check(self) -> bool:
|
||||
try:
|
||||
self._make_request("health", method="GET")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
data: Optional[dict] = None,
|
||||
method: str = "POST",
|
||||
) -> dict:
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
if data:
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
else:
|
||||
body = None
|
||||
|
||||
req = urllib.request.Request(
|
||||
url, data=body, headers=headers, method=method
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=self.timeout) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
Reference in New Issue
Block a user