diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..cbe6238 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,51 @@ +"""Tests for search functionality.""" + +import os +import tempfile + +import pytest + +from snip.db.database import Database +from snip.search.engine import SearchEngine + + +@pytest.fixture +def db(): + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + database = Database(db_path) + database.init_db() + yield database + os.unlink(db_path) + + +@pytest.fixture +def search_engine(db): + return SearchEngine(db) + + +def test_search_basic(search_engine, db): + """Test basic search.""" + db.add_snippet(title="Hello World", code="print('hello')", language="python") + db.add_snippet(title="Goodbye", code="print('bye')", language="python") + + results = search_engine.search("hello") + assert len(results) >= 1 + + +def test_search_with_language_filter(search_engine, db): + """Test search with language filter.""" + db.add_snippet(title="Python Hello", code="print('hello')", language="python") + db.add_snippet(title="JS Hello", code="console.log('hello')", language="javascript") + + results = search_engine.search("hello", language="python") + assert all(r["language"] == "python" for r in results) + + +def test_search_ranking(search_engine, db): + """Test that search results are ranked.""" + db.add_snippet(title="Hello Function", code="def hello(): pass", language="python") + db.add_snippet(title="Hello Class", code="class Hello: pass", language="python") + + results = search_engine.search("hello") + assert len(results) >= 1