Files
shell-history-semantic-search/tests/test_embeddings.py

90 lines
3.2 KiB
Python

import numpy as np
from shell_history_search.core import EmbeddingService
class TestEmbeddingService:
def test_init_default(self):
service = EmbeddingService()
assert service.model_name == "all-MiniLM-L6-v2"
assert service.device == "cpu"
def test_init_custom_model(self, temp_cache_dir):
service = EmbeddingService(
model_name="all-MiniLM-L6-v2",
cache_dir=temp_cache_dir,
)
assert service.model_name == "all-MiniLM-L6-v2"
assert service.cache_dir == temp_cache_dir
def test_embedding_dim(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
assert service.embedding_dim == 384
def test_encode_single(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
embedding = service.encode_single("git commit")
assert isinstance(embedding, np.ndarray)
assert embedding.shape == (384,)
assert embedding.dtype == np.float32
def test_encode_batch(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
embeddings = service.encode(["git add .", "git commit", "git push"])
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (3, 384)
assert embeddings.dtype == np.float32
def test_encode_empty_list(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
embeddings = service.encode([])
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (0,)
def test_encode_returns_normalized(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
embedding = service.encode_single("test command")
norm = np.linalg.norm(embedding)
assert 0.99 < norm <= 1.01
def test_embedding_to_blob(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
embedding = service.encode_single("test")
blob = EmbeddingService.embedding_to_blob(embedding)
assert isinstance(blob, bytes)
assert len(blob) == 384 * 4
def test_blob_to_embedding(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
embedding = service.encode_single("test")
blob = EmbeddingService.embedding_to_blob(embedding)
recovered = EmbeddingService.blob_to_embedding(blob, 384)
assert np.allclose(embedding, recovered)
def test_cosine_similarity(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
e1 = service.encode_single("git commit")
e2 = service.encode_single("git add .")
e3 = service.encode_single("docker run")
sim_same = EmbeddingService.cosine_similarity(e1, e2)
sim_diff = EmbeddingService.cosine_similarity(e1, e3)
assert -1 <= sim_same <= 1
assert -1 <= sim_diff <= 1
assert sim_same > sim_diff
def test_cosine_similarity_perfect_match(self, temp_cache_dir):
service = EmbeddingService(cache_dir=temp_cache_dir)
e1 = service.encode_single("same command")
sim = EmbeddingService.cosine_similarity(e1, e1)
assert 0.9999 < sim <= 1.0001