Initial upload: shell-history-semantic-search v0.1.0
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
This commit is contained in:
129
src/shell_history_search/core/indexing.py
Normal file
129
src/shell_history_search/core/indexing.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import sqlite3
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..parsers import HistoryEntry
|
||||||
|
from ..parsers.factory import get_all_parsers
|
||||||
|
from ..db import init_database, get_db_path
|
||||||
|
from .embeddings import EmbeddingService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IndexingService:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: Optional[sqlite3.Connection] = None,
|
||||||
|
embedding_service: Optional[EmbeddingService] = None,
|
||||||
|
):
|
||||||
|
if db_path is None:
|
||||||
|
self._db_path: Optional[Path] = get_db_path()
|
||||||
|
self._conn = init_database(self._db_path)
|
||||||
|
else:
|
||||||
|
self._conn = db_path
|
||||||
|
self._db_path = None
|
||||||
|
|
||||||
|
self._embedding_service = embedding_service or EmbeddingService()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_service(self) -> EmbeddingService:
|
||||||
|
return self._embedding_service
|
||||||
|
|
||||||
|
def index_shell_history(self, shell_type: Optional[str] = None) -> dict:
|
||||||
|
if shell_type:
|
||||||
|
parsers = [p for p in get_all_parsers() if p.shell_type == shell_type]
|
||||||
|
if not parsers:
|
||||||
|
raise ValueError(f"Unknown shell type: {shell_type}")
|
||||||
|
else:
|
||||||
|
parsers = get_all_parsers()
|
||||||
|
|
||||||
|
total_indexed = 0
|
||||||
|
total_skipped = 0
|
||||||
|
|
||||||
|
for parser in parsers:
|
||||||
|
indexed, skipped = self._index_parser(parser)
|
||||||
|
total_indexed += indexed
|
||||||
|
total_skipped += skipped
|
||||||
|
logger.info(
|
||||||
|
f"Indexed {indexed} commands from {parser.shell_type}, "
|
||||||
|
f"skipped {skipped} duplicates"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_indexed": total_indexed,
|
||||||
|
"total_skipped": total_skipped,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _index_parser(self, parser) -> tuple[int, int]:
|
||||||
|
history_path = parser.get_history_path()
|
||||||
|
|
||||||
|
if not history_path.exists():
|
||||||
|
logger.warning(
|
||||||
|
f"History file not found for {parser.shell_type}: {history_path}"
|
||||||
|
)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
existing_hashes = self._get_existing_hashes()
|
||||||
|
entries_to_embed: list[HistoryEntry] = []
|
||||||
|
|
||||||
|
indexed = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for entry in parser.parse(history_path):
|
||||||
|
if entry.command_hash in existing_hashes:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO commands (command, shell_type, timestamp, hostname, command_hash, indexed_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
entry.command,
|
||||||
|
entry.shell_type,
|
||||||
|
entry.timestamp,
|
||||||
|
entry.hostname,
|
||||||
|
entry.command_hash,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
entries_to_embed.append(entry)
|
||||||
|
|
||||||
|
existing_hashes.add(entry.command_hash)
|
||||||
|
indexed += 1
|
||||||
|
|
||||||
|
if entries_to_embed:
|
||||||
|
self._generate_and_store_embeddings(entries_to_embed)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
|
return indexed, skipped
|
||||||
|
|
||||||
|
def _get_existing_hashes(self) -> set:
|
||||||
|
cursor = self._conn.execute("SELECT command_hash FROM commands")
|
||||||
|
return {row["command_hash"] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
def _generate_and_store_embeddings(self, entries: list[HistoryEntry]) -> None:
|
||||||
|
commands = [entry.command for entry in entries]
|
||||||
|
embeddings = self._embedding_service.encode(commands)
|
||||||
|
|
||||||
|
for entry, embedding in zip(entries, embeddings):
|
||||||
|
cursor = self._conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT id FROM commands WHERE command_hash = ?
|
||||||
|
""",
|
||||||
|
(entry.command_hash,),
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
cmd_id = row["id"]
|
||||||
|
embedding_blob = EmbeddingService.embedding_to_blob(embedding)
|
||||||
|
|
||||||
|
self._conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO embeddings (command_id, embedding, model_name)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(cmd_id, embedding_blob, self._embedding_service.model_name),
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user