Add search modules (embeddings, vectorstore, searcher)
This commit is contained in:
305
src/search/vectorstore.py
Normal file
305
src/search/vectorstore.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Vector storage operations using ChromaDB."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from src.models.document import Document, IndexStats, SourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""ChromaDB-based vector storage for document embeddings."""
|
||||
|
||||
COLLECTION_NAME = "api_docs"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_dir: Path,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the vector store.
|
||||
|
||||
Args:
|
||||
persist_dir: Directory for persistence
|
||||
collection_name: Name of the collection (default: api_docs)
|
||||
"""
|
||||
self._persist_dir = Path(persist_dir)
|
||||
self._persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._collection_name = collection_name or self.COLLECTION_NAME
|
||||
self._client: Optional[chromadb.Client] = None
|
||||
self._collection: Optional[chromadb.Collection] = None
|
||||
|
||||
def _get_client(self) -> chromadb.Client:
|
||||
"""Get or create the ChromaDB client."""
|
||||
if self._client is None:
|
||||
self._client = chromadb.Client(
|
||||
Settings(
|
||||
persist_directory=str(self._persist_dir),
|
||||
anonymized_telemetry=False,
|
||||
)
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _get_collection(self) -> chromadb.Collection:
|
||||
"""Get or create the collection."""
|
||||
if self._collection is None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
self._collection = client.get_collection(self._collection_name)
|
||||
except ValueError:
|
||||
self._collection = client.create_collection(self._collection_name)
|
||||
logger.info(f"Created new collection: {self._collection_name}")
|
||||
return self._collection
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
embeddings: List[List[float]],
|
||||
batch_size: int = 100,
|
||||
) -> int:
|
||||
"""Add documents and their embeddings to the store.
|
||||
|
||||
Args:
|
||||
documents: List of Document objects
|
||||
embeddings: List of embedding vectors
|
||||
batch_size: Documents per batch
|
||||
|
||||
Returns:
|
||||
Number of documents added
|
||||
"""
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
collection = self._get_collection()
|
||||
|
||||
total_added = 0
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
|
||||
ids = [doc.id for doc in batch_docs]
|
||||
contents = [doc.content for doc in batch_docs]
|
||||
metadatas = [
|
||||
{
|
||||
"source_type": doc.source_type.value,
|
||||
"title": doc.title,
|
||||
"file_path": doc.file_path,
|
||||
**doc.metadata,
|
||||
}
|
||||
for doc in batch_docs
|
||||
]
|
||||
|
||||
try:
|
||||
collection.add(
|
||||
ids=ids,
|
||||
documents=contents,
|
||||
embeddings=batch_embeddings,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
total_added += len(batch_docs)
|
||||
logger.debug(f"Added batch of {len(batch_docs)} documents")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add batch: {e}")
|
||||
|
||||
logger.info(f"Added {total_added} documents to collection")
|
||||
return total_added
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
n_results: int = 10,
|
||||
source_type: Optional[SourceType] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar documents.
|
||||
|
||||
Args:
|
||||
query_embedding: Query embedding vector
|
||||
n_results: Number of results to return
|
||||
source_type: Optional filter by source type
|
||||
|
||||
Returns:
|
||||
List of search results with documents and scores
|
||||
"""
|
||||
collection = self._get_collection()
|
||||
|
||||
where_filter = None
|
||||
if source_type:
|
||||
where_filter = {"source_type": source_type.value}
|
||||
|
||||
try:
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where_filter,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
return []
|
||||
|
||||
search_results = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i in range(len(results["ids"][0])):
|
||||
result = {
|
||||
"id": results["ids"][0][i],
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
"score": 1.0 - results["distances"][0][i],
|
||||
}
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
def delete_index(self) -> bool:
|
||||
"""Delete the entire index.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.delete_collection(self._collection_name)
|
||||
self._collection = None
|
||||
logger.info(f"Deleted collection: {self._collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> IndexStats:
|
||||
"""Get statistics about the index.
|
||||
|
||||
Returns:
|
||||
IndexStats object
|
||||
"""
|
||||
collection = self._get_collection()
|
||||
|
||||
total = collection.count()
|
||||
|
||||
source_counts = {type.value: 0 for type in SourceType}
|
||||
|
||||
try:
|
||||
all_metadata = collection.get(include=["metadatas"])
|
||||
for metadata in all_metadata.get("metadatas", []):
|
||||
source_type = metadata.get("source_type")
|
||||
if source_type in source_counts:
|
||||
source_counts[source_type] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get source counts: {e}")
|
||||
|
||||
return IndexStats(
|
||||
total_documents=total,
|
||||
openapi_count=source_counts[SourceType.OPENAPI.value],
|
||||
readme_count=source_counts[SourceType.README.value],
|
||||
code_count=source_counts[SourceType.CODE.value],
|
||||
)
|
||||
|
||||
def get_all_documents(
|
||||
self, limit: int = 1000, offset: int = 0
|
||||
) -> List[Document]:
|
||||
"""Get all documents from the store.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of documents
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
collection = self._get_collection()
|
||||
|
||||
try:
|
||||
results = collection.get(limit=limit, offset=offset, include=["documents", "metadatas"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get documents: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
for i in range(len(results["ids"])):
|
||||
metadata = results["metadatas"][i]
|
||||
doc = Document(
|
||||
id=results["ids"][i],
|
||||
content=results["documents"][i],
|
||||
source_type=SourceType(metadata["source_type"]),
|
||||
title=metadata["title"],
|
||||
file_path=metadata["file_path"],
|
||||
metadata={k: v for k, v in metadata.items() if k not in ["source_type", "title", "file_path"]},
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def delete_by_ids(self, ids: List[str]) -> int:
|
||||
"""Delete documents by IDs.
|
||||
|
||||
Args:
|
||||
ids: List of document IDs to delete
|
||||
|
||||
Returns:
|
||||
Number of documents deleted
|
||||
"""
|
||||
if not ids:
|
||||
return 0
|
||||
|
||||
collection = self._get_collection()
|
||||
|
||||
try:
|
||||
collection.delete(ids=ids)
|
||||
logger.info(f"Deleted {len(ids)} documents")
|
||||
return len(ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete documents: {e}")
|
||||
return 0
|
||||
|
||||
def delete_by_source_type(self, source_type: SourceType) -> int:
|
||||
"""Delete all documents of a given source type.
|
||||
|
||||
Args:
|
||||
source_type: Source type to delete
|
||||
|
||||
Returns:
|
||||
Number of documents deleted
|
||||
"""
|
||||
collection = self._get_collection()
|
||||
|
||||
try:
|
||||
results = collection.get(where={"source_type": source_type.value})
|
||||
if results["ids"]:
|
||||
return self.delete_by_ids(results["ids"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete by source type: {e}")
|
||||
|
||||
return 0
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if the collection exists.
|
||||
|
||||
Returns:
|
||||
True if collection exists
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.get_collection(self._collection_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def count(self) -> int:
|
||||
"""Get the document count.
|
||||
|
||||
Returns:
|
||||
Number of documents in the store
|
||||
"""
|
||||
collection = self._get_collection()
|
||||
return collection.count()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client connection."""
|
||||
self._client = None
|
||||
self._collection = None
|
||||
Reference in New Issue
Block a user