fix: resolve CI/CD issues with proper package structure and imports
This commit is contained in:
117
src/local_api_docs_search/search/embeddings.py
Normal file
117
src/local_api_docs_search/search/embeddings.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Embedding model management using sentence-transformers."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""Manages local embedding models for semantic search."""
|
||||
|
||||
DEFAULT_MODEL = "all-MiniLM-L6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
cache_dir: Optional[Path] = None,
|
||||
):
|
||||
"""Initialize the embedding manager.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to use (default: all-MiniLM-L6-v2)
|
||||
device: Device to run on (cpu, cuda, auto)
|
||||
cache_dir: Directory to cache models
|
||||
"""
|
||||
self._model_name = model_name or self.DEFAULT_MODEL
|
||||
self._device = device or "cpu"
|
||||
self._cache_dir = cache_dir
|
||||
self._model: Optional[SentenceTransformer] = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Get the model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Get the device being used."""
|
||||
return self._device
|
||||
|
||||
def load_model(self, force_download: bool = False) -> SentenceTransformer:
|
||||
"""Load the embedding model.
|
||||
|
||||
Args:
|
||||
force_download: Force re-download of the model
|
||||
|
||||
Returns:
|
||||
Loaded SentenceTransformer model
|
||||
"""
|
||||
if self._model is not None and not force_download:
|
||||
return self._model
|
||||
|
||||
try:
|
||||
model_kwargs = {"device": self._device}
|
||||
if self._cache_dir:
|
||||
model_kwargs["cache_folder"] = str(self._cache_dir)
|
||||
|
||||
self._model = SentenceTransformer(self._model_name, **model_kwargs)
|
||||
logger.info(f"Loaded embedding model: {self._model_name} on {self._device}")
|
||||
return self._model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {self._model_name}: {e}")
|
||||
raise
|
||||
|
||||
def embed(self, texts: List[str], show_progress: bool = False) -> List[List[float]]:
|
||||
"""Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
show_progress: Show progress bar
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
model = self.load_model()
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
show_progress_bar=show_progress,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, query: str) -> List[float]:
|
||||
"""Generate embedding for a single query.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return self.embed([query])[0]
|
||||
|
||||
def get_embedding_dim(self) -> int:
|
||||
"""Get the embedding dimension.
|
||||
|
||||
Returns:
|
||||
Dimension of the embedding vectors
|
||||
"""
|
||||
model = self.load_model()
|
||||
return model.get_sentence_embedding_dimension()
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
self._model = None
|
||||
logger.info("Unloaded embedding model")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"EmbeddingManager(model={self._model_name}, device={self._device})"
|
||||
Reference in New Issue
Block a user