Add source code files
This commit is contained in:
166
src/codeguard/llm/client.py
Normal file
166
src/codeguard/llm/client.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""LLM client for CodeGuard."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Optional
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def chat(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient(LLMClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "http://localhost:11434",
|
||||||
|
timeout: int = 120,
|
||||||
|
max_retries: int = 3,
|
||||||
|
):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
data: Optional[dict] = None,
|
||||||
|
method: str = "POST",
|
||||||
|
) -> dict:
|
||||||
|
url = f"{self.base_url}/{endpoint}"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
if data:
|
||||||
|
body = json.dumps(data).encode("utf-8")
|
||||||
|
else:
|
||||||
|
body = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url, data=body, headers=headers, method=method
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=self.timeout) as response:
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
if e.code == 404:
|
||||||
|
model_name: str = "unknown"
|
||||||
|
if data is not None:
|
||||||
|
model_name = data.get("model", "unknown")
|
||||||
|
raise ModelNotFoundError(f"Model not found: {model_name}")
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
raise ConnectionError(f"HTTP error: {e.code}")
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
raise ConnectionError(f"Connection error: {e.reason}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
model = kwargs.get("model", "codellama")
|
||||||
|
stream = kwargs.get("stream", False)
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": stream,
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = self._make_request("api/chat", data)
|
||||||
|
return result.get("message", {}).get("content", "")
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
try:
|
||||||
|
self._make_request("api/tags", method="GET")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
try:
|
||||||
|
result = self._make_request("api/tags", method="GET")
|
||||||
|
models = result.get("models", [])
|
||||||
|
return [m.get("name", "unknown") for m in models]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def pull_model(self, model: str) -> bool:
|
||||||
|
try:
|
||||||
|
data = {"name": model}
|
||||||
|
self._make_request("api/pull", data)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "http://localhost:8080",
|
||||||
|
timeout: int = 120,
|
||||||
|
):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def chat(self, messages: list[dict[str, str]], **kwargs) -> str:
|
||||||
|
data = {"messages": messages, "stream": False}
|
||||||
|
result = self._make_request("v1/chat/completions", data)
|
||||||
|
return result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
try:
|
||||||
|
self._make_request("health", method="GET")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
data: Optional[dict] = None,
|
||||||
|
method: str = "POST",
|
||||||
|
) -> dict:
|
||||||
|
url = f"{self.base_url}/{endpoint}"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
if data:
|
||||||
|
body = json.dumps(data).encode("utf-8")
|
||||||
|
else:
|
||||||
|
body = None
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url, data=body, headers=headers, method=method
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=self.timeout) as response:
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionError(Exception):
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user