fix: resolve CI/CD issues with proper package structure and imports
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
CI / build (push) Has been cancelled

This commit is contained in:
2026-02-03 03:54:44 +00:00
parent 4ea77b830b
commit d27d8fffa9

View File

@@ -0,0 +1,117 @@
"""Embedding model management using sentence-transformers."""
import logging
from pathlib import Path
from typing import List, Optional
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
class EmbeddingManager:
"""Manages local embedding models for semantic search."""
DEFAULT_MODEL = "all-MiniLM-L6-v2"
def __init__(
self,
model_name: Optional[str] = None,
device: Optional[str] = None,
cache_dir: Optional[Path] = None,
):
"""Initialize the embedding manager.
Args:
model_name: Name of the model to use (default: all-MiniLM-L6-v2)
device: Device to run on (cpu, cuda, auto)
cache_dir: Directory to cache models
"""
self._model_name = model_name or self.DEFAULT_MODEL
self._device = device or "cpu"
self._cache_dir = cache_dir
self._model: Optional[SentenceTransformer] = None
@property
def model_name(self) -> str:
"""Get the model name."""
return self._model_name
@property
def device(self) -> str:
"""Get the device being used."""
return self._device
def load_model(self, force_download: bool = False) -> SentenceTransformer:
"""Load the embedding model.
Args:
force_download: Force re-download of the model
Returns:
Loaded SentenceTransformer model
"""
if self._model is not None and not force_download:
return self._model
try:
model_kwargs = {"device": self._device}
if self._cache_dir:
model_kwargs["cache_folder"] = str(self._cache_dir)
self._model = SentenceTransformer(self._model_name, **model_kwargs)
logger.info(f"Loaded embedding model: {self._model_name} on {self._device}")
return self._model
except Exception as e:
logger.error(f"Failed to load model {self._model_name}: {e}")
raise
def embed(self, texts: List[str], show_progress: bool = False) -> List[List[float]]:
"""Generate embeddings for a list of texts.
Args:
texts: List of text strings to embed
show_progress: Show progress bar
Returns:
List of embedding vectors
"""
if not texts:
return []
model = self.load_model()
embeddings = model.encode(
texts,
show_progress_bar=show_progress,
convert_to_numpy=True,
)
return embeddings.tolist()
def embed_query(self, query: str) -> List[float]:
"""Generate embedding for a single query.
Args:
query: Query string
Returns:
Embedding vector
"""
return self.embed([query])[0]
def get_embedding_dim(self) -> int:
"""Get the embedding dimension.
Returns:
Dimension of the embedding vectors
"""
model = self.load_model()
return model.get_sentence_embedding_dimension()
def unload_model(self) -> None:
"""Unload the model to free memory."""
self._model = None
logger.info("Unloaded embedding model")
def __repr__(self) -> str:
return f"EmbeddingManager(model={self._model_name}, device={self._device})"