diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index ab18d3a..01c6e79 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -1,5 +1,6 @@ import pytest -from shell_history_search.embeddings import EmbeddingService +import numpy as np +from shell_history_search.core.embeddings import EmbeddingService class TestEmbeddingService: @@ -7,17 +8,23 @@ class TestEmbeddingService: def service(self): return EmbeddingService() - def test_get_embedding(self, service): - embedding = service.get_embedding("test command") - assert isinstance(embedding, list) + def test_encode_single(self, service): + embedding = service.encode_single("test command") + assert isinstance(embedding, np.ndarray) assert len(embedding) == 384 - def test_get_embedding_consistency(self, service): - emb1 = service.get_embedding("test command") - emb2 = service.get_embedding("test command") - assert emb1 == emb2 + def test_encode_consistency(self, service): + emb1 = service.encode_single("test command") + emb2 = service.encode_single("test command") + assert np.allclose(emb1, emb2) - def test_get_embedding_different_commands(self, service): - emb1 = service.get_embedding("command one") - emb2 = service.get_embedding("command two") - assert emb1 != emb2 + def test_encode_different_commands(self, service): + emb1 = service.encode_single("command one") + emb2 = service.encode_single("command two") + assert not np.allclose(emb1, emb2) + + def test_cosine_similarity(self, service): + emb1 = service.encode_single("list files") + emb2 = service.encode_single("show directory contents") + similarity = EmbeddingService.cosine_similarity(emb1, emb2) + assert -1.0 <= similarity <= 1.0