diff --git a/shell_speak/library.py b/shell_speak/library.py new file mode 100644 index 0000000..092d9ad --- /dev/null +++ b/shell_speak/library.py @@ -0,0 +1,132 @@ +"""Command library loader module.""" + +import json +from pathlib import Path +from typing import Dict, List, Optional + +import yaml + +from shell_speak.models import CommandPattern +from shell_speak.config import get_data_dir + + +class CommandLibraryLoader: + """Loads and manages command pattern libraries.""" + + def __init__(self): + self._patterns: List[CommandPattern] = [] + self._corrections: Dict[str, str] = {} + self._loaded = False + + def load_library(self, tool: Optional[str] = None) -> None: + """Load command patterns from library files.""" + data_dir = get_data_dir() + self._patterns = [] + + tool_files = { + "docker": "docker.yaml", + "kubectl": "kubectl.yaml", + "git": "git.yaml", + "unix": "unix.yaml", + } + + if tool: + files_to_load = {tool: tool_files.get(tool, f"{tool}.yaml")} + else: + files_to_load = tool_files + + for tool_name, filename in files_to_load.items(): + filepath = data_dir / filename + if filepath.exists(): + try: + patterns = self._load_yaml_library(filepath, tool_name) + self._patterns.extend(patterns) + except Exception: + pass + + self._load_corrections() + self._loaded = True + + def _load_yaml_library(self, filepath: Path, tool: str) -> List[CommandPattern]: + """Load patterns from a YAML file.""" + with open(filepath, 'r') as f: + data = yaml.safe_load(f) or {} + + patterns = [] + for item in data.get("patterns", []): + pattern = CommandPattern( + name=item.get("name", ""), + tool=tool, + description=item.get("description", ""), + patterns=item.get("patterns", []), + template=item.get("template", ""), + explanation=item.get("explanation", ""), + examples=item.get("examples", []), + ) + patterns.append(pattern) + + return patterns + + def _load_corrections(self) -> None: + """Load user corrections from JSON file.""" + corrections_file = get_data_dir() / "corrections.json" + if corrections_file.exists(): + try: + with open(corrections_file, 'r') as f: + data = json.load(f) + self._corrections = data.get("corrections", {}) + except Exception: + self._corrections = {} + + def get_patterns(self) -> List[CommandPattern]: + """Get all loaded patterns.""" + if not self._loaded: + self.load_library() + return self._patterns + + def get_corrections(self) -> Dict[str, str]: + """Get all user corrections.""" + if not self._loaded: + self.load_library() + return self._corrections + + def add_correction(self, query: str, command: str, tool: str) -> None: + """Add a user correction.""" + key = f"{tool}:{query.lower().strip()}" + self._corrections[key] = command + self._save_corrections() + + def remove_correction(self, query: str, tool: str) -> bool: + """Remove a user correction.""" + key = f"{tool}:{query.lower().strip()}" + if key in self._corrections: + del self._corrections[key] + self._save_corrections() + return True + return False + + def _save_corrections(self) -> None: + """Save corrections to JSON file.""" + corrections_file = get_data_dir() / "corrections.json" + data = { + "version": "1.0", + "corrections": self._corrections, + } + with open(corrections_file, 'w') as f: + json.dump(data, f, indent=2) + + def reload(self) -> None: + """Reload all libraries and corrections.""" + self._loaded = False + self.load_library() + + +_loader: Optional[CommandLibraryLoader] = None + + +def get_loader() -> CommandLibraryLoader: + """Get the global command library loader.""" + global _loader + if _loader is None: + _loader = CommandLibraryLoader() + return _loader