Initial upload: shell-history-semantic-search v0.1.0
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
This commit is contained in:
97
src/shell_history_search/core/embeddings.py
Normal file
97
src/shell_history_search/core/embeddings.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
|
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "shell_history_search" / "models"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_dir() -> Path:
|
||||||
|
cache_path = os.environ.get("SHELL_HISTORY_MODEL_CACHE", str(DEFAULT_CACHE_DIR))
|
||||||
|
return Path(cache_path)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = DEFAULT_MODEL_NAME,
|
||||||
|
cache_dir: Optional[Path] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.cache_dir = cache_dir or get_cache_dir()
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.device = device or "cpu"
|
||||||
|
self._model: Optional[SentenceTransformer] = None
|
||||||
|
self._embedding_dim: Optional[int] = None
|
||||||
|
|
||||||
|
def _load_model(self) -> SentenceTransformer:
|
||||||
|
if self._model is None:
|
||||||
|
logger.info(f"Loading embedding model: {self.model_name}")
|
||||||
|
self._model = SentenceTransformer(
|
||||||
|
self.model_name,
|
||||||
|
cache_folder=str(self.cache_dir),
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self._embedding_dim = self._model.get_sentence_embedding_dimension()
|
||||||
|
logger.info(f"Model loaded. Embedding dimension: {self._embedding_dim}")
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> SentenceTransformer:
|
||||||
|
return self._load_model()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
if self._embedding_dim is None:
|
||||||
|
self._load_model()
|
||||||
|
assert self._embedding_dim is not None
|
||||||
|
return self._embedding_dim
|
||||||
|
|
||||||
|
def encode(self, texts: list[str], batch_size: int = 32) -> np.ndarray:
|
||||||
|
if not texts:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
embeddings = self.model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=batch_size,
|
||||||
|
show_progress_bar=False,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
def encode_single(self, text: str) -> np.ndarray:
|
||||||
|
return self.encode([text])[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embedding_to_blob(embedding: np.ndarray) -> bytes:
|
||||||
|
return embedding.astype(np.float32).tobytes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def blob_to_embedding(blob: bytes, dim: int) -> np.ndarray:
|
||||||
|
return np.frombuffer(blob, dtype=np.float32).reshape(-1, dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
if a.ndim == 1:
|
||||||
|
a = a.reshape(1, -1)
|
||||||
|
if b.ndim == 1:
|
||||||
|
b = b.reshape(1, -1)
|
||||||
|
|
||||||
|
a_norm = np.linalg.norm(a, axis=1, keepdims=True)
|
||||||
|
b_norm = np.linalg.norm(b, axis=1, keepdims=True)
|
||||||
|
|
||||||
|
a_normalized = a / (a_norm + 1e-8)
|
||||||
|
b_normalized = b / (b_norm + 1e-8)
|
||||||
|
|
||||||
|
similarity = np.dot(a_normalized, b_normalized.T)
|
||||||
|
|
||||||
|
return float(similarity[0, 0])
|
||||||
Reference in New Issue
Block a user