Add search modules (embeddings, vectorstore, searcher)
This commit is contained in:
117
src/search/embeddings.py
Normal file
117
src/search/embeddings.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""Embedding model management using sentence-transformers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingManager:
|
||||||
|
"""Manages local embedding models for semantic search."""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
cache_dir: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the embedding manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to use (default: all-MiniLM-L6-v2)
|
||||||
|
device: Device to run on (cpu, cuda, auto)
|
||||||
|
cache_dir: Directory to cache models
|
||||||
|
"""
|
||||||
|
self._model_name = model_name or self.DEFAULT_MODEL
|
||||||
|
self._device = device or "cpu"
|
||||||
|
self._cache_dir = cache_dir
|
||||||
|
self._model: Optional[SentenceTransformer] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Get the model name."""
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> str:
|
||||||
|
"""Get the device being used."""
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def load_model(self, force_download: bool = False) -> SentenceTransformer:
|
||||||
|
"""Load the embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_download: Force re-download of the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded SentenceTransformer model
|
||||||
|
"""
|
||||||
|
if self._model is not None and not force_download:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_kwargs = {"device": self._device}
|
||||||
|
if self._cache_dir:
|
||||||
|
model_kwargs["cache_folder"] = str(self._cache_dir)
|
||||||
|
|
||||||
|
self._model = SentenceTransformer(self._model_name, **model_kwargs)
|
||||||
|
logger.info(f"Loaded embedding model: {self._model_name} on {self._device}")
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model {self._model_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def embed(self, texts: List[str], show_progress: bool = False) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for a list of texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of text strings to embed
|
||||||
|
show_progress: Show progress bar
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
model = self.load_model()
|
||||||
|
embeddings = model.encode(
|
||||||
|
texts,
|
||||||
|
show_progress_bar=show_progress,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
)
|
||||||
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
def embed_query(self, query: str) -> List[float]:
|
||||||
|
"""Generate embedding for a single query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector
|
||||||
|
"""
|
||||||
|
return self.embed([query])[0]
|
||||||
|
|
||||||
|
def get_embedding_dim(self) -> int:
|
||||||
|
"""Get the embedding dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dimension of the embedding vectors
|
||||||
|
"""
|
||||||
|
model = self.load_model()
|
||||||
|
return model.get_sentence_embedding_dimension()
|
||||||
|
|
||||||
|
def unload_model(self) -> None:
|
||||||
|
"""Unload the model to free memory."""
|
||||||
|
self._model = None
|
||||||
|
logger.info("Unloaded embedding model")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"EmbeddingManager(model={self._model_name}, device={self._device})"
|
||||||
Reference in New Issue
Block a user