fix: resolve CI/CD issues with proper package structure and imports
This commit is contained in:
368
src/local_api_docs_search/search/searcher.py
Normal file
368
src/local_api_docs_search/search/searcher.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Search logic with semantic similarity and hybrid search."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from local_api_docs_search.models.document import Document, SearchResult, SourceType
|
||||
from local_api_docs_search.search.embeddings import EmbeddingManager
|
||||
from local_api_docs_search.search.vectorstore import VectorStore
|
||||
from local_api_docs_search.utils.config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchOptions:
|
||||
"""Options for search operations."""
|
||||
|
||||
limit: int = 10
|
||||
source_type: Optional[SourceType] = None
|
||||
min_score: float = 0.0
|
||||
include_scores: bool = True
|
||||
|
||||
|
||||
class Searcher:
|
||||
"""Main search class for semantic and hybrid search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_manager: Optional[EmbeddingManager] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
config_path: Optional[Path] = None,
|
||||
):
|
||||
"""Initialize the searcher.
|
||||
|
||||
Args:
|
||||
embedding_manager: Embedding manager instance
|
||||
vector_store: Vector store instance
|
||||
config_path: Path to configuration file
|
||||
"""
|
||||
config = get_config(config_path)
|
||||
|
||||
self._embedding_manager = embedding_manager or EmbeddingManager(
|
||||
model_name=config.model_name,
|
||||
device=config.embedding_device,
|
||||
cache_dir=config.chroma_persist_dir / ".cache",
|
||||
)
|
||||
|
||||
self._vector_store = vector_store or VectorStore(
|
||||
persist_dir=config.chroma_persist_dir,
|
||||
)
|
||||
|
||||
self._config = config
|
||||
|
||||
def search(
|
||||
self, query: str, options: Optional[SearchOptions] = None
|
||||
) -> List[SearchResult]:
|
||||
"""Perform semantic search for a query.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
options: Search options
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects
|
||||
"""
|
||||
if options is None:
|
||||
options = SearchOptions(limit=self._config.default_limit)
|
||||
|
||||
if not query.strip():
|
||||
return []
|
||||
|
||||
try:
|
||||
query_embedding = self._embedding_manager.embed_query(query)
|
||||
|
||||
results = self._vector_store.search(
|
||||
query_embedding=query_embedding,
|
||||
n_results=options.limit * 2,
|
||||
source_type=options.source_type,
|
||||
)
|
||||
|
||||
search_results = []
|
||||
for result in results:
|
||||
if options.min_score > 0 and result["score"] < options.min_score:
|
||||
continue
|
||||
|
||||
doc = Document(
|
||||
id=result["id"],
|
||||
content=result["content"],
|
||||
source_type=SourceType(result["metadata"]["source_type"]),
|
||||
title=result["metadata"]["title"],
|
||||
file_path=result["metadata"]["file_path"],
|
||||
metadata={
|
||||
k: v
|
||||
for k, v in result["metadata"].items()
|
||||
if k not in ["source_type", "title", "file_path"]
|
||||
},
|
||||
)
|
||||
|
||||
highlights = self._generate_highlights(query, result["content"])
|
||||
|
||||
search_results.append(
|
||||
SearchResult(
|
||||
document=doc,
|
||||
score=result["score"],
|
||||
highlights=highlights,
|
||||
)
|
||||
)
|
||||
|
||||
if len(search_results) >= options.limit:
|
||||
break
|
||||
|
||||
return search_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed for query '{query}': {e}")
|
||||
return []
|
||||
|
||||
def hybrid_search(
|
||||
self, query: str, options: Optional[SearchOptions] = None
|
||||
) -> List[SearchResult]:
|
||||
"""Perform hybrid search combining semantic and keyword search.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
options: Search options
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by combined relevance
|
||||
"""
|
||||
if options is None:
|
||||
options = SearchOptions(limit=self._config.default_limit)
|
||||
|
||||
semantic_results = self.search(query, options)
|
||||
|
||||
if not query.strip():
|
||||
return semantic_results
|
||||
|
||||
keyword_results = self._keyword_search(query, options)
|
||||
|
||||
combined = {}
|
||||
for result in semantic_results:
|
||||
combined[result.document.id] = result
|
||||
|
||||
for result in keyword_results:
|
||||
if result.document.id in combined:
|
||||
existing = combined[result.document.id]
|
||||
combined[result.document.id] = SearchResult(
|
||||
document=result.document,
|
||||
score=(existing.score + result.score) / 2,
|
||||
highlights=list(set(existing.highlights + result.highlights)),
|
||||
)
|
||||
else:
|
||||
combined[result.document.id] = result
|
||||
|
||||
sorted_results = sorted(
|
||||
combined.values(), key=lambda r: r.score, reverse=True
|
||||
)
|
||||
|
||||
return sorted_results[: options.limit]
|
||||
|
||||
def _keyword_search(
|
||||
self, query: str, options: SearchOptions
|
||||
) -> List[SearchResult]:
|
||||
"""Perform keyword-based search.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
options: Search options
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects
|
||||
"""
|
||||
keywords = self._extract_keywords(query)
|
||||
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
try:
|
||||
all_docs = self._vector_store.get_all_documents(limit=1000)
|
||||
|
||||
results = []
|
||||
for doc in all_docs:
|
||||
if options.source_type and doc.source_type != options.source_type:
|
||||
continue
|
||||
|
||||
keyword_score = self._calculate_keyword_score(keywords, doc.content)
|
||||
if keyword_score > 0:
|
||||
highlights = self._generate_highlights(query, doc.content)
|
||||
results.append(
|
||||
SearchResult(
|
||||
document=doc,
|
||||
score=keyword_score,
|
||||
highlights=highlights,
|
||||
)
|
||||
)
|
||||
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
return results[: options.limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Keyword search failed: {e}")
|
||||
return []
|
||||
|
||||
def _extract_keywords(self, query: str) -> List[str]:
|
||||
"""Extract keywords from a query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
List of keywords
|
||||
"""
|
||||
stop_words = {
|
||||
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||
"of", "with", "by", "from", "up", "about", "into", "through", "during",
|
||||
"how", "what", "when", "where", "why", "which", "who", "whom",
|
||||
"this", "that", "these", "those", "is", "are", "was", "were", "be",
|
||||
"been", "being", "have", "has", "had", "do", "does", "did", "will",
|
||||
"would", "could", "should", "may", "might", "must", "shall", "can",
|
||||
}
|
||||
|
||||
words = re.findall(r"\b\w+\b", query.lower())
|
||||
keywords = [w for w in words if w not in stop_words and len(w) > 1]
|
||||
|
||||
return keywords
|
||||
|
||||
def _calculate_keyword_score(self, keywords: List[str], content: str) -> float:
|
||||
"""Calculate keyword matching score.
|
||||
|
||||
Args:
|
||||
keywords: List of keywords
|
||||
content: Document content
|
||||
|
||||
Returns:
|
||||
Score between 0 and 1
|
||||
"""
|
||||
if not keywords:
|
||||
return 0.0
|
||||
|
||||
content_lower = content.lower()
|
||||
|
||||
matched_keywords = sum(1 for kw in keywords if kw in content_lower)
|
||||
|
||||
keyword_density = matched_keywords / len(keywords)
|
||||
|
||||
exact_phrase = " ".join(keywords)
|
||||
if exact_phrase in content_lower:
|
||||
return min(1.0, keyword_density + 0.3)
|
||||
|
||||
return keyword_density
|
||||
|
||||
def _generate_highlights(self, query: str, content: str) -> List[str]:
|
||||
"""Generate highlight snippets for a query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
content: Document content
|
||||
|
||||
Returns:
|
||||
List of highlight strings
|
||||
"""
|
||||
keywords = self._extract_keywords(query)
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
highlights = []
|
||||
content_lower = content.lower()
|
||||
|
||||
for keyword in keywords[:3]:
|
||||
pattern = re.compile(re.escape(keyword), re.IGNORECASE)
|
||||
for match in pattern.finditer(content_lower):
|
||||
start = max(0, match.start() - 30)
|
||||
end = min(len(content), match.end() + 30)
|
||||
snippet = content[start:end]
|
||||
if start > 0:
|
||||
snippet = "..." + snippet
|
||||
if end < len(content):
|
||||
snippet = snippet + "..."
|
||||
highlights.append(snippet)
|
||||
|
||||
return highlights[:5]
|
||||
|
||||
def index(
|
||||
self,
|
||||
path: Path,
|
||||
doc_type: str = "all",
|
||||
recursive: bool = False,
|
||||
batch_size: int = 32,
|
||||
) -> int:
|
||||
"""Index documents from a path.
|
||||
|
||||
Args:
|
||||
path: Path to file or directory
|
||||
doc_type: Type of documents (openapi, readme, code, all)
|
||||
recursive: Search recursively
|
||||
batch_size: Batch size for indexing
|
||||
|
||||
Returns:
|
||||
Number of documents indexed
|
||||
"""
|
||||
from local_api_docs_search.indexer.openapi import OpenAPIIndexer
|
||||
from local_api_docs_search.indexer.readme import READMEIndexer
|
||||
from local_api_docs_search.indexer.code import CodeIndexer
|
||||
|
||||
indexers = []
|
||||
|
||||
if doc_type in ("openapi", "all"):
|
||||
indexers.append(OpenAPIIndexer())
|
||||
if doc_type in ("readme", "all"):
|
||||
indexers.append(READMEIndexer())
|
||||
if doc_type in ("code", "all"):
|
||||
indexers.append(CodeIndexer())
|
||||
|
||||
all_documents = []
|
||||
|
||||
for indexer in indexers:
|
||||
documents = indexer.index(path, recursive=recursive, batch_size=batch_size)
|
||||
all_documents.extend(documents)
|
||||
|
||||
if not all_documents:
|
||||
logger.warning("No documents found to index")
|
||||
return 0
|
||||
|
||||
texts = [doc.content for doc in all_documents]
|
||||
embeddings = self._embedding_manager.embed(texts, show_progress=True)
|
||||
|
||||
self._vector_store.add_documents(all_documents, embeddings, batch_size=batch_size)
|
||||
|
||||
logger.info(f"Indexed {len(all_documents)} documents")
|
||||
return len(all_documents)
|
||||
|
||||
def get_stats(self):
|
||||
"""Get index statistics.
|
||||
|
||||
Returns:
|
||||
IndexStats object
|
||||
"""
|
||||
return self._vector_store.get_stats()
|
||||
|
||||
def clear_index(self) -> bool:
|
||||
"""Clear the entire index.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return self._vector_store.delete_index()
|
||||
|
||||
def list_documents(
|
||||
self, source_type: Optional[SourceType] = None, limit: int = 100
|
||||
) -> List[Document]:
|
||||
"""List indexed documents.
|
||||
|
||||
Args:
|
||||
source_type: Optional filter by source type
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
docs = self._vector_store.get_all_documents(limit=limit * 2)
|
||||
|
||||
if source_type:
|
||||
docs = [d for d in docs if d.source_type == source_type]
|
||||
|
||||
return docs[:limit]
|
||||
Reference in New Issue
Block a user