diff --git a/src/search/searcher.py b/src/search/searcher.py new file mode 100644 index 0000000..cfca299 --- /dev/null +++ b/src/search/searcher.py @@ -0,0 +1,368 @@ +"""Search logic with semantic similarity and hybrid search.""" + +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +from src.models.document import Document, SearchResult, SourceType +from src.search.embeddings import EmbeddingManager +from src.search.vectorstore import VectorStore +from src.utils.config import get_config + +logger = logging.getLogger(__name__) + + +@dataclass +class SearchOptions: + """Options for search operations.""" + + limit: int = 10 + source_type: Optional[SourceType] = None + min_score: float = 0.0 + include_scores: bool = True + + +class Searcher: + """Main search class for semantic and hybrid search.""" + + def __init__( + self, + embedding_manager: Optional[EmbeddingManager] = None, + vector_store: Optional[VectorStore] = None, + config_path: Optional[Path] = None, + ): + """Initialize the searcher. + + Args: + embedding_manager: Embedding manager instance + vector_store: Vector store instance + config_path: Path to configuration file + """ + config = get_config(config_path) + + self._embedding_manager = embedding_manager or EmbeddingManager( + model_name=config.model_name, + device=config.embedding_device, + cache_dir=config.chroma_persist_dir / ".cache", + ) + + self._vector_store = vector_store or VectorStore( + persist_dir=config.chroma_persist_dir, + ) + + self._config = config + + def search( + self, query: str, options: Optional[SearchOptions] = None + ) -> List[SearchResult]: + """Perform semantic search for a query. + + Args: + query: Search query string + options: Search options + + Returns: + List of SearchResult objects + """ + if options is None: + options = SearchOptions(limit=self._config.default_limit) + + if not query.strip(): + return [] + + try: + query_embedding = self._embedding_manager.embed_query(query) + + results = self._vector_store.search( + query_embedding=query_embedding, + n_results=options.limit * 2, + source_type=options.source_type, + ) + + search_results = [] + for result in results: + if options.min_score > 0 and result["score"] < options.min_score: + continue + + doc = Document( + id=result["id"], + content=result["content"], + source_type=SourceType(result["metadata"]["source_type"]), + title=result["metadata"]["title"], + file_path=result["metadata"]["file_path"], + metadata={ + k: v + for k, v in result["metadata"].items() + if k not in ["source_type", "title", "file_path"] + }, + ) + + highlights = self._generate_highlights(query, result["content"]) + + search_results.append( + SearchResult( + document=doc, + score=result["score"], + highlights=highlights, + ) + ) + + if len(search_results) >= options.limit: + break + + return search_results + + except Exception as e: + logger.error(f"Search failed for query '{query}': {e}") + return [] + + def hybrid_search( + self, query: str, options: Optional[SearchOptions] = None + ) -> List[SearchResult]: + """Perform hybrid search combining semantic and keyword search. + + Args: + query: Search query string + options: Search options + + Returns: + List of SearchResult objects sorted by combined relevance + """ + if options is None: + options = SearchOptions(limit=self._config.default_limit) + + semantic_results = self.search(query, options) + + if not query.strip(): + return semantic_results + + keyword_results = self._keyword_search(query, options) + + combined = {} + for result in semantic_results: + combined[result.document.id] = result + + for result in keyword_results: + if result.document.id in combined: + existing = combined[result.document.id] + combined[result.document.id] = SearchResult( + document=result.document, + score=(existing.score + result.score) / 2, + highlights=list(set(existing.highlights + result.highlights)), + ) + else: + combined[result.document.id] = result + + sorted_results = sorted( + combined.values(), key=lambda r: r.score, reverse=True + ) + + return sorted_results[: options.limit] + + def _keyword_search( + self, query: str, options: SearchOptions + ) -> List[SearchResult]: + """Perform keyword-based search. + + Args: + query: Search query + options: Search options + + Returns: + List of SearchResult objects + """ + keywords = self._extract_keywords(query) + + if not keywords: + return [] + + try: + all_docs = self._vector_store.get_all_documents(limit=1000) + + results = [] + for doc in all_docs: + if options.source_type and doc.source_type != options.source_type: + continue + + keyword_score = self._calculate_keyword_score(keywords, doc.content) + if keyword_score > 0: + highlights = self._generate_highlights(query, doc.content) + results.append( + SearchResult( + document=doc, + score=keyword_score, + highlights=highlights, + ) + ) + + results.sort(key=lambda r: r.score, reverse=True) + return results[: options.limit] + + except Exception as e: + logger.error(f"Keyword search failed: {e}") + return [] + + def _extract_keywords(self, query: str) -> List[str]: + """Extract keywords from a query. + + Args: + query: Search query + + Returns: + List of keywords + """ + stop_words = { + "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", + "of", "with", "by", "from", "up", "about", "into", "through", "during", + "how", "what", "when", "where", "why", "which", "who", "whom", + "this", "that", "these", "those", "is", "are", "was", "were", "be", + "been", "being", "have", "has", "had", "do", "does", "did", "will", + "would", "could", "should", "may", "might", "must", "shall", "can", + } + + words = re.findall(r"\b\w+\b", query.lower()) + keywords = [w for w in words if w not in stop_words and len(w) > 1] + + return keywords + + def _calculate_keyword_score(self, keywords: List[str], content: str) -> float: + """Calculate keyword matching score. + + Args: + keywords: List of keywords + content: Document content + + Returns: + Score between 0 and 1 + """ + if not keywords: + return 0.0 + + content_lower = content.lower() + + matched_keywords = sum(1 for kw in keywords if kw in content_lower) + + keyword_density = matched_keywords / len(keywords) + + exact_phrase = " ".join(keywords) + if exact_phrase in content_lower: + return min(1.0, keyword_density + 0.3) + + return keyword_density + + def _generate_highlights(self, query: str, content: str) -> List[str]: + """Generate highlight snippets for a query. + + Args: + query: Search query + content: Document content + + Returns: + List of highlight strings + """ + keywords = self._extract_keywords(query) + if not keywords: + return [] + + highlights = [] + content_lower = content.lower() + + for keyword in keywords[:3]: + pattern = re.compile(re.escape(keyword), re.IGNORECASE) + for match in pattern.finditer(content_lower): + start = max(0, match.start() - 30) + end = min(len(content), match.end() + 30) + snippet = content[start:end] + if start > 0: + snippet = "..." + snippet + if end < len(content): + snippet = snippet + "..." + highlights.append(snippet) + + return highlights[:5] + + def index( + self, + path: Path, + doc_type: str = "all", + recursive: bool = False, + batch_size: int = 32, + ) -> int: + """Index documents from a path. + + Args: + path: Path to file or directory + doc_type: Type of documents (openapi, readme, code, all) + recursive: Search recursively + batch_size: Batch size for indexing + + Returns: + Number of documents indexed + """ + from src.indexer.openapi import OpenAPIIndexer + from src.indexer.readme import READMEIndexer + from src.indexer.code import CodeIndexer + + indexers = [] + + if doc_type in ("openapi", "all"): + indexers.append(OpenAPIIndexer()) + if doc_type in ("readme", "all"): + indexers.append(READMEIndexer()) + if doc_type in ("code", "all"): + indexers.append(CodeIndexer()) + + all_documents = [] + + for indexer in indexers: + documents = indexer.index(path, recursive=recursive, batch_size=batch_size) + all_documents.extend(documents) + + if not all_documents: + logger.warning("No documents found to index") + return 0 + + texts = [doc.content for doc in all_documents] + embeddings = self._embedding_manager.embed(texts, show_progress=True) + + self._vector_store.add_documents(all_documents, embeddings, batch_size=batch_size) + + logger.info(f"Indexed {len(all_documents)} documents") + return len(all_documents) + + def get_stats(self): + """Get index statistics. + + Returns: + IndexStats object + """ + return self._vector_store.get_stats() + + def clear_index(self) -> bool: + """Clear the entire index. + + Returns: + True if successful + """ + return self._vector_store.delete_index() + + def list_documents( + self, source_type: Optional[SourceType] = None, limit: int = 100 + ) -> List[Document]: + """List indexed documents. + + Args: + source_type: Optional filter by source type + limit: Maximum results + + Returns: + List of Document objects + """ + docs = self._vector_store.get_all_documents(limit=limit * 2) + + if source_type: + docs = [d for d in docs if d.source_type == source_type] + + return docs[:limit]