Initial upload: shell-history-semantic-search v0.1.0
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
This commit is contained in:
97
src/shell_history_search/core/embeddings.py
Normal file
97
src/shell_history_search/core/embeddings.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
|
||||
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "shell_history_search" / "models"
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
cache_path = os.environ.get("SHELL_HISTORY_MODEL_CACHE", str(DEFAULT_CACHE_DIR))
|
||||
return Path(cache_path)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL_NAME,
|
||||
cache_dir: Optional[Path] = None,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.cache_dir = cache_dir or get_cache_dir()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.device = device or "cpu"
|
||||
self._model: Optional[SentenceTransformer] = None
|
||||
self._embedding_dim: Optional[int] = None
|
||||
|
||||
def _load_model(self) -> SentenceTransformer:
|
||||
if self._model is None:
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
self._model = SentenceTransformer(
|
||||
self.model_name,
|
||||
cache_folder=str(self.cache_dir),
|
||||
device=self.device,
|
||||
)
|
||||
self._embedding_dim = self._model.get_sentence_embedding_dimension()
|
||||
logger.info(f"Model loaded. Embedding dimension: {self._embedding_dim}")
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def model(self) -> SentenceTransformer:
|
||||
return self._load_model()
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
if self._embedding_dim is None:
|
||||
self._load_model()
|
||||
assert self._embedding_dim is not None
|
||||
return self._embedding_dim
|
||||
|
||||
def encode(self, texts: list[str], batch_size: int = 32) -> np.ndarray:
|
||||
if not texts:
|
||||
return np.array([])
|
||||
|
||||
embeddings = self.model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
def encode_single(self, text: str) -> np.ndarray:
|
||||
return self.encode([text])[0]
|
||||
|
||||
@staticmethod
|
||||
def embedding_to_blob(embedding: np.ndarray) -> bytes:
|
||||
return embedding.astype(np.float32).tobytes()
|
||||
|
||||
@staticmethod
|
||||
def blob_to_embedding(blob: bytes, dim: int) -> np.ndarray:
|
||||
return np.frombuffer(blob, dtype=np.float32).reshape(-1, dim)
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
if a.ndim == 1:
|
||||
a = a.reshape(1, -1)
|
||||
if b.ndim == 1:
|
||||
b = b.reshape(1, -1)
|
||||
|
||||
a_norm = np.linalg.norm(a, axis=1, keepdims=True)
|
||||
b_norm = np.linalg.norm(b, axis=1, keepdims=True)
|
||||
|
||||
a_normalized = a / (a_norm + 1e-8)
|
||||
b_normalized = b / (b_norm + 1e-8)
|
||||
|
||||
similarity = np.dot(a_normalized, b_normalized.T)
|
||||
|
||||
return float(similarity[0, 0])
|
||||
Reference in New Issue
Block a user