diff --git a/src/local_api_docs_search/search/embeddings.py b/src/local_api_docs_search/search/embeddings.py new file mode 100644 index 0000000..c751aee --- /dev/null +++ b/src/local_api_docs_search/search/embeddings.py @@ -0,0 +1,117 @@ +"""Embedding model management using sentence-transformers.""" + +import logging +from pathlib import Path +from typing import List, Optional + +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + + +class EmbeddingManager: + """Manages local embedding models for semantic search.""" + + DEFAULT_MODEL = "all-MiniLM-L6-v2" + + def __init__( + self, + model_name: Optional[str] = None, + device: Optional[str] = None, + cache_dir: Optional[Path] = None, + ): + """Initialize the embedding manager. + + Args: + model_name: Name of the model to use (default: all-MiniLM-L6-v2) + device: Device to run on (cpu, cuda, auto) + cache_dir: Directory to cache models + """ + self._model_name = model_name or self.DEFAULT_MODEL + self._device = device or "cpu" + self._cache_dir = cache_dir + self._model: Optional[SentenceTransformer] = None + + @property + def model_name(self) -> str: + """Get the model name.""" + return self._model_name + + @property + def device(self) -> str: + """Get the device being used.""" + return self._device + + def load_model(self, force_download: bool = False) -> SentenceTransformer: + """Load the embedding model. + + Args: + force_download: Force re-download of the model + + Returns: + Loaded SentenceTransformer model + """ + if self._model is not None and not force_download: + return self._model + + try: + model_kwargs = {"device": self._device} + if self._cache_dir: + model_kwargs["cache_folder"] = str(self._cache_dir) + + self._model = SentenceTransformer(self._model_name, **model_kwargs) + logger.info(f"Loaded embedding model: {self._model_name} on {self._device}") + return self._model + + except Exception as e: + logger.error(f"Failed to load model {self._model_name}: {e}") + raise + + def embed(self, texts: List[str], show_progress: bool = False) -> List[List[float]]: + """Generate embeddings for a list of texts. + + Args: + texts: List of text strings to embed + show_progress: Show progress bar + + Returns: + List of embedding vectors + """ + if not texts: + return [] + + model = self.load_model() + embeddings = model.encode( + texts, + show_progress_bar=show_progress, + convert_to_numpy=True, + ) + return embeddings.tolist() + + def embed_query(self, query: str) -> List[float]: + """Generate embedding for a single query. + + Args: + query: Query string + + Returns: + Embedding vector + """ + return self.embed([query])[0] + + def get_embedding_dim(self) -> int: + """Get the embedding dimension. + + Returns: + Dimension of the embedding vectors + """ + model = self.load_model() + return model.get_sentence_embedding_dimension() + + def unload_model(self) -> None: + """Unload the model to free memory.""" + self._model = None + logger.info("Unloaded embedding model") + + def __repr__(self) -> str: + return f"EmbeddingManager(model={self._model_name}, device={self._device})"