408 lines
12 KiB
Python
408 lines
12 KiB
Python
"""CLI interface for ShellGenius."""
|
|
|
|
import os
|
|
from typing import Optional
|
|
|
|
import click
|
|
from prompt_toolkit import PromptSession
|
|
from prompt_toolkit.completion import WordCompleter
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.table import Table
|
|
|
|
from shellgenius.config import get_config
|
|
from shellgenius.explainer import explain_script
|
|
from shellgenius.generation import ShellSafetyChecker, generate_shell
|
|
from shellgenius.history import HistoryLearner, get_history_storage
|
|
from shellgenius.ollama_client import get_ollama_client
|
|
from shellgenius.refactoring import refactor_script
|
|
|
|
console = Console()
|
|
session: PromptSession[str] = PromptSession()
|
|
|
|
|
|
def print_header():
|
|
"""Print welcome header."""
|
|
console.print(
|
|
Panel(
|
|
"[bold cyan]ShellGenius[/bold cyan] - AI-Powered Shell Script Assistant\n"
|
|
"[dim]Powered by Ollama - Your Local LLM[/dim]",
|
|
title="Welcome",
|
|
subtitle="Type 'help' for available commands",
|
|
)
|
|
)
|
|
|
|
|
|
def print_error(message: str):
|
|
"""Print error message.
|
|
|
|
Args:
|
|
message: Error message to display
|
|
"""
|
|
console.print(f"[bold red]Error:[/bold red] {message}")
|
|
|
|
|
|
def print_success(message: str):
|
|
"""Print success message.
|
|
|
|
Args:
|
|
message: Success message to display
|
|
"""
|
|
console.print(f"[bold green]Success:[/bold green] {message}")
|
|
|
|
|
|
def print_warning(message: str):
|
|
"""Print warning message.
|
|
|
|
Args:
|
|
message: Warning message to display
|
|
"""
|
|
console.print(f"[bold yellow]Warning:[/bold yellow] {message}")
|
|
|
|
|
|
@click.group()
|
|
@click.option(
|
|
"--host",
|
|
default=None,
|
|
help="Ollama server URL",
|
|
)
|
|
@click.option(
|
|
"--model",
|
|
default=None,
|
|
help="Ollama model to use",
|
|
)
|
|
@click.option(
|
|
"--config",
|
|
default=None,
|
|
help="Path to config file",
|
|
)
|
|
@click.pass_context
|
|
def main(
|
|
ctx: click.Context, host: Optional[str], model: Optional[str], config: Optional[str]
|
|
):
|
|
"""ShellGenius - AI-Powered Local Shell Script Assistant."""
|
|
ctx.ensure_object(dict)
|
|
ctx.obj["host"] = host
|
|
ctx.obj["model"] = model
|
|
ctx.obj["config"] = config
|
|
|
|
if config:
|
|
os.environ["SHELLGENIUS_CONFIG"] = config
|
|
|
|
|
|
@main.command()
|
|
@click.argument("description", nargs=-1, type=str)
|
|
@click.option(
|
|
"--shell",
|
|
default="bash",
|
|
type=click.Choice(["bash", "zsh", "sh"]),
|
|
help="Target shell type",
|
|
)
|
|
@click.option(
|
|
"--safety/--no-safety",
|
|
default=True,
|
|
help="Enable safety checks",
|
|
)
|
|
@click.option(
|
|
"--dry-run",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Preview without executing",
|
|
)
|
|
@click.pass_context
|
|
def generate(
|
|
ctx: click.Context,
|
|
description: tuple,
|
|
shell: str,
|
|
safety: bool,
|
|
dry_run: bool,
|
|
):
|
|
"""Generate shell commands from natural language."""
|
|
desc = " ".join(description)
|
|
|
|
if not desc:
|
|
print_error("Please provide a description")
|
|
return
|
|
|
|
print_header() if not ctx.parent else None
|
|
|
|
console.print(f"[cyan]Generating {shell} commands for:[/cyan] {desc}")
|
|
|
|
result = generate_shell(desc, shell_type=shell)
|
|
|
|
console.print("\n[bold]Generated Commands:[/bold]")
|
|
for i, cmd in enumerate(result.commands, 1):
|
|
console.print(f" {i}. {cmd}")
|
|
|
|
if safety:
|
|
checker = ShellSafetyChecker()
|
|
safety_result = checker.check_script("\n".join(result.commands))
|
|
|
|
if not safety_result["is_safe"]:
|
|
console.print("\n[bold yellow]Safety Warnings:[/bold yellow]")
|
|
for issue in safety_result["issues"]:
|
|
console.print(f" Line {issue['line']}: {issue['command']}")
|
|
for warning in issue.get("warnings", []):
|
|
console.print(f" - {warning}")
|
|
|
|
if not dry_run:
|
|
learner = HistoryLearner()
|
|
learner.learn(desc, result.commands, shell)
|
|
|
|
console.print(f"\n[dim]{result.explanation}[/dim]")
|
|
|
|
|
|
@main.command()
|
|
@click.argument("script_path", type=click.Path(exists=True))
|
|
@click.option(
|
|
"--detailed/--basic",
|
|
default=True,
|
|
help="Use detailed AI-based explanation",
|
|
)
|
|
@click.pass_context
|
|
def explain(ctx: click.Context, script_path: str, detailed: bool):
|
|
"""Explain a shell script line by line."""
|
|
with open(script_path, "r") as f:
|
|
script = f.read()
|
|
|
|
console.print(f"[cyan]Explaining:[/cyan] {script_path}")
|
|
|
|
result = explain_script(script, detailed=detailed)
|
|
|
|
console.print(f"\n[bold]Script Analysis ({result.shell_type}):[/bold]")
|
|
console.print(f"Purpose: [cyan]{result.overall_purpose}[/cyan]")
|
|
console.print(f"Summary: {result.summary}")
|
|
|
|
console.print("\n[bold]Line-by-Line Explanation:[/bold]")
|
|
table = Table(show_header=True)
|
|
table.add_column("Line", style="dim", width=5)
|
|
table.add_column("Content", style="cyan", max_width=50)
|
|
table.add_column("Explanation", style="green")
|
|
|
|
for exp in result.line_explanations:
|
|
if exp.is_command:
|
|
table.add_row(
|
|
str(exp.line_number),
|
|
exp.content[:50],
|
|
exp.explanation,
|
|
)
|
|
|
|
console.print(table)
|
|
|
|
|
|
@main.command()
|
|
@click.argument("script_path", type=click.Path(exists=True))
|
|
@click.option(
|
|
"--suggest/--no-suggest",
|
|
default=True,
|
|
help="Include AI suggestions",
|
|
)
|
|
@click.option(
|
|
"--show-safe",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Show safer version of script",
|
|
)
|
|
@click.pass_context
|
|
def refactor(ctx: click.Context, script_path: str, suggest: bool, show_safe: bool):
|
|
"""Analyze and refactor shell script for security."""
|
|
with open(script_path, "r") as f:
|
|
script = f.read()
|
|
|
|
console.print(f"[cyan]Analyzing:[/cyan] {script_path}")
|
|
|
|
result = refactor_script(script, include_suggestions=suggest)
|
|
|
|
console.print(f"\n[bold]Security Score:[/bold] {result.score}/100")
|
|
|
|
if result.issues:
|
|
console.print("\n[bold yellow]Issues Found:[/bold yellow]")
|
|
for issue in result.issues:
|
|
console.print(
|
|
f" Line {issue.line_number}: [{issue.severity.upper()}] {issue.issue_type}"
|
|
)
|
|
console.print(f" Original: {issue.original}")
|
|
console.print(f" Risk: {issue.risk_assessment}")
|
|
console.print(f" Alternative: {issue.safer_alternative}")
|
|
else:
|
|
console.print("\n[bold green]No issues found![/bold green]")
|
|
|
|
if result.suggestions:
|
|
console.print("\n[bold]Suggestions:[/bold]")
|
|
for suggestion in result.suggestions[:5]:
|
|
console.print(f" - {suggestion}")
|
|
|
|
if show_safe:
|
|
console.print("\n[bold]Safer Version:[/bold]")
|
|
console.print(Panel(result.safer_script, style="green"))
|
|
|
|
|
|
@main.command()
|
|
@click.option("--limit", default=20, help="Number of entries to show")
|
|
@click.option("--popular", is_flag=True, help="Show most used entries")
|
|
@click.option("--clear", is_flag=True, help="Clear history")
|
|
def history(limit: int, popular: bool, clear: bool):
|
|
"""Manage command history."""
|
|
storage = get_history_storage()
|
|
|
|
if clear:
|
|
if click.confirm("Are you sure you want to clear all history?"):
|
|
storage.clear()
|
|
print_success("History cleared")
|
|
return
|
|
|
|
if popular:
|
|
entries = storage.get_popular(limit)
|
|
console.print("[bold]Most Used Commands:[/bold]")
|
|
else:
|
|
entries = storage.get_entries(limit=limit)
|
|
console.print(f"[bold]Recent History (last {limit}):[/bold]")
|
|
|
|
if not entries:
|
|
console.print("[dim]No history yet[/dim]")
|
|
return
|
|
|
|
table = Table(show_header=True)
|
|
table.add_column("Description", style="cyan", max_width=40)
|
|
table.add_column("Commands", style="green", max_width=30)
|
|
table.add_column("Shell", style="dim")
|
|
table.add_column("Used", style="dim")
|
|
|
|
for entry in entries:
|
|
cmd_preview = ", ".join(entry.commands[:2])
|
|
if len(entry.commands) > 2:
|
|
cmd_preview += "..."
|
|
table.add_row(
|
|
entry.description[:40],
|
|
cmd_preview[:30],
|
|
entry.shell_type,
|
|
str(entry.usage_count),
|
|
)
|
|
|
|
console.print(table)
|
|
|
|
|
|
@main.command()
|
|
def models():
|
|
"""List available Ollama models."""
|
|
client = get_ollama_client()
|
|
|
|
if not client.is_available():
|
|
print_error("Ollama is not available. Make sure it's running.")
|
|
return
|
|
|
|
models = client.list_models()
|
|
|
|
console.print("[bold]Available Models:[/bold]")
|
|
|
|
config = get_config()
|
|
current = config.ollama_model
|
|
|
|
for model in models:
|
|
marker = "[*]" if model == current else "[ ]"
|
|
console.print(f" {marker} {model}")
|
|
|
|
|
|
@main.command()
|
|
@click.pass_context
|
|
def interactive(ctx: click.Context):
|
|
"""Start interactive mode."""
|
|
print_header()
|
|
|
|
while True:
|
|
try:
|
|
choice = session.prompt(
|
|
"[bold cyan]ShellGenius[/bold cyan] > ",
|
|
completer=WordCompleter(["g", "e", "r", "h", "m", "q", "?"]),
|
|
).strip() or "?"
|
|
|
|
if choice in ["q", "quit", "exit"]:
|
|
console.print("[cyan]Goodbye![/cyan]")
|
|
break
|
|
elif choice == "?":
|
|
console.print(
|
|
Panel(
|
|
"[bold]Commands:[/bold]\n"
|
|
" g - Generate commands\n"
|
|
" e - Explain script\n"
|
|
" r - Refactor/Analyze script\n"
|
|
" h - View history\n"
|
|
" m - List models\n"
|
|
" q - Quit\n"
|
|
" ? - Show this help"
|
|
)
|
|
)
|
|
elif choice == "g":
|
|
desc = session.prompt("[cyan]Describe what you want:[/cyan]")
|
|
if desc:
|
|
ctx.invoke(
|
|
generate,
|
|
description=tuple(desc.split()),
|
|
shell="bash",
|
|
)
|
|
elif choice == "e":
|
|
path = session.prompt("[cyan]Path to script:[/cyan]")
|
|
if path and os.path.exists(path):
|
|
ctx.invoke(explain, script_path=path)
|
|
elif choice == "r":
|
|
path = session.prompt("[cyan]Path to script:[/cyan]")
|
|
if path and os.path.exists(path):
|
|
ctx.invoke(refactor, script_path=path, show_safe=True)
|
|
elif choice == "h":
|
|
ctx.invoke(history, limit=10)
|
|
elif choice == "m":
|
|
ctx.invoke(models)
|
|
|
|
except KeyboardInterrupt:
|
|
console.print("\n[cyan]Use 'q' to quit[/cyan]")
|
|
except Exception as e:
|
|
print_error(str(e))
|
|
|
|
|
|
@main.command()
|
|
def version():
|
|
"""Show version information."""
|
|
from shellgenius import __version__
|
|
|
|
console.print(f"ShellGenius v{__version__}")
|
|
|
|
|
|
@main.command()
|
|
def check():
|
|
"""Check system requirements."""
|
|
console.print("[bold]System Check:[/bold]\n")
|
|
|
|
config_ok = True
|
|
ollama_ok = False
|
|
|
|
try:
|
|
config = get_config()
|
|
console.print(f"[green]✓[/green] Config loaded from: {config.config_path}")
|
|
except Exception as e:
|
|
console.print(f"[red]✗[/red] Config error: {e}")
|
|
config_ok = False
|
|
|
|
try:
|
|
client = get_ollama_client()
|
|
if client.is_available():
|
|
ollama_ok = True
|
|
console.print(f"[green]✓[/green] Ollama connected: {client.host}")
|
|
console.print(f"[green]✓[/green] Current model: {client.model}")
|
|
models = client.list_models()
|
|
console.print(f"[green]✓[/green] Available models: {len(models)}")
|
|
else:
|
|
console.print("[red]✗[/red] Ollama not reachable")
|
|
except Exception as e:
|
|
console.print(f"[red]✗[/red] Ollama error: {e}")
|
|
|
|
try:
|
|
storage = get_history_storage()
|
|
console.print(f"[green]✓[/green] History storage: {storage.storage_path}")
|
|
except Exception as e:
|
|
console.print(f"[yellow]![yellow] History storage warning: {e}")
|
|
|
|
if config_ok and ollama_ok:
|
|
console.print("\n[bold green]System ready![/bold green]")
|
|
else:
|
|
console.print("\n[bold yellow]Some issues detected[/bold yellow]")
|