Initial upload: shell-history-semantic-search v0.1.0
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
2026-03-22 18:15:30 +00:00
parent 1d15227281
commit 55467b198f

View File

@@ -0,0 +1,129 @@
import sqlite3
import logging
from pathlib import Path
from typing import Optional
from ..parsers import HistoryEntry
from ..parsers.factory import get_all_parsers
from ..db import init_database, get_db_path
from .embeddings import EmbeddingService
logger = logging.getLogger(__name__)
class IndexingService:
def __init__(
self,
db_path: Optional[sqlite3.Connection] = None,
embedding_service: Optional[EmbeddingService] = None,
):
if db_path is None:
self._db_path: Optional[Path] = get_db_path()
self._conn = init_database(self._db_path)
else:
self._conn = db_path
self._db_path = None
self._embedding_service = embedding_service or EmbeddingService()
@property
def embedding_service(self) -> EmbeddingService:
return self._embedding_service
def index_shell_history(self, shell_type: Optional[str] = None) -> dict:
if shell_type:
parsers = [p for p in get_all_parsers() if p.shell_type == shell_type]
if not parsers:
raise ValueError(f"Unknown shell type: {shell_type}")
else:
parsers = get_all_parsers()
total_indexed = 0
total_skipped = 0
for parser in parsers:
indexed, skipped = self._index_parser(parser)
total_indexed += indexed
total_skipped += skipped
logger.info(
f"Indexed {indexed} commands from {parser.shell_type}, "
f"skipped {skipped} duplicates"
)
return {
"total_indexed": total_indexed,
"total_skipped": total_skipped,
}
def _index_parser(self, parser) -> tuple[int, int]:
history_path = parser.get_history_path()
if not history_path.exists():
logger.warning(
f"History file not found for {parser.shell_type}: {history_path}"
)
return 0, 0
existing_hashes = self._get_existing_hashes()
entries_to_embed: list[HistoryEntry] = []
indexed = 0
skipped = 0
for entry in parser.parse(history_path):
if entry.command_hash in existing_hashes:
skipped += 1
continue
self._conn.execute(
"""
INSERT INTO commands (command, shell_type, timestamp, hostname, command_hash, indexed_at)
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
""",
(
entry.command,
entry.shell_type,
entry.timestamp,
entry.hostname,
entry.command_hash,
),
)
entries_to_embed.append(entry)
existing_hashes.add(entry.command_hash)
indexed += 1
if entries_to_embed:
self._generate_and_store_embeddings(entries_to_embed)
self._conn.commit()
return indexed, skipped
def _get_existing_hashes(self) -> set:
cursor = self._conn.execute("SELECT command_hash FROM commands")
return {row["command_hash"] for row in cursor.fetchall()}
def _generate_and_store_embeddings(self, entries: list[HistoryEntry]) -> None:
commands = [entry.command for entry in entries]
embeddings = self._embedding_service.encode(commands)
for entry, embedding in zip(entries, embeddings):
cursor = self._conn.execute(
"""
SELECT id FROM commands WHERE command_hash = ?
""",
(entry.command_hash,),
)
row = cursor.fetchone()
if row:
cmd_id = row["id"]
embedding_blob = EmbeddingService.embedding_to_blob(embedding)
self._conn.execute(
"""
INSERT OR REPLACE INTO embeddings (command_id, embedding, model_name)
VALUES (?, ?, ?)
""",
(cmd_id, embedding_blob, self._embedding_service.model_name),
)