Add shell generation engine
This commit is contained in:
347
shellgenius/generation.py
Normal file
347
shellgenius/generation.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""Shell generation engine for ShellGenius."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from shellgenius.config import get_config
|
||||||
|
from shellgenius.ollama_client import get_ollama_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneratedScript:
|
||||||
|
"""Container for generated shell script."""
|
||||||
|
|
||||||
|
commands: List[str]
|
||||||
|
explanation: str
|
||||||
|
shell_type: str
|
||||||
|
raw_response: str
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplates:
|
||||||
|
"""Prompt templates for shell generation."""
|
||||||
|
|
||||||
|
BASH_GENERATE = """You are a shell script expert. Generate a {shell} script for the following task:
|
||||||
|
|
||||||
|
Task: {description}
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Use safe, best practices
|
||||||
|
- Add comments explaining each command
|
||||||
|
- Use proper error handling with set -e or similar
|
||||||
|
- Return ONLY the script content, no markdown code blocks
|
||||||
|
|
||||||
|
Script:"""
|
||||||
|
|
||||||
|
BASH_EXPLAIN = """You are a shell script expert. Explain this {shell} script line by line:
|
||||||
|
|
||||||
|
{script}
|
||||||
|
|
||||||
|
Provide explanations in this format:
|
||||||
|
Line <num>: <command> - <explanation>
|
||||||
|
|
||||||
|
Be concise and informative:"""
|
||||||
|
|
||||||
|
BASH_REFACTOR = """You are a shell script security expert. Analyze this {shell} script for security issues and suggest safer alternatives:
|
||||||
|
|
||||||
|
{script}
|
||||||
|
|
||||||
|
For each issue found:
|
||||||
|
1. Identify the risky command or pattern
|
||||||
|
2. Explain the risk
|
||||||
|
3. Provide a safer alternative
|
||||||
|
4. Suggest improvement
|
||||||
|
|
||||||
|
Format your response:"""
|
||||||
|
|
||||||
|
SHELL_TYPES = ["bash", "zsh", "sh"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_generate_prompt(
|
||||||
|
cls, description: str, shell_type: str = "bash"
|
||||||
|
) -> str:
|
||||||
|
"""Get generation prompt for shell type."""
|
||||||
|
return cls.BASH_GENERATE.format(
|
||||||
|
description=description, shell=shell_type
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_explain_prompt(cls, script: str, shell_type: str = "bash") -> str:
|
||||||
|
"""Get explanation prompt for script."""
|
||||||
|
return cls.BASH_EXPLAIN.format(script=script, shell=shell_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_refactor_prompt(cls, script: str, shell_type: str = "bash") -> str:
|
||||||
|
"""Get refactor prompt for script."""
|
||||||
|
return cls.BASH_REFACTOR.format(script=script, shell=shell_type)
|
||||||
|
|
||||||
|
|
||||||
|
class ShellParser:
|
||||||
|
"""Parser for shell scripts."""
|
||||||
|
|
||||||
|
LINE_PATTERN = re.compile(r"^(?:(\s*)(.*?)(\s*))$", re.MULTILINE)
|
||||||
|
COMMENT_PATTERN = re.compile(r"^\s*#")
|
||||||
|
SHEBANG_PATTERN = re.compile(r"^#!(.+)$", re.MULTILINE)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_shell(cls, script: str) -> str:
|
||||||
|
"""Detect shell type from shebang.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: Shell script content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detected shell type
|
||||||
|
"""
|
||||||
|
shebang_match = cls.SHEBANG_PATTERN.match(script)
|
||||||
|
if shebang_match:
|
||||||
|
shebang = shebang_match.group(1)
|
||||||
|
if "zsh" in shebang:
|
||||||
|
return "zsh"
|
||||||
|
elif "bash" in shebang:
|
||||||
|
return "bash"
|
||||||
|
elif "sh" in shebang:
|
||||||
|
return "sh"
|
||||||
|
return "bash"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_lines(cls, script: str) -> List[Tuple[int, str]]:
|
||||||
|
"""Parse script into lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: Shell script content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (line_number, line_content) tuples
|
||||||
|
"""
|
||||||
|
lines = script.split("\n")
|
||||||
|
return [(i + 1, line) for i, line in enumerate(lines)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_commands(cls, script: str) -> List[str]:
|
||||||
|
"""Extract executable commands from script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: Shell script content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of commands
|
||||||
|
"""
|
||||||
|
commands = []
|
||||||
|
for _, line in cls.parse_lines(script):
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped and not stripped.startswith("#") and not stripped.startswith(":"):
|
||||||
|
commands.append(stripped)
|
||||||
|
return commands
|
||||||
|
|
||||||
|
|
||||||
|
class ShellGenerator:
|
||||||
|
"""Shell generation engine."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize generator."""
|
||||||
|
self.client = get_ollama_client()
|
||||||
|
self.parser = ShellParser()
|
||||||
|
self.templates = PromptTemplates
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
description: str,
|
||||||
|
shell_type: str = "bash",
|
||||||
|
model: Optional[str] = None,
|
||||||
|
) -> GeneratedScript:
|
||||||
|
"""Generate shell script from description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language description
|
||||||
|
shell_type: Target shell type
|
||||||
|
model: Ollama model to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeneratedScript with commands and explanation
|
||||||
|
"""
|
||||||
|
prompt = self.templates.get_generate_prompt(description, shell_type)
|
||||||
|
|
||||||
|
result = self.client.generate(prompt, model=model)
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
raw = result["response"].get("response", "")
|
||||||
|
return self._parse_generation(raw, shell_type)
|
||||||
|
else:
|
||||||
|
return GeneratedScript(
|
||||||
|
commands=[],
|
||||||
|
explanation=f"Generation failed: {result.get('error', 'Unknown error')}",
|
||||||
|
shell_type=shell_type,
|
||||||
|
raw_response="",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_generation(self, raw: str, shell_type: str) -> GeneratedScript:
|
||||||
|
"""Parse raw LLM response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw: Raw response text
|
||||||
|
shell_type: Shell type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeneratedScript
|
||||||
|
"""
|
||||||
|
commands = self._extract_commands(raw)
|
||||||
|
explanation = self._generate_summary(commands, raw)
|
||||||
|
return GeneratedScript(
|
||||||
|
commands=commands,
|
||||||
|
explanation=explanation,
|
||||||
|
shell_type=shell_type,
|
||||||
|
raw_response=raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_commands(self, raw: str) -> List[str]:
|
||||||
|
"""Extract commands from raw response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw: Raw LLM response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of commands
|
||||||
|
"""
|
||||||
|
lines = raw.strip().split("\n")
|
||||||
|
commands = []
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#"):
|
||||||
|
commands.append(line)
|
||||||
|
return commands
|
||||||
|
|
||||||
|
def _generate_summary(
|
||||||
|
self, commands: List[str], raw: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate explanation summary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commands: List of commands
|
||||||
|
raw: Raw response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Explanation summary
|
||||||
|
"""
|
||||||
|
return f"Generated {len(commands)} command(s) for your request."
|
||||||
|
|
||||||
|
|
||||||
|
class ShellSafetyChecker:
|
||||||
|
"""Safety checker for shell commands."""
|
||||||
|
|
||||||
|
DANGEROUS_PATTERNS = [
|
||||||
|
(r"rm\s+-rf\s+/", "Removes all files recursively - catastrophic data loss"),
|
||||||
|
(r":\(\)\s*\{\s*:\s*\|\s*:\s&\s*;\s*\}", "Fork bomb - can crash system"),
|
||||||
|
(r"chmod\s+777", "Removes all file permissions - security risk"),
|
||||||
|
(r"sudo\s+su", "Escalates to root without proper controls"),
|
||||||
|
(r"dd\s+if=/dev/zero", "Can overwrite disks if misconfigured"),
|
||||||
|
(r">\s*/dev/sda", "Direct disk write - destructive"),
|
||||||
|
(r"mv\s+.*\s+/dev/null", "Sends data to null device - data loss"),
|
||||||
|
(r"cat\s+.*\|\s*sh", "Executes arbitrary shell code"),
|
||||||
|
(r"curl.*\|\s*sh", "Downloads and executes code - security risk"),
|
||||||
|
(r"wget.*\|\s*sh", "Downloads and executes code - security risk"),
|
||||||
|
]
|
||||||
|
|
||||||
|
WARNING_PATTERNS = [
|
||||||
|
(r"rm\s+-r", "Recursive removal"),
|
||||||
|
(r"rm\s+[^-]", "File removal"),
|
||||||
|
(r"chmod\s+[0-7][0-7][0-7]", "Permission change"),
|
||||||
|
(r"sudo", "Requires sudo privileges"),
|
||||||
|
(r">\s*\S+", "File redirection"),
|
||||||
|
(r"\|\s*bash", "Pipe to shell"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, safety_level: str = "moderate"):
|
||||||
|
"""Initialize safety checker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
safety_level: strict, moderate, or permissive
|
||||||
|
"""
|
||||||
|
self.safety_level = safety_level
|
||||||
|
self.config = get_config()
|
||||||
|
|
||||||
|
def check_command(self, command: str) -> Tuple[bool, List[str]]:
|
||||||
|
"""Check if command is dangerous.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Shell command to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, list of warnings)
|
||||||
|
"""
|
||||||
|
warnings = []
|
||||||
|
is_safe = True
|
||||||
|
|
||||||
|
for pattern, message in self.DANGEROUS_PATTERNS:
|
||||||
|
if re.search(pattern, command):
|
||||||
|
if self.safety_level in ["strict", "moderate"]:
|
||||||
|
is_safe = False
|
||||||
|
warnings.append(f"DANGEROUS: {message}")
|
||||||
|
|
||||||
|
for pattern, message in self.WARNING_PATTERNS:
|
||||||
|
if re.search(pattern, command):
|
||||||
|
if self.safety_level == "strict":
|
||||||
|
is_safe = False
|
||||||
|
warnings.append(f"WARNING: {message}")
|
||||||
|
|
||||||
|
if self.safety_level == "strict" and not warnings:
|
||||||
|
is_safe = True
|
||||||
|
|
||||||
|
return is_safe, warnings
|
||||||
|
|
||||||
|
def check_script(self, script: str) -> Dict[str, Any]:
|
||||||
|
"""Check entire script for safety issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: Shell script content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with safety assessment
|
||||||
|
"""
|
||||||
|
lines = script.split("\n")
|
||||||
|
issues = []
|
||||||
|
warnings = []
|
||||||
|
safe_lines = []
|
||||||
|
|
||||||
|
for i, line in enumerate(lines, 1):
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped and not stripped.startswith("#"):
|
||||||
|
is_safe, line_warnings = self.check_command(stripped)
|
||||||
|
if not is_safe:
|
||||||
|
issues.append({"line": i, "command": stripped, "warnings": line_warnings})
|
||||||
|
else:
|
||||||
|
safe_lines.append(i)
|
||||||
|
warnings.extend(line_warnings)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"is_safe": len(issues) == 0,
|
||||||
|
"issues": issues,
|
||||||
|
"warnings": warnings,
|
||||||
|
"safe_lines": safe_lines,
|
||||||
|
"total_lines": len([l for l in lines if l.strip() and not l.strip().startswith("#")]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_shell(
|
||||||
|
description: str,
|
||||||
|
shell_type: str = "bash",
|
||||||
|
safety_level: Optional[str] = None,
|
||||||
|
) -> GeneratedScript:
|
||||||
|
"""Convenience function to generate shell script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language description
|
||||||
|
shell_type: Target shell type
|
||||||
|
safety_level: Safety level override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeneratedScript
|
||||||
|
"""
|
||||||
|
generator = ShellGenerator()
|
||||||
|
if safety_level:
|
||||||
|
config = get_config()
|
||||||
|
config.config["safety"]["level"] = safety_level
|
||||||
|
return generator.generate(description, shell_type)
|
||||||
Reference in New Issue
Block a user