Add source code files

This commit is contained in:
2026-02-01 02:55:41 +00:00
parent 8bbaf80a9b
commit 6b684e5699

166
src/codeguard/llm/client.py Normal file
View File

@@ -0,0 +1,166 @@
"""LLM client for CodeGuard."""
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional
import urllib.request
import urllib.error
logger = logging.getLogger(__name__)
class LLMClient(ABC):
@abstractmethod
def chat(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
pass
@abstractmethod
def health_check(self) -> bool:
pass
@abstractmethod
def list_models(self) -> list[str]:
pass
class OllamaClient(LLMClient):
def __init__(
self,
base_url: str = "http://localhost:11434",
timeout: int = 120,
max_retries: int = 3,
):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
def _make_request(
self,
endpoint: str,
data: Optional[dict] = None,
method: str = "POST",
) -> dict:
url = f"{self.base_url}/{endpoint}"
headers = {"Content-Type": "application/json"}
if data:
body = json.dumps(data).encode("utf-8")
else:
body = None
for attempt in range(self.max_retries):
try:
req = urllib.request.Request(
url, data=body, headers=headers, method=method
)
with urllib.request.urlopen(req, timeout=self.timeout) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as e:
if e.code == 404:
model_name: str = "unknown"
if data is not None:
model_name = data.get("model", "unknown")
raise ModelNotFoundError(f"Model not found: {model_name}")
if attempt == self.max_retries - 1:
raise ConnectionError(f"HTTP error: {e.code}")
except urllib.error.URLError as e:
if attempt == self.max_retries - 1:
raise ConnectionError(f"Connection error: {e.reason}")
return {}
def chat(
self,
messages: list[dict[str, str]],
**kwargs: Any,
) -> str:
model = kwargs.get("model", "codellama")
stream = kwargs.get("stream", False)
data = {
"model": model,
"messages": messages,
"stream": stream,
"options": {
"temperature": 0.1,
"top_k": 10,
"top_p": 0.9,
},
}
result = self._make_request("api/chat", data)
return result.get("message", {}).get("content", "")
def health_check(self) -> bool:
try:
self._make_request("api/tags", method="GET")
return True
except Exception:
return False
def list_models(self) -> list[str]:
try:
result = self._make_request("api/tags", method="GET")
models = result.get("models", [])
return [m.get("name", "unknown") for m in models]
except Exception:
return []
def pull_model(self, model: str) -> bool:
try:
data = {"name": model}
self._make_request("api/pull", data)
return True
except Exception:
return False
class LlamaCppClient:
def __init__(
self,
base_url: str = "http://localhost:8080",
timeout: int = 120,
):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
def chat(self, messages: list[dict[str, str]], **kwargs) -> str:
data = {"messages": messages, "stream": False}
result = self._make_request("v1/chat/completions", data)
return result.get("choices", [{}])[0].get("message", {}).get("content", "")
def health_check(self) -> bool:
try:
self._make_request("health", method="GET")
return True
except Exception:
return False
def list_models(self) -> list[str]:
return []
def _make_request(
self,
endpoint: str,
data: Optional[dict] = None,
method: str = "POST",
) -> dict:
url = f"{self.base_url}/{endpoint}"
headers = {"Content-Type": "application/json"}
if data:
body = json.dumps(data).encode("utf-8")
else:
body = None
req = urllib.request.Request(
url, data=body, headers=headers, method=method
)
with urllib.request.urlopen(req, timeout=self.timeout) as response:
return json.loads(response.read().decode("utf-8"))
class ModelNotFoundError(Exception):
pass
class ConnectionError(Exception):
pass