"""Vector storage operations using ChromaDB.""" import logging from pathlib import Path from typing import Any, Dict, List, Optional import chromadb from chromadb.config import Settings from src.models.document import Document, IndexStats, SourceType logger = logging.getLogger(__name__) class VectorStore: """ChromaDB-based vector storage for document embeddings.""" COLLECTION_NAME = "api_docs" def __init__( self, persist_dir: Path, collection_name: Optional[str] = None, ): """Initialize the vector store. Args: persist_dir: Directory for persistence collection_name: Name of the collection (default: api_docs) """ self._persist_dir = Path(persist_dir) self._persist_dir.mkdir(parents=True, exist_ok=True) self._collection_name = collection_name or self.COLLECTION_NAME self._client: Optional[chromadb.Client] = None self._collection: Optional[chromadb.Collection] = None def _get_client(self) -> chromadb.Client: """Get or create the ChromaDB client.""" if self._client is None: self._client = chromadb.Client( Settings( persist_directory=str(self._persist_dir), anonymized_telemetry=False, ) ) return self._client def _get_collection(self) -> chromadb.Collection: """Get or create the collection.""" if self._collection is None: client = self._get_client() try: self._collection = client.get_collection(self._collection_name) except ValueError: self._collection = client.create_collection(self._collection_name) logger.info(f"Created new collection: {self._collection_name}") return self._collection def add_documents( self, documents: List[Document], embeddings: List[List[float]], batch_size: int = 100, ) -> int: """Add documents and their embeddings to the store. Args: documents: List of Document objects embeddings: List of embedding vectors batch_size: Documents per batch Returns: Number of documents added """ if not documents: return 0 collection = self._get_collection() total_added = 0 for i in range(0, len(documents), batch_size): batch_docs = documents[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] ids = [doc.id for doc in batch_docs] contents = [doc.content for doc in batch_docs] metadatas = [ { "source_type": doc.source_type.value, "title": doc.title, "file_path": doc.file_path, **doc.metadata, } for doc in batch_docs ] try: collection.add( ids=ids, documents=contents, embeddings=batch_embeddings, metadatas=metadatas, ) total_added += len(batch_docs) logger.debug(f"Added batch of {len(batch_docs)} documents") except Exception as e: logger.error(f"Failed to add batch: {e}") logger.info(f"Added {total_added} documents to collection") return total_added def search( self, query_embedding: List[float], n_results: int = 10, source_type: Optional[SourceType] = None, ) -> List[Dict[str, Any]]: """Search for similar documents. Args: query_embedding: Query embedding vector n_results: Number of results to return source_type: Optional filter by source type Returns: List of search results with documents and scores """ collection = self._get_collection() where_filter = None if source_type: where_filter = {"source_type": source_type.value} try: results = collection.query( query_embeddings=[query_embedding], n_results=n_results, where=where_filter, include=["documents", "metadatas", "distances"], ) except Exception as e: logger.error(f"Search failed: {e}") return [] search_results = [] if results["ids"] and results["ids"][0]: for i in range(len(results["ids"][0])): result = { "id": results["ids"][0][i], "content": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], "score": 1.0 - results["distances"][0][i], } search_results.append(result) return search_results def delete_index(self) -> bool: """Delete the entire index. Returns: True if successful """ try: client = self._get_client() client.delete_collection(self._collection_name) self._collection = None logger.info(f"Deleted collection: {self._collection_name}") return True except Exception as e: logger.error(f"Failed to delete collection: {e}") return False def get_stats(self) -> IndexStats: """Get statistics about the index. Returns: IndexStats object """ collection = self._get_collection() total = collection.count() source_counts = {type.value: 0 for type in SourceType} try: all_metadata = collection.get(include=["metadatas"]) for metadata in all_metadata.get("metadatas", []): source_type = metadata.get("source_type") if source_type in source_counts: source_counts[source_type] += 1 except Exception as e: logger.warning(f"Failed to get source counts: {e}") return IndexStats( total_documents=total, openapi_count=source_counts[SourceType.OPENAPI.value], readme_count=source_counts[SourceType.README.value], code_count=source_counts[SourceType.CODE.value], ) def get_all_documents( self, limit: int = 1000, offset: int = 0 ) -> List[Document]: """Get all documents from the store. Args: limit: Maximum number of documents offset: Offset for pagination Returns: List of Document objects """ collection = self._get_collection() try: results = collection.get(limit=limit, offset=offset, include=["documents", "metadatas"]) except Exception as e: logger.error(f"Failed to get documents: {e}") return [] documents = [] for i in range(len(results["ids"])): metadata = results["metadatas"][i] doc = Document( id=results["ids"][i], content=results["documents"][i], source_type=SourceType(metadata["source_type"]), title=metadata["title"], file_path=metadata["file_path"], metadata={k: v for k, v in metadata.items() if k not in ["source_type", "title", "file_path"]}, ) documents.append(doc) return documents def delete_by_ids(self, ids: List[str]) -> int: """Delete documents by IDs. Args: ids: List of document IDs to delete Returns: Number of documents deleted """ if not ids: return 0 collection = self._get_collection() try: collection.delete(ids=ids) logger.info(f"Deleted {len(ids)} documents") return len(ids) except Exception as e: logger.error(f"Failed to delete documents: {e}") return 0 def delete_by_source_type(self, source_type: SourceType) -> int: """Delete all documents of a given source type. Args: source_type: Source type to delete Returns: Number of documents deleted """ collection = self._get_collection() try: results = collection.get(where={"source_type": source_type.value}) if results["ids"]: return self.delete_by_ids(results["ids"]) except Exception as e: logger.error(f"Failed to delete by source type: {e}") return 0 def exists(self) -> bool: """Check if the collection exists. Returns: True if collection exists """ try: client = self._get_client() client.get_collection(self._collection_name) return True except ValueError: return False def count(self) -> int: """Get the document count. Returns: Number of documents in the store """ collection = self._get_collection() return collection.count() def close(self) -> None: """Close the client connection.""" self._client = None self._collection = None