52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
"""Tests for search functionality."""
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from snip.db.database import Database
|
|
from snip.search.engine import SearchEngine
|
|
|
|
|
|
@pytest.fixture
|
|
def db():
|
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
|
db_path = f.name
|
|
database = Database(db_path)
|
|
database.init_db()
|
|
yield database
|
|
os.unlink(db_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def search_engine(db):
|
|
return SearchEngine(db)
|
|
|
|
|
|
def test_search_basic(search_engine, db):
|
|
"""Test basic search."""
|
|
db.add_snippet(title="Hello World", code="print('hello')", language="python")
|
|
db.add_snippet(title="Goodbye", code="print('bye')", language="python")
|
|
|
|
results = search_engine.search("hello")
|
|
assert len(results) >= 1
|
|
|
|
|
|
def test_search_with_language_filter(search_engine, db):
|
|
"""Test search with language filter."""
|
|
db.add_snippet(title="Python Hello", code="print('hello')", language="python")
|
|
db.add_snippet(title="JS Hello", code="console.log('hello')", language="javascript")
|
|
|
|
results = search_engine.search("hello", language="python")
|
|
assert all(r["language"] == "python" for r in results)
|
|
|
|
|
|
def test_search_ranking(search_engine, db):
|
|
"""Test that search results are ranked."""
|
|
db.add_snippet(title="Hello Function", code="def hello(): pass", language="python")
|
|
db.add_snippet(title="Hello Class", code="class Hello: pass", language="python")
|
|
|
|
results = search_engine.search("hello")
|
|
assert len(results) >= 1
|