diff --git a/src/local_api_docs_search/search/vectorstore.py b/src/local_api_docs_search/search/vectorstore.py new file mode 100644 index 0000000..50e86e5 --- /dev/null +++ b/src/local_api_docs_search/search/vectorstore.py @@ -0,0 +1,305 @@ +"""Vector storage operations using ChromaDB.""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +import chromadb +from chromadb.config import Settings + +from local_api_docs_search.models.document import Document, IndexStats, SourceType + +logger = logging.getLogger(__name__) + + +class VectorStore: + """ChromaDB-based vector storage for document embeddings.""" + + COLLECTION_NAME = "api_docs" + + def __init__( + self, + persist_dir: Path, + collection_name: Optional[str] = None, + ): + """Initialize the vector store. + + Args: + persist_dir: Directory for persistence + collection_name: Name of the collection (default: api_docs) + """ + self._persist_dir = Path(persist_dir) + self._persist_dir.mkdir(parents=True, exist_ok=True) + self._collection_name = collection_name or self.COLLECTION_NAME + self._client: Optional[chromadb.Client] = None + self._collection: Optional[chromadb.Collection] = None + + def _get_client(self) -> chromadb.Client: + """Get or create the ChromaDB client.""" + if self._client is None: + self._client = chromadb.Client( + Settings( + persist_directory=str(self._persist_dir), + anonymized_telemetry=False, + ) + ) + return self._client + + def _get_collection(self) -> chromadb.Collection: + """Get or create the collection.""" + if self._collection is None: + client = self._get_client() + try: + self._collection = client.get_collection(self._collection_name) + except ValueError: + self._collection = client.create_collection(self._collection_name) + logger.info(f"Created new collection: {self._collection_name}") + return self._collection + + def add_documents( + self, + documents: List[Document], + embeddings: List[List[float]], + batch_size: int = 100, + ) -> int: + """Add documents and their embeddings to the store. + + Args: + documents: List of Document objects + embeddings: List of embedding vectors + batch_size: Documents per batch + + Returns: + Number of documents added + """ + if not documents: + return 0 + + collection = self._get_collection() + + total_added = 0 + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] + + ids = [doc.id for doc in batch_docs] + contents = [doc.content for doc in batch_docs] + metadatas = [ + { + "source_type": doc.source_type.value, + "title": doc.title, + "file_path": doc.file_path, + **doc.metadata, + } + for doc in batch_docs + ] + + try: + collection.add( + ids=ids, + documents=contents, + embeddings=batch_embeddings, + metadatas=metadatas, + ) + total_added += len(batch_docs) + logger.debug(f"Added batch of {len(batch_docs)} documents") + except Exception as e: + logger.error(f"Failed to add batch: {e}") + + logger.info(f"Added {total_added} documents to collection") + return total_added + + def search( + self, + query_embedding: List[float], + n_results: int = 10, + source_type: Optional[SourceType] = None, + ) -> List[Dict[str, Any]]: + """Search for similar documents. + + Args: + query_embedding: Query embedding vector + n_results: Number of results to return + source_type: Optional filter by source type + + Returns: + List of search results with documents and scores + """ + collection = self._get_collection() + + where_filter = None + if source_type: + where_filter = {"source_type": source_type.value} + + try: + results = collection.query( + query_embeddings=[query_embedding], + n_results=n_results, + where=where_filter, + include=["documents", "metadatas", "distances"], + ) + except Exception as e: + logger.error(f"Search failed: {e}") + return [] + + search_results = [] + if results["ids"] and results["ids"][0]: + for i in range(len(results["ids"][0])): + result = { + "id": results["ids"][0][i], + "content": results["documents"][0][i], + "metadata": results["metadatas"][0][i], + "distance": results["distances"][0][i], + "score": 1.0 - results["distances"][0][i], + } + search_results.append(result) + + return search_results + + def delete_index(self) -> bool: + """Delete the entire index. + + Returns: + True if successful + """ + try: + client = self._get_client() + client.delete_collection(self._collection_name) + self._collection = None + logger.info(f"Deleted collection: {self._collection_name}") + return True + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + return False + + def get_stats(self) -> IndexStats: + """Get statistics about the index. + + Returns: + IndexStats object + """ + collection = self._get_collection() + + total = collection.count() + + source_counts = {type.value: 0 for type in SourceType} + + try: + all_metadata = collection.get(include=["metadatas"]) + for metadata in all_metadata.get("metadatas", []): + source_type = metadata.get("source_type") + if source_type in source_counts: + source_counts[source_type] += 1 + except Exception as e: + logger.warning(f"Failed to get source counts: {e}") + + return IndexStats( + total_documents=total, + openapi_count=source_counts[SourceType.OPENAPI.value], + readme_count=source_counts[SourceType.README.value], + code_count=source_counts[SourceType.CODE.value], + ) + + def get_all_documents( + self, limit: int = 1000, offset: int = 0 + ) -> List[Document]: + """Get all documents from the store. + + Args: + limit: Maximum number of documents + offset: Offset for pagination + + Returns: + List of Document objects + """ + collection = self._get_collection() + + try: + results = collection.get(limit=limit, offset=offset, include=["documents", "metadatas"]) + except Exception as e: + logger.error(f"Failed to get documents: {e}") + return [] + + documents = [] + for i in range(len(results["ids"])): + metadata = results["metadatas"][i] + doc = Document( + id=results["ids"][i], + content=results["documents"][i], + source_type=SourceType(metadata["source_type"]), + title=metadata["title"], + file_path=metadata["file_path"], + metadata={k: v for k, v in metadata.items() if k not in ["source_type", "title", "file_path"]}, + ) + documents.append(doc) + + return documents + + def delete_by_ids(self, ids: List[str]) -> int: + """Delete documents by IDs. + + Args: + ids: List of document IDs to delete + + Returns: + Number of documents deleted + """ + if not ids: + return 0 + + collection = self._get_collection() + + try: + collection.delete(ids=ids) + logger.info(f"Deleted {len(ids)} documents") + return len(ids) + except Exception as e: + logger.error(f"Failed to delete documents: {e}") + return 0 + + def delete_by_source_type(self, source_type: SourceType) -> int: + """Delete all documents of a given source type. + + Args: + source_type: Source type to delete + + Returns: + Number of documents deleted + """ + collection = self._get_collection() + + try: + results = collection.get(where={"source_type": source_type.value}) + if results["ids"]: + return self.delete_by_ids(results["ids"]) + except Exception as e: + logger.error(f"Failed to delete by source type: {e}") + + return 0 + + def exists(self) -> bool: + """Check if the collection exists. + + Returns: + True if collection exists + """ + try: + client = self._get_client() + client.get_collection(self._collection_name) + return True + except ValueError: + return False + + def count(self) -> int: + """Get the document count. + + Returns: + Number of documents in the store + """ + collection = self._get_collection() + return collection.count() + + def close(self) -> None: + """Close the client connection.""" + self._client = None + self._collection = None