"""Tests for command library loader.""" import json import os import sys from pathlib import Path import pytest sys.path.insert(0, str(Path(__file__).parent.parent)) class TestCommandLibraryLoader: """Tests for CommandLibraryLoader class.""" @pytest.fixture(autouse=True) def setup(self, tmp_path, sample_docker_yaml, sample_git_yaml): self.test_dir = tmp_path docker_file = self.test_dir / "docker.yaml" docker_file.write_text(sample_docker_yaml) git_file = self.test_dir / "git.yaml" git_file.write_text(sample_git_yaml) os.environ["SHELL_SPEAK_DATA_DIR"] = str(self.test_dir) def test_load_docker_library(self): from shell_speak.library import get_loader loader = get_loader() loader._loaded = False loader.load_library("docker") patterns = loader.get_patterns() assert len(patterns) > 0 assert any(p.tool == "docker" for p in patterns) def test_load_git_library(self): from shell_speak.library import get_loader loader = get_loader() loader._loaded = False loader.load_library("git") patterns = loader.get_patterns() assert len(patterns) > 0 assert any(p.tool == "git" for p in patterns) def test_load_all_libraries(self): from shell_speak.library import get_loader loader = get_loader() loader._loaded = False loader.load_library() patterns = loader.get_patterns() docker_patterns = [p for p in patterns if p.tool == "docker"] git_patterns = [p for p in patterns if p.tool == "git"] assert len(docker_patterns) > 0 assert len(git_patterns) > 0 def test_pattern_structure(self): from shell_speak.library import get_loader loader = get_loader() loader._loaded = False loader.load_library("docker") patterns = loader.get_patterns() if patterns: pattern = patterns[0] assert hasattr(pattern, "name") assert hasattr(pattern, "tool") assert hasattr(pattern, "template") assert hasattr(pattern, "patterns") def test_corrections(self, tmp_path, sample_corrections_json): from shell_speak.library import get_loader corrections_file = tmp_path / "corrections.json" corrections_file.write_text(json.dumps(sample_corrections_json)) os.environ["SHELL_SPEAK_DATA_DIR"] = str(tmp_path) loader = get_loader() loader._loaded = False loader.load_library() corrections = loader.get_corrections() assert "custom:my custom query" in corrections assert corrections["custom:my custom query"] == "echo custom command" def test_add_correction(self): from shell_speak.library import get_loader loader = get_loader() loader._loaded = False loader.load_library() loader.add_correction("new query", "echo new", "unix") corrections = loader.get_corrections() assert "unix:new query" in corrections def test_remove_correction(self): from shell_speak.library import get_loader loader = get_loader() loader._loaded = False loader.load_library() loader.add_correction("test query", "echo test", "unix") loader.remove_correction("test query", "unix") corrections = loader.get_corrections() assert "unix:test query" not in corrections def test_reload(self): from shell_speak.library import get_loader loader = get_loader() loader.load_library() initial_count = len(loader.get_patterns()) loader.reload() reload_count = len(loader.get_patterns()) assert initial_count == reload_count