diff --git a/src/shell_history_search/core/indexing.py b/src/shell_history_search/core/indexing.py new file mode 100644 index 0000000..3aa7da9 --- /dev/null +++ b/src/shell_history_search/core/indexing.py @@ -0,0 +1,129 @@ +import sqlite3 +import logging +from pathlib import Path +from typing import Optional + +from ..parsers import HistoryEntry +from ..parsers.factory import get_all_parsers +from ..db import init_database, get_db_path +from .embeddings import EmbeddingService + +logger = logging.getLogger(__name__) + + +class IndexingService: + def __init__( + self, + db_path: Optional[sqlite3.Connection] = None, + embedding_service: Optional[EmbeddingService] = None, + ): + if db_path is None: + self._db_path: Optional[Path] = get_db_path() + self._conn = init_database(self._db_path) + else: + self._conn = db_path + self._db_path = None + + self._embedding_service = embedding_service or EmbeddingService() + + @property + def embedding_service(self) -> EmbeddingService: + return self._embedding_service + + def index_shell_history(self, shell_type: Optional[str] = None) -> dict: + if shell_type: + parsers = [p for p in get_all_parsers() if p.shell_type == shell_type] + if not parsers: + raise ValueError(f"Unknown shell type: {shell_type}") + else: + parsers = get_all_parsers() + + total_indexed = 0 + total_skipped = 0 + + for parser in parsers: + indexed, skipped = self._index_parser(parser) + total_indexed += indexed + total_skipped += skipped + logger.info( + f"Indexed {indexed} commands from {parser.shell_type}, " + f"skipped {skipped} duplicates" + ) + + return { + "total_indexed": total_indexed, + "total_skipped": total_skipped, + } + + def _index_parser(self, parser) -> tuple[int, int]: + history_path = parser.get_history_path() + + if not history_path.exists(): + logger.warning( + f"History file not found for {parser.shell_type}: {history_path}" + ) + return 0, 0 + + existing_hashes = self._get_existing_hashes() + entries_to_embed: list[HistoryEntry] = [] + + indexed = 0 + skipped = 0 + + for entry in parser.parse(history_path): + if entry.command_hash in existing_hashes: + skipped += 1 + continue + + self._conn.execute( + """ + INSERT INTO commands (command, shell_type, timestamp, hostname, command_hash, indexed_at) + VALUES (?, ?, ?, ?, ?, strftime('%s', 'now')) + """, + ( + entry.command, + entry.shell_type, + entry.timestamp, + entry.hostname, + entry.command_hash, + ), + ) + entries_to_embed.append(entry) + + existing_hashes.add(entry.command_hash) + indexed += 1 + + if entries_to_embed: + self._generate_and_store_embeddings(entries_to_embed) + + self._conn.commit() + + return indexed, skipped + + def _get_existing_hashes(self) -> set: + cursor = self._conn.execute("SELECT command_hash FROM commands") + return {row["command_hash"] for row in cursor.fetchall()} + + def _generate_and_store_embeddings(self, entries: list[HistoryEntry]) -> None: + commands = [entry.command for entry in entries] + embeddings = self._embedding_service.encode(commands) + + for entry, embedding in zip(entries, embeddings): + cursor = self._conn.execute( + """ + SELECT id FROM commands WHERE command_hash = ? + """, + (entry.command_hash,), + ) + row = cursor.fetchone() + if row: + cmd_id = row["id"] + embedding_blob = EmbeddingService.embedding_to_blob(embedding) + + self._conn.execute( + """ + INSERT OR REPLACE INTO embeddings (command_id, embedding, model_name) + VALUES (?, ?, ?) + """, + (cmd_id, embedding_blob, self._embedding_service.model_name), + ) \ No newline at end of file