- Fixed undefined 'tool' variable in display_history function - Changed '[tool]' markup tag usage to proper Rich syntax - All tests now pass (38/38 unit tests) - Type checking passes with mypy --strict
132 lines
4.0 KiB
Python
132 lines
4.0 KiB
Python
"""Command library loader module."""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
from shell_speak.config import get_data_dir
|
|
from shell_speak.models import CommandPattern
|
|
|
|
|
|
class CommandLibraryLoader:
|
|
"""Loads and manages command pattern libraries."""
|
|
|
|
def __init__(self) -> None:
|
|
self._patterns: list[CommandPattern] = []
|
|
self._corrections: dict[str, str] = {}
|
|
self._loaded = False
|
|
|
|
def load_library(self, tool: str | None = None) -> None:
|
|
"""Load command patterns from library files."""
|
|
data_dir = get_data_dir()
|
|
self._patterns = []
|
|
|
|
tool_files = {
|
|
"docker": "docker.yaml",
|
|
"kubectl": "kubectl.yaml",
|
|
"git": "git.yaml",
|
|
"unix": "unix.yaml",
|
|
}
|
|
|
|
if tool:
|
|
files_to_load = {tool: tool_files.get(tool, f"{tool}.yaml")}
|
|
else:
|
|
files_to_load = tool_files
|
|
|
|
for tool_name, filename in files_to_load.items():
|
|
filepath = data_dir / filename
|
|
if filepath.exists():
|
|
try:
|
|
patterns = self._load_yaml_library(filepath, tool_name)
|
|
self._patterns.extend(patterns)
|
|
except Exception:
|
|
pass
|
|
|
|
self._load_corrections()
|
|
self._loaded = True
|
|
|
|
def _load_yaml_library(self, filepath: Path, tool: str) -> list[CommandPattern]:
|
|
"""Load patterns from a YAML file."""
|
|
with open(filepath) as f:
|
|
data = yaml.safe_load(f) or {}
|
|
|
|
patterns = []
|
|
for item in data.get("patterns", []):
|
|
pattern = CommandPattern(
|
|
name=item.get("name", ""),
|
|
tool=tool,
|
|
description=item.get("description", ""),
|
|
patterns=item.get("patterns", []),
|
|
template=item.get("template", ""),
|
|
explanation=item.get("explanation", ""),
|
|
examples=item.get("examples", []),
|
|
)
|
|
patterns.append(pattern)
|
|
|
|
return patterns
|
|
|
|
def _load_corrections(self) -> None:
|
|
"""Load user corrections from JSON file."""
|
|
corrections_file = get_data_dir() / "corrections.json"
|
|
if corrections_file.exists():
|
|
try:
|
|
with open(corrections_file) as f:
|
|
data = json.load(f)
|
|
self._corrections = data.get("corrections", {})
|
|
except Exception:
|
|
self._corrections = {}
|
|
|
|
def get_patterns(self) -> list[CommandPattern]:
|
|
"""Get all loaded patterns."""
|
|
if not self._loaded:
|
|
self.load_library()
|
|
return self._patterns
|
|
|
|
def get_corrections(self) -> dict[str, str]:
|
|
"""Get all user corrections."""
|
|
if not self._loaded:
|
|
self.load_library()
|
|
return self._corrections
|
|
|
|
def add_correction(self, query: str, command: str, tool: str) -> None:
|
|
"""Add a user correction."""
|
|
key = f"{tool}:{query.lower().strip()}"
|
|
self._corrections[key] = command
|
|
self._save_corrections()
|
|
|
|
def remove_correction(self, query: str, tool: str) -> bool:
|
|
"""Remove a user correction."""
|
|
key = f"{tool}:{query.lower().strip()}"
|
|
if key in self._corrections:
|
|
del self._corrections[key]
|
|
self._save_corrections()
|
|
return True
|
|
return False
|
|
|
|
def _save_corrections(self) -> None:
|
|
"""Save corrections to JSON file."""
|
|
corrections_file = get_data_dir() / "corrections.json"
|
|
data = {
|
|
"version": "1.0",
|
|
"corrections": self._corrections,
|
|
}
|
|
with open(corrections_file, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
def reload(self) -> None:
|
|
"""Reload all libraries and corrections."""
|
|
self._loaded = False
|
|
self.load_library()
|
|
|
|
|
|
_loader: CommandLibraryLoader | None = None
|
|
|
|
|
|
def get_loader() -> CommandLibraryLoader:
|
|
"""Get the global command library loader."""
|
|
global _loader
|
|
if _loader is None:
|
|
_loader = CommandLibraryLoader()
|
|
return _loader
|