Add search modules (embeddings, vectorstore, searcher)
This commit is contained in:
305
src/search/vectorstore.py
Normal file
305
src/search/vectorstore.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""Vector storage operations using ChromaDB."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
from src.models.document import Document, IndexStats, SourceType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""ChromaDB-based vector storage for document embeddings."""
|
||||||
|
|
||||||
|
COLLECTION_NAME = "api_docs"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
persist_dir: Path,
|
||||||
|
collection_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
persist_dir: Directory for persistence
|
||||||
|
collection_name: Name of the collection (default: api_docs)
|
||||||
|
"""
|
||||||
|
self._persist_dir = Path(persist_dir)
|
||||||
|
self._persist_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._collection_name = collection_name or self.COLLECTION_NAME
|
||||||
|
self._client: Optional[chromadb.Client] = None
|
||||||
|
self._collection: Optional[chromadb.Collection] = None
|
||||||
|
|
||||||
|
def _get_client(self) -> chromadb.Client:
|
||||||
|
"""Get or create the ChromaDB client."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = chromadb.Client(
|
||||||
|
Settings(
|
||||||
|
persist_directory=str(self._persist_dir),
|
||||||
|
anonymized_telemetry=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _get_collection(self) -> chromadb.Collection:
|
||||||
|
"""Get or create the collection."""
|
||||||
|
if self._collection is None:
|
||||||
|
client = self._get_client()
|
||||||
|
try:
|
||||||
|
self._collection = client.get_collection(self._collection_name)
|
||||||
|
except ValueError:
|
||||||
|
self._collection = client.create_collection(self._collection_name)
|
||||||
|
logger.info(f"Created new collection: {self._collection_name}")
|
||||||
|
return self._collection
|
||||||
|
|
||||||
|
def add_documents(
|
||||||
|
self,
|
||||||
|
documents: List[Document],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
batch_size: int = 100,
|
||||||
|
) -> int:
|
||||||
|
"""Add documents and their embeddings to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of Document objects
|
||||||
|
embeddings: List of embedding vectors
|
||||||
|
batch_size: Documents per batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of documents added
|
||||||
|
"""
|
||||||
|
if not documents:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
total_added = 0
|
||||||
|
for i in range(0, len(documents), batch_size):
|
||||||
|
batch_docs = documents[i : i + batch_size]
|
||||||
|
batch_embeddings = embeddings[i : i + batch_size]
|
||||||
|
|
||||||
|
ids = [doc.id for doc in batch_docs]
|
||||||
|
contents = [doc.content for doc in batch_docs]
|
||||||
|
metadatas = [
|
||||||
|
{
|
||||||
|
"source_type": doc.source_type.value,
|
||||||
|
"title": doc.title,
|
||||||
|
"file_path": doc.file_path,
|
||||||
|
**doc.metadata,
|
||||||
|
}
|
||||||
|
for doc in batch_docs
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
collection.add(
|
||||||
|
ids=ids,
|
||||||
|
documents=contents,
|
||||||
|
embeddings=batch_embeddings,
|
||||||
|
metadatas=metadatas,
|
||||||
|
)
|
||||||
|
total_added += len(batch_docs)
|
||||||
|
logger.debug(f"Added batch of {len(batch_docs)} documents")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add batch: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Added {total_added} documents to collection")
|
||||||
|
return total_added
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
n_results: int = 10,
|
||||||
|
source_type: Optional[SourceType] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search for similar documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: Query embedding vector
|
||||||
|
n_results: Number of results to return
|
||||||
|
source_type: Optional filter by source type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results with documents and scores
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
where_filter = None
|
||||||
|
if source_type:
|
||||||
|
where_filter = {"source_type": source_type.value}
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = collection.query(
|
||||||
|
query_embeddings=[query_embedding],
|
||||||
|
n_results=n_results,
|
||||||
|
where=where_filter,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
search_results = []
|
||||||
|
if results["ids"] and results["ids"][0]:
|
||||||
|
for i in range(len(results["ids"][0])):
|
||||||
|
result = {
|
||||||
|
"id": results["ids"][0][i],
|
||||||
|
"content": results["documents"][0][i],
|
||||||
|
"metadata": results["metadatas"][0][i],
|
||||||
|
"distance": results["distances"][0][i],
|
||||||
|
"score": 1.0 - results["distances"][0][i],
|
||||||
|
}
|
||||||
|
search_results.append(result)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
def delete_index(self) -> bool:
|
||||||
|
"""Delete the entire index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
client.delete_collection(self._collection_name)
|
||||||
|
self._collection = None
|
||||||
|
logger.info(f"Deleted collection: {self._collection_name}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete collection: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> IndexStats:
|
||||||
|
"""Get statistics about the index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndexStats object
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
total = collection.count()
|
||||||
|
|
||||||
|
source_counts = {type.value: 0 for type in SourceType}
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_metadata = collection.get(include=["metadatas"])
|
||||||
|
for metadata in all_metadata.get("metadatas", []):
|
||||||
|
source_type = metadata.get("source_type")
|
||||||
|
if source_type in source_counts:
|
||||||
|
source_counts[source_type] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get source counts: {e}")
|
||||||
|
|
||||||
|
return IndexStats(
|
||||||
|
total_documents=total,
|
||||||
|
openapi_count=source_counts[SourceType.OPENAPI.value],
|
||||||
|
readme_count=source_counts[SourceType.README.value],
|
||||||
|
code_count=source_counts[SourceType.CODE.value],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_documents(
|
||||||
|
self, limit: int = 1000, offset: int = 0
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get all documents from the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of documents
|
||||||
|
offset: Offset for pagination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Document objects
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = collection.get(limit=limit, offset=offset, include=["documents", "metadatas"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get documents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for i in range(len(results["ids"])):
|
||||||
|
metadata = results["metadatas"][i]
|
||||||
|
doc = Document(
|
||||||
|
id=results["ids"][i],
|
||||||
|
content=results["documents"][i],
|
||||||
|
source_type=SourceType(metadata["source_type"]),
|
||||||
|
title=metadata["title"],
|
||||||
|
file_path=metadata["file_path"],
|
||||||
|
metadata={k: v for k, v in metadata.items() if k not in ["source_type", "title", "file_path"]},
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: List[str]) -> int:
|
||||||
|
"""Delete documents by IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of document IDs to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of documents deleted
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
collection.delete(ids=ids)
|
||||||
|
logger.info(f"Deleted {len(ids)} documents")
|
||||||
|
return len(ids)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete documents: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def delete_by_source_type(self, source_type: SourceType) -> int:
|
||||||
|
"""Delete all documents of a given source type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_type: Source type to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of documents deleted
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = collection.get(where={"source_type": source_type.value})
|
||||||
|
if results["ids"]:
|
||||||
|
return self.delete_by_ids(results["ids"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete by source type: {e}")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def exists(self) -> bool:
|
||||||
|
"""Check if the collection exists.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if collection exists
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
client.get_collection(self._collection_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Get the document count.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of documents in the store
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
return collection.count()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client connection."""
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
Reference in New Issue
Block a user