"""Command library loader module.""" import json from pathlib import Path import yaml from shell_speak.config import get_data_dir from shell_speak.models import CommandPattern class CommandLibraryLoader: """Loads and manages command pattern libraries.""" def __init__(self) -> None: self._patterns: list[CommandPattern] = [] self._corrections: dict[str, str] = {} self._loaded = False def load_library(self, tool: str | None = None) -> None: """Load command patterns from library files.""" data_dir = get_data_dir() self._patterns = [] tool_files = { "docker": "docker.yaml", "kubectl": "kubectl.yaml", "git": "git.yaml", "unix": "unix.yaml", } if tool: files_to_load = {tool: tool_files.get(tool, f"{tool}.yaml")} else: files_to_load = tool_files for tool_name, filename in files_to_load.items(): filepath = data_dir / filename if filepath.exists(): try: patterns = self._load_yaml_library(filepath, tool_name) self._patterns.extend(patterns) except Exception: pass self._load_corrections() self._loaded = True def _load_yaml_library(self, filepath: Path, tool: str) -> list[CommandPattern]: """Load patterns from a YAML file.""" with open(filepath) as f: data = yaml.safe_load(f) or {} patterns = [] for item in data.get("patterns", []): pattern = CommandPattern( name=item.get("name", ""), tool=tool, description=item.get("description", ""), patterns=item.get("patterns", []), template=item.get("template", ""), explanation=item.get("explanation", ""), examples=item.get("examples", []), ) patterns.append(pattern) return patterns def _load_corrections(self) -> None: """Load user corrections from JSON file.""" corrections_file = get_data_dir() / "corrections.json" if corrections_file.exists(): try: with open(corrections_file) as f: data = json.load(f) self._corrections = data.get("corrections", {}) except Exception: self._corrections = {} def get_patterns(self) -> list[CommandPattern]: """Get all loaded patterns.""" if not self._loaded: self.load_library() return self._patterns def get_corrections(self) -> dict[str, str]: """Get all user corrections.""" if not self._loaded: self.load_library() return self._corrections def add_correction(self, query: str, command: str, tool: str) -> None: """Add a user correction.""" key = f"{tool}:{query.lower().strip()}" self._corrections[key] = command self._save_corrections() def remove_correction(self, query: str, tool: str) -> bool: """Remove a user correction.""" key = f"{tool}:{query.lower().strip()}" if key in self._corrections: del self._corrections[key] self._save_corrections() return True return False def _save_corrections(self) -> None: """Save corrections to JSON file.""" corrections_file = get_data_dir() / "corrections.json" data = { "version": "1.0", "corrections": self._corrections, } with open(corrections_file, 'w') as f: json.dump(data, f, indent=2) def reload(self) -> None: """Reload all libraries and corrections.""" self._loaded = False self.load_library() _loader: CommandLibraryLoader | None = None def get_loader() -> CommandLibraryLoader: """Get the global command library loader.""" global _loader if _loader is None: _loader = CommandLibraryLoader() return _loader