Add shell generation engine
This commit is contained in:
347
shellgenius/generation.py
Normal file
347
shellgenius/generation.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Shell generation engine for ShellGenius."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from shellgenius.config import get_config
|
||||
from shellgenius.ollama_client import get_ollama_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedScript:
|
||||
"""Container for generated shell script."""
|
||||
|
||||
commands: List[str]
|
||||
explanation: str
|
||||
shell_type: str
|
||||
raw_response: str
|
||||
|
||||
|
||||
class PromptTemplates:
|
||||
"""Prompt templates for shell generation."""
|
||||
|
||||
BASH_GENERATE = """You are a shell script expert. Generate a {shell} script for the following task:
|
||||
|
||||
Task: {description}
|
||||
|
||||
Requirements:
|
||||
- Use safe, best practices
|
||||
- Add comments explaining each command
|
||||
- Use proper error handling with set -e or similar
|
||||
- Return ONLY the script content, no markdown code blocks
|
||||
|
||||
Script:"""
|
||||
|
||||
BASH_EXPLAIN = """You are a shell script expert. Explain this {shell} script line by line:
|
||||
|
||||
{script}
|
||||
|
||||
Provide explanations in this format:
|
||||
Line <num>: <command> - <explanation>
|
||||
|
||||
Be concise and informative:"""
|
||||
|
||||
BASH_REFACTOR = """You are a shell script security expert. Analyze this {shell} script for security issues and suggest safer alternatives:
|
||||
|
||||
{script}
|
||||
|
||||
For each issue found:
|
||||
1. Identify the risky command or pattern
|
||||
2. Explain the risk
|
||||
3. Provide a safer alternative
|
||||
4. Suggest improvement
|
||||
|
||||
Format your response:"""
|
||||
|
||||
SHELL_TYPES = ["bash", "zsh", "sh"]
|
||||
|
||||
@classmethod
|
||||
def get_generate_prompt(
|
||||
cls, description: str, shell_type: str = "bash"
|
||||
) -> str:
|
||||
"""Get generation prompt for shell type."""
|
||||
return cls.BASH_GENERATE.format(
|
||||
description=description, shell=shell_type
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_explain_prompt(cls, script: str, shell_type: str = "bash") -> str:
|
||||
"""Get explanation prompt for script."""
|
||||
return cls.BASH_EXPLAIN.format(script=script, shell=shell_type)
|
||||
|
||||
@classmethod
|
||||
def get_refactor_prompt(cls, script: str, shell_type: str = "bash") -> str:
|
||||
"""Get refactor prompt for script."""
|
||||
return cls.BASH_REFACTOR.format(script=script, shell=shell_type)
|
||||
|
||||
|
||||
class ShellParser:
|
||||
"""Parser for shell scripts."""
|
||||
|
||||
LINE_PATTERN = re.compile(r"^(?:(\s*)(.*?)(\s*))$", re.MULTILINE)
|
||||
COMMENT_PATTERN = re.compile(r"^\s*#")
|
||||
SHEBANG_PATTERN = re.compile(r"^#!(.+)$", re.MULTILINE)
|
||||
|
||||
@classmethod
|
||||
def detect_shell(cls, script: str) -> str:
|
||||
"""Detect shell type from shebang.
|
||||
|
||||
Args:
|
||||
script: Shell script content
|
||||
|
||||
Returns:
|
||||
Detected shell type
|
||||
"""
|
||||
shebang_match = cls.SHEBANG_PATTERN.match(script)
|
||||
if shebang_match:
|
||||
shebang = shebang_match.group(1)
|
||||
if "zsh" in shebang:
|
||||
return "zsh"
|
||||
elif "bash" in shebang:
|
||||
return "bash"
|
||||
elif "sh" in shebang:
|
||||
return "sh"
|
||||
return "bash"
|
||||
|
||||
@classmethod
|
||||
def parse_lines(cls, script: str) -> List[Tuple[int, str]]:
|
||||
"""Parse script into lines.
|
||||
|
||||
Args:
|
||||
script: Shell script content
|
||||
|
||||
Returns:
|
||||
List of (line_number, line_content) tuples
|
||||
"""
|
||||
lines = script.split("\n")
|
||||
return [(i + 1, line) for i, line in enumerate(lines)]
|
||||
|
||||
@classmethod
|
||||
def extract_commands(cls, script: str) -> List[str]:
|
||||
"""Extract executable commands from script.
|
||||
|
||||
Args:
|
||||
script: Shell script content
|
||||
|
||||
Returns:
|
||||
List of commands
|
||||
"""
|
||||
commands = []
|
||||
for _, line in cls.parse_lines(script):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#") and not stripped.startswith(":"):
|
||||
commands.append(stripped)
|
||||
return commands
|
||||
|
||||
|
||||
class ShellGenerator:
|
||||
"""Shell generation engine."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize generator."""
|
||||
self.client = get_ollama_client()
|
||||
self.parser = ShellParser()
|
||||
self.templates = PromptTemplates
|
||||
|
||||
def generate(
|
||||
self,
|
||||
description: str,
|
||||
shell_type: str = "bash",
|
||||
model: Optional[str] = None,
|
||||
) -> GeneratedScript:
|
||||
"""Generate shell script from description.
|
||||
|
||||
Args:
|
||||
description: Natural language description
|
||||
shell_type: Target shell type
|
||||
model: Ollama model to use
|
||||
|
||||
Returns:
|
||||
GeneratedScript with commands and explanation
|
||||
"""
|
||||
prompt = self.templates.get_generate_prompt(description, shell_type)
|
||||
|
||||
result = self.client.generate(prompt, model=model)
|
||||
|
||||
if result["success"]:
|
||||
raw = result["response"].get("response", "")
|
||||
return self._parse_generation(raw, shell_type)
|
||||
else:
|
||||
return GeneratedScript(
|
||||
commands=[],
|
||||
explanation=f"Generation failed: {result.get('error', 'Unknown error')}",
|
||||
shell_type=shell_type,
|
||||
raw_response="",
|
||||
)
|
||||
|
||||
def _parse_generation(self, raw: str, shell_type: str) -> GeneratedScript:
|
||||
"""Parse raw LLM response.
|
||||
|
||||
Args:
|
||||
raw: Raw response text
|
||||
shell_type: Shell type
|
||||
|
||||
Returns:
|
||||
GeneratedScript
|
||||
"""
|
||||
commands = self._extract_commands(raw)
|
||||
explanation = self._generate_summary(commands, raw)
|
||||
return GeneratedScript(
|
||||
commands=commands,
|
||||
explanation=explanation,
|
||||
shell_type=shell_type,
|
||||
raw_response=raw,
|
||||
)
|
||||
|
||||
def _extract_commands(self, raw: str) -> List[str]:
|
||||
"""Extract commands from raw response.
|
||||
|
||||
Args:
|
||||
raw: Raw LLM response
|
||||
|
||||
Returns:
|
||||
List of commands
|
||||
"""
|
||||
lines = raw.strip().split("\n")
|
||||
commands = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
commands.append(line)
|
||||
return commands
|
||||
|
||||
def _generate_summary(
|
||||
self, commands: List[str], raw: str
|
||||
) -> str:
|
||||
"""Generate explanation summary.
|
||||
|
||||
Args:
|
||||
commands: List of commands
|
||||
raw: Raw response
|
||||
|
||||
Returns:
|
||||
Explanation summary
|
||||
"""
|
||||
return f"Generated {len(commands)} command(s) for your request."
|
||||
|
||||
|
||||
class ShellSafetyChecker:
|
||||
"""Safety checker for shell commands."""
|
||||
|
||||
DANGEROUS_PATTERNS = [
|
||||
(r"rm\s+-rf\s+/", "Removes all files recursively - catastrophic data loss"),
|
||||
(r":\(\)\s*\{\s*:\s*\|\s*:\s&\s*;\s*\}", "Fork bomb - can crash system"),
|
||||
(r"chmod\s+777", "Removes all file permissions - security risk"),
|
||||
(r"sudo\s+su", "Escalates to root without proper controls"),
|
||||
(r"dd\s+if=/dev/zero", "Can overwrite disks if misconfigured"),
|
||||
(r">\s*/dev/sda", "Direct disk write - destructive"),
|
||||
(r"mv\s+.*\s+/dev/null", "Sends data to null device - data loss"),
|
||||
(r"cat\s+.*\|\s*sh", "Executes arbitrary shell code"),
|
||||
(r"curl.*\|\s*sh", "Downloads and executes code - security risk"),
|
||||
(r"wget.*\|\s*sh", "Downloads and executes code - security risk"),
|
||||
]
|
||||
|
||||
WARNING_PATTERNS = [
|
||||
(r"rm\s+-r", "Recursive removal"),
|
||||
(r"rm\s+[^-]", "File removal"),
|
||||
(r"chmod\s+[0-7][0-7][0-7]", "Permission change"),
|
||||
(r"sudo", "Requires sudo privileges"),
|
||||
(r">\s*\S+", "File redirection"),
|
||||
(r"\|\s*bash", "Pipe to shell"),
|
||||
]
|
||||
|
||||
def __init__(self, safety_level: str = "moderate"):
|
||||
"""Initialize safety checker.
|
||||
|
||||
Args:
|
||||
safety_level: strict, moderate, or permissive
|
||||
"""
|
||||
self.safety_level = safety_level
|
||||
self.config = get_config()
|
||||
|
||||
def check_command(self, command: str) -> Tuple[bool, List[str]]:
|
||||
"""Check if command is dangerous.
|
||||
|
||||
Args:
|
||||
command: Shell command to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, list of warnings)
|
||||
"""
|
||||
warnings = []
|
||||
is_safe = True
|
||||
|
||||
for pattern, message in self.DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, command):
|
||||
if self.safety_level in ["strict", "moderate"]:
|
||||
is_safe = False
|
||||
warnings.append(f"DANGEROUS: {message}")
|
||||
|
||||
for pattern, message in self.WARNING_PATTERNS:
|
||||
if re.search(pattern, command):
|
||||
if self.safety_level == "strict":
|
||||
is_safe = False
|
||||
warnings.append(f"WARNING: {message}")
|
||||
|
||||
if self.safety_level == "strict" and not warnings:
|
||||
is_safe = True
|
||||
|
||||
return is_safe, warnings
|
||||
|
||||
def check_script(self, script: str) -> Dict[str, Any]:
|
||||
"""Check entire script for safety issues.
|
||||
|
||||
Args:
|
||||
script: Shell script content
|
||||
|
||||
Returns:
|
||||
Dictionary with safety assessment
|
||||
"""
|
||||
lines = script.split("\n")
|
||||
issues = []
|
||||
warnings = []
|
||||
safe_lines = []
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#"):
|
||||
is_safe, line_warnings = self.check_command(stripped)
|
||||
if not is_safe:
|
||||
issues.append({"line": i, "command": stripped, "warnings": line_warnings})
|
||||
else:
|
||||
safe_lines.append(i)
|
||||
warnings.extend(line_warnings)
|
||||
|
||||
return {
|
||||
"is_safe": len(issues) == 0,
|
||||
"issues": issues,
|
||||
"warnings": warnings,
|
||||
"safe_lines": safe_lines,
|
||||
"total_lines": len([l for l in lines if l.strip() and not l.strip().startswith("#")]),
|
||||
}
|
||||
|
||||
|
||||
def generate_shell(
|
||||
description: str,
|
||||
shell_type: str = "bash",
|
||||
safety_level: Optional[str] = None,
|
||||
) -> GeneratedScript:
|
||||
"""Convenience function to generate shell script.
|
||||
|
||||
Args:
|
||||
description: Natural language description
|
||||
shell_type: Target shell type
|
||||
safety_level: Safety level override
|
||||
|
||||
Returns:
|
||||
GeneratedScript
|
||||
"""
|
||||
generator = ShellGenerator()
|
||||
if safety_level:
|
||||
config = get_config()
|
||||
config.config["safety"]["level"] = safety_level
|
||||
return generator.generate(description, shell_type)
|
||||
Reference in New Issue
Block a user