diff --git a/tests/test_search.py b/tests/test_search.py index f237aec..9b58834 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,52 +1,67 @@ -from shell_history_search.core import SearchEngine, SearchResult -from shell_history_search.db import init_database +import pytest +from datetime import datetime + +from shell_history_search.search import SemanticSearch, SearchResult +from shell_history_search.database import Database +from shell_history_search.models import HistoryEntry -class TestSearchEngine: - def test_init_with_db_path(self, temp_db_path): - conn = init_database(temp_db_path) - engine = SearchEngine(db_path=conn) - stats = engine.get_stats() +class TestSemanticSearch: + @pytest.fixture + def search(self, tmp_path): + db_path = tmp_path / "test.db" + db = Database(str(db_path)) + search = SemanticSearch(db) + entry = HistoryEntry( + id=1, + timestamp=1234567890.0, + command="ls -la", + exit_code=0, + shell="bash", + working_dir="/home" + ) + db.add_entry(entry) + yield search + db.close() - assert stats["total_commands"] == 0 - assert stats["total_embeddings"] == 0 - conn.close() + def test_semantic_search_basic(self, search): + results = search.search("list files", limit=5) + assert isinstance(results, list) - def test_get_stats_empty(self, temp_db_path): - conn = init_database(temp_db_path) - engine = SearchEngine(db_path=conn) - stats = engine.get_stats() + def test_semantic_search_with_filter(self, search): + results = search.search("list", shell="bash", limit=5) + assert isinstance(results, list) - assert stats["total_commands"] == 0 - assert stats["total_embeddings"] == 0 - assert stats["shell_counts"] == {} - assert stats["embedding_model"] == "all-MiniLM-L6-v2" - assert stats["embedding_dim"] == 384 - conn.close() + def test_search_returns_search_results(self, search): + results = search.search("ls", limit=5) + for result in results: + assert isinstance(result, SearchResult) - def test_clear_all(self, temp_db_path): - conn = init_database(temp_db_path) - engine = SearchEngine(db_path=conn) - engine.clear_all() + def test_search_empty_query(self, search): + results = search.search("", limit=5) + assert isinstance(results, list) - stats = engine.get_stats() - assert stats["total_commands"] == 0 - assert stats["total_embeddings"] == 0 - conn.close() + def test_search_with_time_filter(self, search): + results = search.search("ls", limit=5) + assert isinstance(results, list) class TestSearchResult: def test_search_result_creation(self): result = SearchResult( - command="git commit", - shell_type="bash", - timestamp=1700000000, - similarity=0.95, - command_id=1, + command="ls -la", + timestamp=1234567890.0, + score=0.95, + shell="bash" ) + assert result.command == "ls -la" + assert result.score == 0.95 - assert result.command == "git commit" - assert result.shell_type == "bash" - assert result.timestamp == 1700000000 - assert result.similarity == 0.95 - assert result.command_id == 1 + def test_search_result_sorting(self): + results = [ + SearchResult("cmd1", 123.0, 0.5, "bash"), + SearchResult("cmd2", 124.0, 0.9, "bash"), + SearchResult("cmd3", 125.0, 0.7, "bash"), + ] + sorted_results = sorted(results) + assert sorted_results[0].score >= sorted_results[1].score >= sorted_results[2].score