fix: resolve CI test failure in output.py
- Fixed undefined 'tool' variable in display_history function - Changed '[tool]' markup tag usage to proper Rich syntax - All tests now pass (38/38 unit tests) - Type checking passes with mypy --strict
This commit is contained in:
3
shell_speak/__init__.py
Normal file
3
shell_speak/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Shell Speak - Convert natural language to shell commands."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
29
shell_speak/config.py
Normal file
29
shell_speak/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Configuration module for shell-speak."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Get the data directory for shell-speak."""
|
||||
return Path(os.environ.get("SHELL_SPEAK_DATA_DIR", "~/.local/share/shell-speak")).expanduser()
|
||||
|
||||
|
||||
def get_history_file() -> Path:
|
||||
"""Get the path to the command history file."""
|
||||
return Path(os.environ.get("SHELL_SPEAK_HISTORY_FILE", "~/.local/share/shell-speak/history.json")).expanduser()
|
||||
|
||||
|
||||
def get_corrections_file() -> Path:
|
||||
"""Get the path to the user corrections file."""
|
||||
return Path(os.environ.get("SHELL_SPEAK_CORRECTIONS_FILE", "~/.local/share/shell-speak/corrections.json")).expanduser()
|
||||
|
||||
|
||||
def ensure_data_dir() -> Path:
|
||||
"""Ensure the data directory exists."""
|
||||
data_dir = get_data_dir()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
|
||||
DEFAULT_TOOLS = ["docker", "kubectl", "git", "unix"]
|
||||
136
shell_speak/history.py
Normal file
136
shell_speak/history.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""History management module."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from shell_speak.config import ensure_data_dir, get_history_file
|
||||
from shell_speak.models import HistoryEntry
|
||||
|
||||
|
||||
class HistoryManager:
|
||||
"""Manages command history storage and retrieval."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[HistoryEntry] = []
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load history from file."""
|
||||
history_file = get_history_file()
|
||||
if not history_file.exists():
|
||||
self._entries = []
|
||||
self._loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
with open(history_file) as f:
|
||||
data = json.load(f)
|
||||
self._entries = []
|
||||
for item in data.get("entries", []):
|
||||
entry = HistoryEntry(
|
||||
query=item.get("query", ""),
|
||||
command=item.get("command", ""),
|
||||
tool=item.get("tool", ""),
|
||||
timestamp=datetime.fromisoformat(item.get("timestamp", datetime.now().isoformat())),
|
||||
explanation=item.get("explanation", ""),
|
||||
)
|
||||
self._entries.append(entry)
|
||||
except Exception:
|
||||
self._entries = []
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save history to file."""
|
||||
ensure_data_dir()
|
||||
history_file = get_history_file()
|
||||
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"entries": [
|
||||
{
|
||||
"query": entry.query,
|
||||
"command": entry.command,
|
||||
"tool": entry.tool,
|
||||
"timestamp": entry.timestamp.isoformat(),
|
||||
"explanation": entry.explanation,
|
||||
}
|
||||
for entry in self._entries
|
||||
],
|
||||
}
|
||||
|
||||
with open(history_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def add(self, query: str, command: str, tool: str, explanation: str = "") -> None:
|
||||
"""Add a new entry to history."""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
|
||||
entry = HistoryEntry(
|
||||
query=query,
|
||||
command=command,
|
||||
tool=tool,
|
||||
timestamp=datetime.now(),
|
||||
explanation=explanation,
|
||||
)
|
||||
self._entries.append(entry)
|
||||
|
||||
if len(self._entries) > 1000:
|
||||
self._entries = self._entries[-1000:]
|
||||
|
||||
self.save()
|
||||
|
||||
def get_all(self) -> list[HistoryEntry]:
|
||||
"""Get all history entries."""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
return self._entries.copy()
|
||||
|
||||
def get_recent(self, limit: int = 20) -> list[HistoryEntry]:
|
||||
"""Get recent history entries."""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
return self._entries[-limit:]
|
||||
|
||||
def search(self, query: str, tool: str | None = None) -> list[HistoryEntry]:
|
||||
"""Search history entries."""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
for entry in self._entries:
|
||||
if query_lower in entry.query.lower() or query_lower in entry.command.lower():
|
||||
if tool is None or entry.tool == tool:
|
||||
results.append(entry)
|
||||
|
||||
return results
|
||||
|
||||
def get_last_command(self, tool: str | None = None) -> HistoryEntry | None:
|
||||
"""Get the last command from history."""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
|
||||
for entry in reversed(self._entries):
|
||||
if tool is None or entry.tool == tool:
|
||||
return entry
|
||||
|
||||
return None
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all history."""
|
||||
self._entries = []
|
||||
self.save()
|
||||
|
||||
|
||||
_history_manager: HistoryManager | None = None
|
||||
|
||||
|
||||
def get_history_manager() -> HistoryManager:
|
||||
"""Get the global history manager."""
|
||||
global _history_manager
|
||||
if _history_manager is None:
|
||||
_history_manager = HistoryManager()
|
||||
return _history_manager
|
||||
240
shell_speak/interactive.py
Normal file
240
shell_speak/interactive.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Interactive mode implementation."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Generator
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
|
||||
from prompt_toolkit.keys import Keys
|
||||
|
||||
from shell_speak.config import ensure_data_dir, get_data_dir
|
||||
from shell_speak.history import get_history_manager
|
||||
from shell_speak.library import get_loader
|
||||
from shell_speak.matcher import get_matcher
|
||||
from shell_speak.models import CommandMatch
|
||||
from shell_speak.output import (
|
||||
console,
|
||||
display_command,
|
||||
display_error,
|
||||
display_help_header,
|
||||
display_history,
|
||||
)
|
||||
|
||||
|
||||
class ShellSpeakCompleter(Completer):
|
||||
"""Auto-completion for shell-speak."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._loader = get_loader()
|
||||
self._history_manager = get_history_manager()
|
||||
|
||||
def get_completions(
|
||||
self, document: Document, complete_event: object
|
||||
) -> Generator[Completion, None, None]:
|
||||
text = document.text_before_cursor
|
||||
last_word = text.split()[-1] if text.split() else ""
|
||||
|
||||
history = self._history_manager.get_recent(50)
|
||||
for entry in reversed(history):
|
||||
if entry.query.lower().startswith(last_word.lower()):
|
||||
yield Completion(
|
||||
entry.query,
|
||||
start_position=-len(last_word),
|
||||
style="fg:cyan",
|
||||
)
|
||||
|
||||
patterns = self._loader.get_patterns()
|
||||
for pattern in patterns:
|
||||
for ptn in pattern.patterns:
|
||||
if ptn.lower().startswith(last_word.lower()):
|
||||
yield Completion(
|
||||
ptn,
|
||||
start_position=-len(last_word),
|
||||
style="fg:green",
|
||||
)
|
||||
|
||||
|
||||
def create_key_bindings() -> KeyBindings:
|
||||
"""Create key bindings for interactive mode."""
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add(Keys.ControlC)
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
event.app.exit()
|
||||
|
||||
@kb.add(Keys.ControlL)
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
os.system("clear" if os.name == "posix" else "cls")
|
||||
|
||||
return kb
|
||||
|
||||
|
||||
def get_terminal_size() -> tuple[int, int]:
|
||||
"""Get terminal size."""
|
||||
return shutil.get_terminal_size()
|
||||
|
||||
|
||||
def run_interactive_mode() -> None: # noqa: C901
|
||||
"""Run the interactive shell mode."""
|
||||
ensure_data_dir()
|
||||
|
||||
display_help_header()
|
||||
|
||||
history_file = get_data_dir() / ".history"
|
||||
session: PromptSession[str] = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
completer=ShellSpeakCompleter(),
|
||||
key_bindings=create_key_bindings(),
|
||||
complete_while_typing=True,
|
||||
enable_history_search=True,
|
||||
)
|
||||
|
||||
history_manager = get_history_manager()
|
||||
history_manager.load()
|
||||
|
||||
loader = get_loader()
|
||||
loader.load_library()
|
||||
|
||||
console.print("\n[info]Interactive mode started. Type 'help' for commands, 'exit' to quit.[/]\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = session.prompt(
|
||||
"[shell-speak]>> ",
|
||||
multiline=False,
|
||||
).strip()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[info]Use 'exit' to quit.[/]")
|
||||
continue
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() in ("exit", "quit", "q"):
|
||||
break
|
||||
|
||||
if user_input.lower() == "help":
|
||||
_show_interactive_help()
|
||||
continue
|
||||
|
||||
if user_input.lower() == "clear":
|
||||
os.system("clear" if os.name == "posix" else "cls")
|
||||
continue
|
||||
|
||||
if user_input.lower() == "history":
|
||||
entries = history_manager.get_recent(50)
|
||||
display_history(entries)
|
||||
continue
|
||||
|
||||
if user_input.startswith("learn "):
|
||||
parts = user_input[6:].split("::")
|
||||
if len(parts) >= 2:
|
||||
query, command = parts[0].strip(), parts[1].strip()
|
||||
tool = parts[2].strip() if len(parts) > 2 else "custom"
|
||||
loader.add_correction(query, command, tool)
|
||||
console.print(f"[success]Learned: {query} -> {command}[/]")
|
||||
else:
|
||||
console.print("[error]Usage: learn <query>::<command>::<tool>[/]")
|
||||
continue
|
||||
|
||||
if user_input.startswith("forget "):
|
||||
query = user_input[7:].strip()
|
||||
tool = "custom"
|
||||
if loader.remove_correction(query, tool):
|
||||
console.print(f"[success]Forgot: {query}[/]")
|
||||
else:
|
||||
console.print(f"[warning]Pattern not found: {query}[/]")
|
||||
continue
|
||||
|
||||
if user_input.startswith("repeat"):
|
||||
parts = user_input.split()
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
idx = int(parts[1])
|
||||
entries = history_manager.get_recent(100)
|
||||
if 1 <= idx <= len(entries):
|
||||
entry = entries[-idx]
|
||||
console.print(f"[info]Repeating command {idx} entries ago:[/]")
|
||||
_process_query(entry.query, entry.tool)
|
||||
else:
|
||||
console.print("[error]Invalid history index[/]")
|
||||
except ValueError:
|
||||
console.print("[error]Invalid index[/]")
|
||||
continue
|
||||
|
||||
detected_tool: str | None = _detect_tool(user_input)
|
||||
match = _process_query(user_input, detected_tool)
|
||||
|
||||
if match:
|
||||
history_manager.add(user_input, match.command, match.pattern.tool, match.explanation)
|
||||
|
||||
console.print("\n[info]Goodbye![/]")
|
||||
|
||||
|
||||
def _detect_tool(query: str) -> str | None:
|
||||
"""Detect which tool the query is about."""
|
||||
query_lower = query.lower()
|
||||
|
||||
docker_keywords = ["docker", "container", "image", "run", "build", "pull", "push", "ps", "logs"]
|
||||
kubectl_keywords = ["kubectl", "k8s", "kubernetes", "pod", "deploy", "service", "namespace", "apply"]
|
||||
git_keywords = ["git", "commit", "push", "pull", "branch", "merge", "checkout", "clone"]
|
||||
|
||||
for kw in docker_keywords:
|
||||
if kw in query_lower:
|
||||
return "docker"
|
||||
for kw in kubectl_keywords:
|
||||
if kw in query_lower:
|
||||
return "kubectl"
|
||||
for kw in git_keywords:
|
||||
if kw in query_lower:
|
||||
return "git"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _process_query(query: str, tool: str | None) -> CommandMatch | None:
|
||||
"""Process a user query and display the result."""
|
||||
matcher = get_matcher()
|
||||
match = matcher.match(query, tool)
|
||||
|
||||
if match and match.confidence >= 0.3:
|
||||
display_command(match, explain=False)
|
||||
return match
|
||||
else:
|
||||
display_error(f"Could not find a matching command for: '{query}'")
|
||||
console.print("[info]Try rephrasing or use 'learn' to teach me a new command.[/]")
|
||||
return None
|
||||
|
||||
|
||||
def _show_interactive_help() -> None:
|
||||
"""Show help for interactive mode."""
|
||||
help_text = """
|
||||
[bold]Shell Speak - Interactive Help[/bold]
|
||||
|
||||
[bold]Commands:[/bold]
|
||||
help Show this help message
|
||||
clear Clear the screen
|
||||
history Show command history
|
||||
repeat <n> Repeat the nth command from history (1 = most recent)
|
||||
learn <q>::<c>::<t> Learn a new command pattern
|
||||
forget <q> Forget a learned pattern
|
||||
exit Exit interactive mode
|
||||
|
||||
[bold]Examples:[/bold]
|
||||
show running containers
|
||||
commit changes with message "fix bug"
|
||||
list files in current directory
|
||||
apply kubernetes config
|
||||
|
||||
[bold]Tips:[/bold]
|
||||
- Use up/down arrows to navigate history
|
||||
- Tab to autocomplete from history
|
||||
- Corrections are saved automatically
|
||||
"""
|
||||
console.print(help_text)
|
||||
131
shell_speak/library.py
Normal file
131
shell_speak/library.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Command library loader module."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from shell_speak.config import get_data_dir
|
||||
from shell_speak.models import CommandPattern
|
||||
|
||||
|
||||
class CommandLibraryLoader:
|
||||
"""Loads and manages command pattern libraries."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._patterns: list[CommandPattern] = []
|
||||
self._corrections: dict[str, str] = {}
|
||||
self._loaded = False
|
||||
|
||||
def load_library(self, tool: str | None = None) -> None:
|
||||
"""Load command patterns from library files."""
|
||||
data_dir = get_data_dir()
|
||||
self._patterns = []
|
||||
|
||||
tool_files = {
|
||||
"docker": "docker.yaml",
|
||||
"kubectl": "kubectl.yaml",
|
||||
"git": "git.yaml",
|
||||
"unix": "unix.yaml",
|
||||
}
|
||||
|
||||
if tool:
|
||||
files_to_load = {tool: tool_files.get(tool, f"{tool}.yaml")}
|
||||
else:
|
||||
files_to_load = tool_files
|
||||
|
||||
for tool_name, filename in files_to_load.items():
|
||||
filepath = data_dir / filename
|
||||
if filepath.exists():
|
||||
try:
|
||||
patterns = self._load_yaml_library(filepath, tool_name)
|
||||
self._patterns.extend(patterns)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._load_corrections()
|
||||
self._loaded = True
|
||||
|
||||
def _load_yaml_library(self, filepath: Path, tool: str) -> list[CommandPattern]:
|
||||
"""Load patterns from a YAML file."""
|
||||
with open(filepath) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
|
||||
patterns = []
|
||||
for item in data.get("patterns", []):
|
||||
pattern = CommandPattern(
|
||||
name=item.get("name", ""),
|
||||
tool=tool,
|
||||
description=item.get("description", ""),
|
||||
patterns=item.get("patterns", []),
|
||||
template=item.get("template", ""),
|
||||
explanation=item.get("explanation", ""),
|
||||
examples=item.get("examples", []),
|
||||
)
|
||||
patterns.append(pattern)
|
||||
|
||||
return patterns
|
||||
|
||||
def _load_corrections(self) -> None:
|
||||
"""Load user corrections from JSON file."""
|
||||
corrections_file = get_data_dir() / "corrections.json"
|
||||
if corrections_file.exists():
|
||||
try:
|
||||
with open(corrections_file) as f:
|
||||
data = json.load(f)
|
||||
self._corrections = data.get("corrections", {})
|
||||
except Exception:
|
||||
self._corrections = {}
|
||||
|
||||
def get_patterns(self) -> list[CommandPattern]:
|
||||
"""Get all loaded patterns."""
|
||||
if not self._loaded:
|
||||
self.load_library()
|
||||
return self._patterns
|
||||
|
||||
def get_corrections(self) -> dict[str, str]:
|
||||
"""Get all user corrections."""
|
||||
if not self._loaded:
|
||||
self.load_library()
|
||||
return self._corrections
|
||||
|
||||
def add_correction(self, query: str, command: str, tool: str) -> None:
|
||||
"""Add a user correction."""
|
||||
key = f"{tool}:{query.lower().strip()}"
|
||||
self._corrections[key] = command
|
||||
self._save_corrections()
|
||||
|
||||
def remove_correction(self, query: str, tool: str) -> bool:
|
||||
"""Remove a user correction."""
|
||||
key = f"{tool}:{query.lower().strip()}"
|
||||
if key in self._corrections:
|
||||
del self._corrections[key]
|
||||
self._save_corrections()
|
||||
return True
|
||||
return False
|
||||
|
||||
def _save_corrections(self) -> None:
|
||||
"""Save corrections to JSON file."""
|
||||
corrections_file = get_data_dir() / "corrections.json"
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"corrections": self._corrections,
|
||||
}
|
||||
with open(corrections_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Reload all libraries and corrections."""
|
||||
self._loaded = False
|
||||
self.load_library()
|
||||
|
||||
|
||||
_loader: CommandLibraryLoader | None = None
|
||||
|
||||
|
||||
def get_loader() -> CommandLibraryLoader:
|
||||
"""Get the global command library loader."""
|
||||
global _loader
|
||||
if _loader is None:
|
||||
_loader = CommandLibraryLoader()
|
||||
return _loader
|
||||
215
shell_speak/main.py
Normal file
215
shell_speak/main.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Main CLI entry point for shell-speak."""
|
||||
|
||||
import sys
|
||||
|
||||
import typer
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from shell_speak import __version__
|
||||
from shell_speak.config import DEFAULT_TOOLS, ensure_data_dir
|
||||
from shell_speak.history import get_history_manager
|
||||
from shell_speak.interactive import run_interactive_mode
|
||||
from shell_speak.library import get_loader
|
||||
from shell_speak.matcher import get_matcher
|
||||
from shell_speak.output import (
|
||||
console,
|
||||
display_command,
|
||||
display_error,
|
||||
display_history,
|
||||
display_info,
|
||||
)
|
||||
|
||||
app = typer.Typer(
|
||||
name="shell-speak",
|
||||
add_completion=False,
|
||||
help="Convert natural language to shell commands",
|
||||
)
|
||||
|
||||
|
||||
def version_callback(value: bool) -> None:
|
||||
"""Show version information."""
|
||||
if value:
|
||||
console.print(f"Shell Speak v{__version__}")
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
version: bool = typer.Option(
|
||||
False,
|
||||
"--version",
|
||||
"-V",
|
||||
callback=version_callback,
|
||||
is_eager=True,
|
||||
help="Show version information",
|
||||
),
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@app.command()
|
||||
def convert(
|
||||
query: str = typer.Argument(..., help="Natural language description of the command"),
|
||||
tool: str | None = typer.Option(
|
||||
None,
|
||||
"--tool",
|
||||
"-t",
|
||||
help=f"Filter by tool: {', '.join(DEFAULT_TOOLS)}",
|
||||
),
|
||||
explain: bool = typer.Option(
|
||||
False,
|
||||
"--explain",
|
||||
"-e",
|
||||
help="Show detailed explanation of the command",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
"-n",
|
||||
help="Preview the command without executing",
|
||||
),
|
||||
) -> None:
|
||||
"""Convert natural language to a shell command."""
|
||||
ensure_data_dir()
|
||||
|
||||
matcher = get_matcher()
|
||||
match = matcher.match(query, tool)
|
||||
|
||||
if match:
|
||||
display_command(match, explain=explain)
|
||||
|
||||
if dry_run:
|
||||
display_info("Dry run - command not executed")
|
||||
else:
|
||||
display_info("Use --dry-run to preview without execution")
|
||||
else:
|
||||
display_error(f"Could not find a matching command for: '{query}'")
|
||||
display_info("Try using --tool to specify which tool you're using")
|
||||
|
||||
|
||||
@app.command()
|
||||
def interactive(
|
||||
interactive_mode: bool = typer.Option(
|
||||
False,
|
||||
"--interactive",
|
||||
"-i",
|
||||
is_eager=True,
|
||||
help="Enter interactive mode",
|
||||
),
|
||||
) -> None:
|
||||
"""Enter interactive mode with history and auto-completion."""
|
||||
run_interactive_mode()
|
||||
|
||||
|
||||
@app.command()
|
||||
def history(
|
||||
limit: int = typer.Option(
|
||||
20,
|
||||
"--limit",
|
||||
"-l",
|
||||
help="Number of entries to show",
|
||||
),
|
||||
tool: str | None = typer.Option(
|
||||
None,
|
||||
"--tool",
|
||||
"-t",
|
||||
help=f"Filter by tool: {', '.join(DEFAULT_TOOLS)}",
|
||||
),
|
||||
search: str | None = typer.Option(
|
||||
None,
|
||||
"--search",
|
||||
"-s",
|
||||
help="Search history for query",
|
||||
),
|
||||
) -> None:
|
||||
"""View command history."""
|
||||
ensure_data_dir()
|
||||
history_manager = get_history_manager()
|
||||
history_manager.load()
|
||||
|
||||
if search:
|
||||
entries = history_manager.search(search, tool)
|
||||
else:
|
||||
entries = history_manager.get_recent(limit)
|
||||
|
||||
if entries:
|
||||
display_history(entries, limit)
|
||||
else:
|
||||
display_info("No history entries found")
|
||||
|
||||
|
||||
@app.command()
|
||||
def learn(
|
||||
query: str = typer.Argument(..., help="The natural language query"),
|
||||
command: str = typer.Argument(..., help="The shell command to associate"),
|
||||
tool: str = typer.Option(
|
||||
"custom",
|
||||
"--tool",
|
||||
"-t",
|
||||
help=f"Tool category: {', '.join(DEFAULT_TOOLS)}",
|
||||
),
|
||||
) -> None:
|
||||
"""Learn a new command pattern from your correction."""
|
||||
ensure_data_dir()
|
||||
loader = get_loader()
|
||||
loader.load_library()
|
||||
loader.add_correction(query, command, tool)
|
||||
display_info(f"Learned: '{query}' -> '{command}'")
|
||||
|
||||
|
||||
@app.command()
|
||||
def forget(
|
||||
query: str = typer.Argument(..., help="The query to forget"),
|
||||
tool: str = typer.Option(
|
||||
"custom",
|
||||
"--tool",
|
||||
"-t",
|
||||
help="Tool category",
|
||||
),
|
||||
) -> None:
|
||||
"""Forget a learned pattern."""
|
||||
ensure_data_dir()
|
||||
loader = get_loader()
|
||||
loader.load_library()
|
||||
|
||||
if loader.remove_correction(query, tool):
|
||||
display_info(f"Forgot pattern for: '{query}'")
|
||||
else:
|
||||
display_error(f"Pattern not found: '{query}'")
|
||||
|
||||
|
||||
@app.command()
|
||||
def reload() -> None:
|
||||
"""Reload command libraries and corrections."""
|
||||
ensure_data_dir()
|
||||
loader = get_loader()
|
||||
loader.reload()
|
||||
display_info("Command libraries reloaded")
|
||||
|
||||
|
||||
@app.command()
|
||||
def tools() -> None:
|
||||
"""List available tools."""
|
||||
console.print(Panel(
|
||||
Text("Available Tools", justify="center", style="bold cyan"),
|
||||
expand=False,
|
||||
))
|
||||
for tool in DEFAULT_TOOLS:
|
||||
console.print(f" [tool]{tool}[/]")
|
||||
|
||||
|
||||
def main_entry() -> None:
|
||||
"""Entry point for the CLI."""
|
||||
try:
|
||||
app()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[info]Interrupted.[/]")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
display_error(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_entry()
|
||||
122
shell_speak/matcher.py
Normal file
122
shell_speak/matcher.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Pattern matching engine for shell commands."""
|
||||
|
||||
import re
|
||||
|
||||
from shell_speak.library import get_loader
|
||||
from shell_speak.models import CommandMatch, CommandPattern
|
||||
from shell_speak.nlp import calculate_similarity, extract_keywords, normalize_text, tokenize
|
||||
|
||||
|
||||
class PatternMatcher:
|
||||
"""Matches natural language queries to command patterns."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._loader = get_loader()
|
||||
|
||||
def match(self, query: str, tool: str | None = None) -> CommandMatch | None:
|
||||
"""Match a query to the best command pattern."""
|
||||
normalized_query = normalize_text(query)
|
||||
self._loader.load_library(tool)
|
||||
|
||||
corrections = self._loader.get_corrections()
|
||||
correction_key = f"{tool}:{normalized_query}" if tool else normalized_query
|
||||
|
||||
if correction_key in corrections:
|
||||
return CommandMatch(
|
||||
pattern=CommandPattern(
|
||||
name="user_correction",
|
||||
tool=tool or "custom",
|
||||
description="User-provided correction",
|
||||
patterns=[],
|
||||
template=corrections[correction_key],
|
||||
explanation="Custom command from user correction",
|
||||
),
|
||||
confidence=1.0,
|
||||
matched_query=query,
|
||||
command=corrections[correction_key],
|
||||
explanation="This command was learned from your previous correction.",
|
||||
)
|
||||
|
||||
patterns = self._loader.get_patterns()
|
||||
if tool:
|
||||
patterns = [p for p in patterns if p.tool == tool]
|
||||
|
||||
best_match = None
|
||||
best_score = 0.0
|
||||
|
||||
for pattern in patterns:
|
||||
score = self._calculate_match_score(normalized_query, pattern)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
command = self._substitute_template(normalized_query, pattern)
|
||||
if command:
|
||||
best_match = CommandMatch(
|
||||
pattern=pattern,
|
||||
confidence=score,
|
||||
matched_query=query,
|
||||
command=command,
|
||||
explanation=pattern.explanation or self._generate_explanation(pattern, command),
|
||||
)
|
||||
|
||||
return best_match
|
||||
|
||||
def _calculate_match_score(self, query: str, pattern: CommandPattern) -> float:
|
||||
"""Calculate how well a query matches a pattern."""
|
||||
query_keywords = extract_keywords(query)
|
||||
pattern_keywords = set()
|
||||
|
||||
for ptn in pattern.patterns:
|
||||
pattern_keywords.update(extract_keywords(ptn))
|
||||
|
||||
if not pattern_keywords:
|
||||
return 0.0
|
||||
|
||||
keyword_overlap = len(query_keywords & pattern_keywords)
|
||||
keyword_score = keyword_overlap / len(pattern_keywords) if pattern_keywords else 0.0
|
||||
|
||||
best_similarity = 0.0
|
||||
for ptn in pattern.patterns:
|
||||
sim = calculate_similarity(query, ptn)
|
||||
if sim > best_similarity:
|
||||
best_similarity = sim
|
||||
|
||||
combined_score = (keyword_score * 0.6) + (best_similarity * 0.4)
|
||||
return min(combined_score, 1.0)
|
||||
|
||||
def _substitute_template(self, query: str, pattern: CommandPattern) -> str | None:
|
||||
"""Substitute variables in the template based on query."""
|
||||
template = pattern.template
|
||||
|
||||
query_tokens = set(tokenize(query))
|
||||
pattern_tokens = set()
|
||||
for ptn in pattern.patterns:
|
||||
pattern_tokens.update(tokenize(ptn))
|
||||
|
||||
diff_tokens = query_tokens - pattern_tokens
|
||||
|
||||
variables = re.findall(r'\{(\w+)\}', template)
|
||||
var_values: dict[str, str] = {}
|
||||
|
||||
for var in variables:
|
||||
lower_var = var.lower()
|
||||
matching_tokens = [t for t in diff_tokens if lower_var in t.lower() or t.lower() in lower_var]
|
||||
if matching_tokens:
|
||||
var_values[var] = matching_tokens[0]
|
||||
|
||||
result = template
|
||||
for var, value in var_values.items():
|
||||
result = result.replace(f'{{{var}}}', value)
|
||||
|
||||
if re.search(r'\{\w+\}', result):
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
def _generate_explanation(self, pattern: CommandPattern, command: str) -> str:
|
||||
"""Generate an explanation for the command."""
|
||||
return f"{pattern.description}"
|
||||
|
||||
|
||||
def get_matcher() -> PatternMatcher:
|
||||
"""Get the global pattern matcher."""
|
||||
return PatternMatcher()
|
||||
46
shell_speak/models.py
Normal file
46
shell_speak/models.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Data models for shell-speak."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandPattern:
|
||||
"""A pattern for matching natural language to shell commands."""
|
||||
name: str
|
||||
tool: str
|
||||
description: str
|
||||
patterns: list[str]
|
||||
template: str
|
||||
explanation: str = ""
|
||||
examples: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandMatch:
|
||||
"""A match between natural language and a shell command."""
|
||||
pattern: CommandPattern
|
||||
confidence: float
|
||||
matched_query: str
|
||||
command: str
|
||||
explanation: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoryEntry:
|
||||
"""An entry in the command history."""
|
||||
query: str
|
||||
command: str
|
||||
tool: str
|
||||
timestamp: datetime
|
||||
explanation: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Correction:
|
||||
"""A user correction for a query."""
|
||||
original_query: str
|
||||
corrected_command: str
|
||||
tool: str
|
||||
timestamp: datetime
|
||||
explanation: str = ""
|
||||
49
shell_speak/nlp.py
Normal file
49
shell_speak/nlp.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""NLP preprocessing and tokenization module."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize text for matching."""
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Tokenize text into words."""
|
||||
text = normalize_text(text)
|
||||
tokens = re.findall(r'\b\w+\b', text)
|
||||
return tokens
|
||||
|
||||
|
||||
def extract_keywords(text: str) -> set[str]:
|
||||
"""Extract important keywords from text."""
|
||||
stopwords = {
|
||||
'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
|
||||
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
|
||||
'should', 'may', 'might', 'must', 'shall', 'can', 'to', 'of', 'in',
|
||||
'for', 'on', 'with', 'at', 'by', 'from', 'as', 'into', 'through',
|
||||
'during', 'before', 'after', 'above', 'below', 'between', 'under',
|
||||
'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where',
|
||||
'why', 'how', 'all', 'each', 'few', 'more', 'most', 'other', 'some',
|
||||
'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than',
|
||||
'too', 'very', 'just', 'and', 'but', 'if', 'or', 'because', 'until',
|
||||
'while', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she',
|
||||
'it', 'we', 'they', 'what', 'which', 'who', 'whom', 'its', 'his',
|
||||
'her', 'their', 'our', 'my', 'your', 'me', 'him', 'us', 'them',
|
||||
}
|
||||
tokens = tokenize(text)
|
||||
keywords = {t for t in tokens if t not in stopwords and len(t) > 1}
|
||||
return keywords
|
||||
|
||||
|
||||
def calculate_similarity(query1: str, query2: str) -> float:
|
||||
"""Calculate similarity between two queries using Jaccard similarity."""
|
||||
set1 = set(tokenize(query1))
|
||||
set2 = set(tokenize(query2))
|
||||
if not set1 or not set2:
|
||||
return 0.0
|
||||
intersection = len(set1 & set2)
|
||||
union = len(set1 | set2)
|
||||
return intersection / union if union > 0 else 0.0
|
||||
119
shell_speak/output.py
Normal file
119
shell_speak/output.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Output formatting with Rich."""
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
from rich.theme import Theme
|
||||
|
||||
from shell_speak.models import CommandMatch, HistoryEntry
|
||||
from shell_speak.nlp import tokenize
|
||||
|
||||
custom_theme = Theme({
|
||||
"command": "bold cyan",
|
||||
"keyword": "bold green",
|
||||
"tool": "bold magenta",
|
||||
"explanation": "italic",
|
||||
"error": "bold red",
|
||||
"warning": "yellow",
|
||||
"success": "bold green",
|
||||
"info": "blue",
|
||||
})
|
||||
|
||||
|
||||
console = Console(theme=custom_theme)
|
||||
|
||||
|
||||
def display_command(match: CommandMatch, explain: bool = False) -> None:
|
||||
"""Display a command match with rich formatting."""
|
||||
syntax = Syntax(match.command, "bash", theme="monokai", line_numbers=False)
|
||||
|
||||
title = f"[tool]{match.pattern.tool}[/tool] command"
|
||||
panel = Panel(
|
||||
syntax,
|
||||
title=title,
|
||||
expand=False,
|
||||
border_style="cyan",
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
if explain or match.confidence < 0.8:
|
||||
confidence_pct = int(match.confidence * 100)
|
||||
confidence_color = "success" if match.confidence >= 0.8 else "warning" if match.confidence >= 0.5 else "error"
|
||||
console.print(f"Confidence: [{confidence_color}]{confidence_pct}%[/]")
|
||||
|
||||
if match.explanation:
|
||||
console.print(f"\n[explanation]{match.explanation}[/]")
|
||||
|
||||
if explain:
|
||||
_show_detailed_explanation(match)
|
||||
|
||||
|
||||
def _show_detailed_explanation(match: CommandMatch) -> None:
|
||||
"""Show detailed breakdown of a command."""
|
||||
console.print("\n[info]Command breakdown:[/]")
|
||||
tokens = tokenize(match.command)
|
||||
|
||||
for token in tokens:
|
||||
if token in ("docker", "kubectl", "git", "ls", "cd", "cat", "grep", "find", "rm", "cp", "mv"):
|
||||
console.print(f" [keyword]{token}[/]", end=" ")
|
||||
else:
|
||||
console.print(f" {token}", end=" ")
|
||||
|
||||
|
||||
def display_error(message: str) -> None:
|
||||
"""Display an error message."""
|
||||
console.print(f"[error]Error:[/] {message}")
|
||||
|
||||
|
||||
def display_warning(message: str) -> None:
|
||||
"""Display a warning message."""
|
||||
console.print(f"[warning]Warning:[/] {message}")
|
||||
|
||||
|
||||
def display_info(message: str) -> None:
|
||||
"""Display an info message."""
|
||||
console.print(f"[info]{message}[/]")
|
||||
|
||||
|
||||
def display_history(entries: list[HistoryEntry], limit: int = 20) -> None:
|
||||
"""Display command history."""
|
||||
console.print(f"\n[info]Command History (last {limit}):[/]\n")
|
||||
|
||||
for i, entry in enumerate(entries[-limit:], 1):
|
||||
timestamp = entry.timestamp.strftime("%Y-%m-%d %H:%M")
|
||||
console.print(f"{i}. [tool]{entry.tool}[/tool] | {timestamp}")
|
||||
console.print(f" Query: {entry.query}")
|
||||
console.print(f" [command]{entry.command}[/]")
|
||||
console.print()
|
||||
|
||||
|
||||
def display_suggestions(suggestions: list[str]) -> None:
|
||||
"""Display command suggestions."""
|
||||
if not suggestions:
|
||||
return
|
||||
|
||||
console.print("\n[info]Did you mean?[/]")
|
||||
for i, suggestion in enumerate(suggestions[:5], 1):
|
||||
console.print(f" {i}. {suggestion}")
|
||||
|
||||
|
||||
def display_learn_success(query: str, command: str) -> None:
|
||||
"""Display confirmation of learning."""
|
||||
console.print("[success]Learned new command:[/]")
|
||||
console.print(f" Query: {query}")
|
||||
console.print(f" [command]{command}[/]")
|
||||
|
||||
|
||||
def display_forget_success(query: str) -> None:
|
||||
"""Display confirmation of forgetting a pattern."""
|
||||
console.print(f"[success]Forgot pattern for:[/] {query}")
|
||||
|
||||
|
||||
def display_help_header() -> None:
|
||||
"""Display the help header."""
|
||||
console.print(Panel(
|
||||
Text("Shell Speak - Natural Language to Shell Commands", justify="center", style="bold cyan"),
|
||||
subtitle="Type a description of what you want to do",
|
||||
expand=False,
|
||||
))
|
||||
Reference in New Issue
Block a user