Initial upload: shell-history-semantic-search v0.1.0
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
This commit is contained in:
118
src/shell_history_search/core/search.py
Normal file
118
src/shell_history_search/core/search.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..db import init_database, get_db_path
|
||||||
|
from .embeddings import EmbeddingService
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
command: str
|
||||||
|
shell_type: str
|
||||||
|
timestamp: Optional[int]
|
||||||
|
similarity: float
|
||||||
|
command_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEngine:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: Optional[sqlite3.Connection] = None,
|
||||||
|
embedding_service: Optional[EmbeddingService] = None,
|
||||||
|
):
|
||||||
|
if db_path is None:
|
||||||
|
self._db_path: Optional[Path] = get_db_path()
|
||||||
|
self._conn = init_database(self._db_path)
|
||||||
|
else:
|
||||||
|
self._conn = db_path
|
||||||
|
self._db_path = None
|
||||||
|
|
||||||
|
self._embedding_service = embedding_service or EmbeddingService()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_service(self) -> EmbeddingService:
|
||||||
|
return self._embedding_service
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
shell_type: Optional[str] = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
query_embedding = self._embedding_service.encode_single(query)
|
||||||
|
|
||||||
|
base_query = """
|
||||||
|
SELECT c.id, c.command, c.shell_type, c.timestamp, e.embedding
|
||||||
|
FROM commands c
|
||||||
|
JOIN embeddings e ON c.id = e.command_id
|
||||||
|
"""
|
||||||
|
params: list = []
|
||||||
|
|
||||||
|
if shell_type:
|
||||||
|
base_query += " WHERE c.shell_type = ?"
|
||||||
|
params.append(shell_type)
|
||||||
|
|
||||||
|
base_query += " ORDER BY c.id DESC"
|
||||||
|
|
||||||
|
cursor = self._conn.execute(base_query, params)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
results: list[tuple[SearchResult, float]] = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
cmd_id = row["id"]
|
||||||
|
command = row["command"]
|
||||||
|
cmd_shell_type = row["shell_type"]
|
||||||
|
timestamp = row["timestamp"]
|
||||||
|
embedding_blob = row["embedding"]
|
||||||
|
|
||||||
|
stored_embedding = EmbeddingService.blob_to_embedding(
|
||||||
|
embedding_blob, self._embedding_service.embedding_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity = EmbeddingService.cosine_similarity(
|
||||||
|
query_embedding, stored_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
result = SearchResult(
|
||||||
|
command=command,
|
||||||
|
shell_type=cmd_shell_type,
|
||||||
|
timestamp=timestamp,
|
||||||
|
similarity=similarity,
|
||||||
|
command_id=cmd_id,
|
||||||
|
)
|
||||||
|
results.append((result, similarity))
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
top_results = [result for result, _ in results[:limit]]
|
||||||
|
|
||||||
|
return top_results
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
cursor = self._conn.execute("""
|
||||||
|
SELECT shell_type, COUNT(*) as count FROM commands GROUP BY shell_type
|
||||||
|
""")
|
||||||
|
shell_counts = {row["shell_type"]: row["count"] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
cursor = self._conn.execute("SELECT COUNT(*) as count FROM commands")
|
||||||
|
total_commands = cursor.fetchone()["count"]
|
||||||
|
|
||||||
|
cursor = self._conn.execute("SELECT COUNT(*) as count FROM embeddings")
|
||||||
|
total_embeddings = cursor.fetchone()["count"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_commands": total_commands,
|
||||||
|
"total_embeddings": total_embeddings,
|
||||||
|
"shell_counts": shell_counts,
|
||||||
|
"embedding_model": self._embedding_service.model_name,
|
||||||
|
"embedding_dim": self._embedding_service.embedding_dim,
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_all(self) -> None:
|
||||||
|
self._conn.execute("DELETE FROM embeddings")
|
||||||
|
self._conn.execute("DELETE FROM commands")
|
||||||
|
self._conn.execute("DELETE FROM index_state")
|
||||||
|
self._conn.commit()
|
||||||
Reference in New Issue
Block a user