Initial upload: shell-history-semantic-search v0.1.0
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
2026-03-22 18:15:30 +00:00
parent 7fc45e0849
commit 1d15227281

View File

@@ -0,0 +1,118 @@
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from ..db import init_database, get_db_path
from .embeddings import EmbeddingService
@dataclass
class SearchResult:
command: str
shell_type: str
timestamp: Optional[int]
similarity: float
command_id: int
class SearchEngine:
def __init__(
self,
db_path: Optional[sqlite3.Connection] = None,
embedding_service: Optional[EmbeddingService] = None,
):
if db_path is None:
self._db_path: Optional[Path] = get_db_path()
self._conn = init_database(self._db_path)
else:
self._conn = db_path
self._db_path = None
self._embedding_service = embedding_service or EmbeddingService()
@property
def embedding_service(self) -> EmbeddingService:
return self._embedding_service
def search(
self,
query: str,
limit: int = 10,
shell_type: Optional[str] = None,
) -> list[SearchResult]:
query_embedding = self._embedding_service.encode_single(query)
base_query = """
SELECT c.id, c.command, c.shell_type, c.timestamp, e.embedding
FROM commands c
JOIN embeddings e ON c.id = e.command_id
"""
params: list = []
if shell_type:
base_query += " WHERE c.shell_type = ?"
params.append(shell_type)
base_query += " ORDER BY c.id DESC"
cursor = self._conn.execute(base_query, params)
rows = cursor.fetchall()
results: list[tuple[SearchResult, float]] = []
for row in rows:
cmd_id = row["id"]
command = row["command"]
cmd_shell_type = row["shell_type"]
timestamp = row["timestamp"]
embedding_blob = row["embedding"]
stored_embedding = EmbeddingService.blob_to_embedding(
embedding_blob, self._embedding_service.embedding_dim
)
similarity = EmbeddingService.cosine_similarity(
query_embedding, stored_embedding
)
result = SearchResult(
command=command,
shell_type=cmd_shell_type,
timestamp=timestamp,
similarity=similarity,
command_id=cmd_id,
)
results.append((result, similarity))
results.sort(key=lambda x: x[1], reverse=True)
top_results = [result for result, _ in results[:limit]]
return top_results
def get_stats(self) -> dict:
cursor = self._conn.execute("""
SELECT shell_type, COUNT(*) as count FROM commands GROUP BY shell_type
""")
shell_counts = {row["shell_type"]: row["count"] for row in cursor.fetchall()}
cursor = self._conn.execute("SELECT COUNT(*) as count FROM commands")
total_commands = cursor.fetchone()["count"]
cursor = self._conn.execute("SELECT COUNT(*) as count FROM embeddings")
total_embeddings = cursor.fetchone()["count"]
return {
"total_commands": total_commands,
"total_embeddings": total_embeddings,
"shell_counts": shell_counts,
"embedding_model": self._embedding_service.model_name,
"embedding_dim": self._embedding_service.embedding_dim,
}
def clear_all(self) -> None:
self._conn.execute("DELETE FROM embeddings")
self._conn.execute("DELETE FROM commands")
self._conn.execute("DELETE FROM index_state")
self._conn.commit()