Files
local-api-docs-search/src/search/vectorstore.py
7000pctAUTO acea9424a4
Some checks failed
CI / test (push) Has been cancelled
CI / build (push) Has been cancelled
Add search modules (embeddings, vectorstore, searcher)
2026-02-03 01:23:02 +00:00

306 lines
9.3 KiB
Python

"""Vector storage operations using ChromaDB."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import chromadb
from chromadb.config import Settings
from src.models.document import Document, IndexStats, SourceType
logger = logging.getLogger(__name__)
class VectorStore:
"""ChromaDB-based vector storage for document embeddings."""
COLLECTION_NAME = "api_docs"
def __init__(
self,
persist_dir: Path,
collection_name: Optional[str] = None,
):
"""Initialize the vector store.
Args:
persist_dir: Directory for persistence
collection_name: Name of the collection (default: api_docs)
"""
self._persist_dir = Path(persist_dir)
self._persist_dir.mkdir(parents=True, exist_ok=True)
self._collection_name = collection_name or self.COLLECTION_NAME
self._client: Optional[chromadb.Client] = None
self._collection: Optional[chromadb.Collection] = None
def _get_client(self) -> chromadb.Client:
"""Get or create the ChromaDB client."""
if self._client is None:
self._client = chromadb.Client(
Settings(
persist_directory=str(self._persist_dir),
anonymized_telemetry=False,
)
)
return self._client
def _get_collection(self) -> chromadb.Collection:
"""Get or create the collection."""
if self._collection is None:
client = self._get_client()
try:
self._collection = client.get_collection(self._collection_name)
except ValueError:
self._collection = client.create_collection(self._collection_name)
logger.info(f"Created new collection: {self._collection_name}")
return self._collection
def add_documents(
self,
documents: List[Document],
embeddings: List[List[float]],
batch_size: int = 100,
) -> int:
"""Add documents and their embeddings to the store.
Args:
documents: List of Document objects
embeddings: List of embedding vectors
batch_size: Documents per batch
Returns:
Number of documents added
"""
if not documents:
return 0
collection = self._get_collection()
total_added = 0
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
ids = [doc.id for doc in batch_docs]
contents = [doc.content for doc in batch_docs]
metadatas = [
{
"source_type": doc.source_type.value,
"title": doc.title,
"file_path": doc.file_path,
**doc.metadata,
}
for doc in batch_docs
]
try:
collection.add(
ids=ids,
documents=contents,
embeddings=batch_embeddings,
metadatas=metadatas,
)
total_added += len(batch_docs)
logger.debug(f"Added batch of {len(batch_docs)} documents")
except Exception as e:
logger.error(f"Failed to add batch: {e}")
logger.info(f"Added {total_added} documents to collection")
return total_added
def search(
self,
query_embedding: List[float],
n_results: int = 10,
source_type: Optional[SourceType] = None,
) -> List[Dict[str, Any]]:
"""Search for similar documents.
Args:
query_embedding: Query embedding vector
n_results: Number of results to return
source_type: Optional filter by source type
Returns:
List of search results with documents and scores
"""
collection = self._get_collection()
where_filter = None
if source_type:
where_filter = {"source_type": source_type.value}
try:
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where_filter,
include=["documents", "metadatas", "distances"],
)
except Exception as e:
logger.error(f"Search failed: {e}")
return []
search_results = []
if results["ids"] and results["ids"][0]:
for i in range(len(results["ids"][0])):
result = {
"id": results["ids"][0][i],
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
"score": 1.0 - results["distances"][0][i],
}
search_results.append(result)
return search_results
def delete_index(self) -> bool:
"""Delete the entire index.
Returns:
True if successful
"""
try:
client = self._get_client()
client.delete_collection(self._collection_name)
self._collection = None
logger.info(f"Deleted collection: {self._collection_name}")
return True
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
return False
def get_stats(self) -> IndexStats:
"""Get statistics about the index.
Returns:
IndexStats object
"""
collection = self._get_collection()
total = collection.count()
source_counts = {type.value: 0 for type in SourceType}
try:
all_metadata = collection.get(include=["metadatas"])
for metadata in all_metadata.get("metadatas", []):
source_type = metadata.get("source_type")
if source_type in source_counts:
source_counts[source_type] += 1
except Exception as e:
logger.warning(f"Failed to get source counts: {e}")
return IndexStats(
total_documents=total,
openapi_count=source_counts[SourceType.OPENAPI.value],
readme_count=source_counts[SourceType.README.value],
code_count=source_counts[SourceType.CODE.value],
)
def get_all_documents(
self, limit: int = 1000, offset: int = 0
) -> List[Document]:
"""Get all documents from the store.
Args:
limit: Maximum number of documents
offset: Offset for pagination
Returns:
List of Document objects
"""
collection = self._get_collection()
try:
results = collection.get(limit=limit, offset=offset, include=["documents", "metadatas"])
except Exception as e:
logger.error(f"Failed to get documents: {e}")
return []
documents = []
for i in range(len(results["ids"])):
metadata = results["metadatas"][i]
doc = Document(
id=results["ids"][i],
content=results["documents"][i],
source_type=SourceType(metadata["source_type"]),
title=metadata["title"],
file_path=metadata["file_path"],
metadata={k: v for k, v in metadata.items() if k not in ["source_type", "title", "file_path"]},
)
documents.append(doc)
return documents
def delete_by_ids(self, ids: List[str]) -> int:
"""Delete documents by IDs.
Args:
ids: List of document IDs to delete
Returns:
Number of documents deleted
"""
if not ids:
return 0
collection = self._get_collection()
try:
collection.delete(ids=ids)
logger.info(f"Deleted {len(ids)} documents")
return len(ids)
except Exception as e:
logger.error(f"Failed to delete documents: {e}")
return 0
def delete_by_source_type(self, source_type: SourceType) -> int:
"""Delete all documents of a given source type.
Args:
source_type: Source type to delete
Returns:
Number of documents deleted
"""
collection = self._get_collection()
try:
results = collection.get(where={"source_type": source_type.value})
if results["ids"]:
return self.delete_by_ids(results["ids"])
except Exception as e:
logger.error(f"Failed to delete by source type: {e}")
return 0
def exists(self) -> bool:
"""Check if the collection exists.
Returns:
True if collection exists
"""
try:
client = self._get_client()
client.get_collection(self._collection_name)
return True
except ValueError:
return False
def count(self) -> int:
"""Get the document count.
Returns:
Number of documents in the store
"""
collection = self._get_collection()
return collection.count()
def close(self) -> None:
"""Close the client connection."""
self._client = None
self._collection = None