Initial upload: Local LLM Prompt Manager CLI tool
This commit is contained in:
79
src/llm/lmstudio.py
Normal file
79
src/llm/lmstudio.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""LM Studio LLM client implementation."""
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ..config import get_config
|
||||||
|
from .base import LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class LMStudioClient(LLMClient):
|
||||||
|
"""Client for LM Studio API."""
|
||||||
|
|
||||||
|
def __init__(self, url: str = None):
|
||||||
|
config = get_config()
|
||||||
|
self.url = url or config.lmstudio_url
|
||||||
|
|
||||||
|
def generate(self, prompt: str, model: str = None, **kwargs) -> str:
|
||||||
|
"""Generate a response using LM Studio API."""
|
||||||
|
payload = {
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
payload.update(kwargs)
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.url}/v1/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=120
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("choices", [{}])[0].get("text", "")
|
||||||
|
|
||||||
|
def stream_generate(self, prompt: str, model: str = None, **kwargs) -> Iterator[str]:
|
||||||
|
"""Stream a response using LM Studio API."""
|
||||||
|
payload = {
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
payload.update(kwargs)
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.url}/v1/completions",
|
||||||
|
json=payload,
|
||||||
|
stream=True,
|
||||||
|
timeout=120
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
data = line.decode("utf-8")
|
||||||
|
if data.startswith("data: "):
|
||||||
|
data = data[6:]
|
||||||
|
if data != "[DONE]":
|
||||||
|
import json
|
||||||
|
parsed = json.loads(data)
|
||||||
|
chunk = parsed.get("choices", [{}])[0].get("text", "")
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def test_connection(self) -> bool:
|
||||||
|
"""Test if LM Studio is available."""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.url}/v1/models", timeout=5)
|
||||||
|
return response.status_code == 200
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_available_models(self) -> list[str]:
|
||||||
|
"""Get list of available models from LM Studio."""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.url}/v1/models", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return [m.get("id", "") for m in data.get("data", [])]
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
Reference in New Issue
Block a user