fix: resolve CI/CD issues with proper package structure and imports
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
CI / build (push) Has been cancelled

This commit is contained in:
2026-02-03 03:54:45 +00:00
parent d27d8fffa9
commit 4c9c795764

View File

@@ -0,0 +1,305 @@
"""Vector storage operations using ChromaDB."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import chromadb
from chromadb.config import Settings
from local_api_docs_search.models.document import Document, IndexStats, SourceType
logger = logging.getLogger(__name__)
class VectorStore:
"""ChromaDB-based vector storage for document embeddings."""
COLLECTION_NAME = "api_docs"
def __init__(
self,
persist_dir: Path,
collection_name: Optional[str] = None,
):
"""Initialize the vector store.
Args:
persist_dir: Directory for persistence
collection_name: Name of the collection (default: api_docs)
"""
self._persist_dir = Path(persist_dir)
self._persist_dir.mkdir(parents=True, exist_ok=True)
self._collection_name = collection_name or self.COLLECTION_NAME
self._client: Optional[chromadb.Client] = None
self._collection: Optional[chromadb.Collection] = None
def _get_client(self) -> chromadb.Client:
"""Get or create the ChromaDB client."""
if self._client is None:
self._client = chromadb.Client(
Settings(
persist_directory=str(self._persist_dir),
anonymized_telemetry=False,
)
)
return self._client
def _get_collection(self) -> chromadb.Collection:
"""Get or create the collection."""
if self._collection is None:
client = self._get_client()
try:
self._collection = client.get_collection(self._collection_name)
except ValueError:
self._collection = client.create_collection(self._collection_name)
logger.info(f"Created new collection: {self._collection_name}")
return self._collection
def add_documents(
self,
documents: List[Document],
embeddings: List[List[float]],
batch_size: int = 100,
) -> int:
"""Add documents and their embeddings to the store.
Args:
documents: List of Document objects
embeddings: List of embedding vectors
batch_size: Documents per batch
Returns:
Number of documents added
"""
if not documents:
return 0
collection = self._get_collection()
total_added = 0
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
ids = [doc.id for doc in batch_docs]
contents = [doc.content for doc in batch_docs]
metadatas = [
{
"source_type": doc.source_type.value,
"title": doc.title,
"file_path": doc.file_path,
**doc.metadata,
}
for doc in batch_docs
]
try:
collection.add(
ids=ids,
documents=contents,
embeddings=batch_embeddings,
metadatas=metadatas,
)
total_added += len(batch_docs)
logger.debug(f"Added batch of {len(batch_docs)} documents")
except Exception as e:
logger.error(f"Failed to add batch: {e}")
logger.info(f"Added {total_added} documents to collection")
return total_added
def search(
self,
query_embedding: List[float],
n_results: int = 10,
source_type: Optional[SourceType] = None,
) -> List[Dict[str, Any]]:
"""Search for similar documents.
Args:
query_embedding: Query embedding vector
n_results: Number of results to return
source_type: Optional filter by source type
Returns:
List of search results with documents and scores
"""
collection = self._get_collection()
where_filter = None
if source_type:
where_filter = {"source_type": source_type.value}
try:
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where_filter,
include=["documents", "metadatas", "distances"],
)
except Exception as e:
logger.error(f"Search failed: {e}")
return []
search_results = []
if results["ids"] and results["ids"][0]:
for i in range(len(results["ids"][0])):
result = {
"id": results["ids"][0][i],
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
"score": 1.0 - results["distances"][0][i],
}
search_results.append(result)
return search_results
def delete_index(self) -> bool:
"""Delete the entire index.
Returns:
True if successful
"""
try:
client = self._get_client()
client.delete_collection(self._collection_name)
self._collection = None
logger.info(f"Deleted collection: {self._collection_name}")
return True
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
return False
def get_stats(self) -> IndexStats:
"""Get statistics about the index.
Returns:
IndexStats object
"""
collection = self._get_collection()
total = collection.count()
source_counts = {type.value: 0 for type in SourceType}
try:
all_metadata = collection.get(include=["metadatas"])
for metadata in all_metadata.get("metadatas", []):
source_type = metadata.get("source_type")
if source_type in source_counts:
source_counts[source_type] += 1
except Exception as e:
logger.warning(f"Failed to get source counts: {e}")
return IndexStats(
total_documents=total,
openapi_count=source_counts[SourceType.OPENAPI.value],
readme_count=source_counts[SourceType.README.value],
code_count=source_counts[SourceType.CODE.value],
)
def get_all_documents(
self, limit: int = 1000, offset: int = 0
) -> List[Document]:
"""Get all documents from the store.
Args:
limit: Maximum number of documents
offset: Offset for pagination
Returns:
List of Document objects
"""
collection = self._get_collection()
try:
results = collection.get(limit=limit, offset=offset, include=["documents", "metadatas"])
except Exception as e:
logger.error(f"Failed to get documents: {e}")
return []
documents = []
for i in range(len(results["ids"])):
metadata = results["metadatas"][i]
doc = Document(
id=results["ids"][i],
content=results["documents"][i],
source_type=SourceType(metadata["source_type"]),
title=metadata["title"],
file_path=metadata["file_path"],
metadata={k: v for k, v in metadata.items() if k not in ["source_type", "title", "file_path"]},
)
documents.append(doc)
return documents
def delete_by_ids(self, ids: List[str]) -> int:
"""Delete documents by IDs.
Args:
ids: List of document IDs to delete
Returns:
Number of documents deleted
"""
if not ids:
return 0
collection = self._get_collection()
try:
collection.delete(ids=ids)
logger.info(f"Deleted {len(ids)} documents")
return len(ids)
except Exception as e:
logger.error(f"Failed to delete documents: {e}")
return 0
def delete_by_source_type(self, source_type: SourceType) -> int:
"""Delete all documents of a given source type.
Args:
source_type: Source type to delete
Returns:
Number of documents deleted
"""
collection = self._get_collection()
try:
results = collection.get(where={"source_type": source_type.value})
if results["ids"]:
return self.delete_by_ids(results["ids"])
except Exception as e:
logger.error(f"Failed to delete by source type: {e}")
return 0
def exists(self) -> bool:
"""Check if the collection exists.
Returns:
True if collection exists
"""
try:
client = self._get_client()
client.get_collection(self._collection_name)
return True
except ValueError:
return False
def count(self) -> int:
"""Get the document count.
Returns:
Number of documents in the store
"""
collection = self._get_collection()
return collection.count()
def close(self) -> None:
"""Close the client connection."""
self._client = None
self._collection = None