Add shell generation engine
Some checks failed
CI / test (push) Failing after 9s
CI / lint (push) Failing after 5s
CI / type-check (push) Failing after 13s

This commit is contained in:
2026-02-04 10:59:36 +00:00
parent bf0bf36e66
commit 9c2fb72bb9

347
shellgenius/generation.py Normal file
View File

@@ -0,0 +1,347 @@
"""Shell generation engine for ShellGenius."""
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from shellgenius.config import get_config
from shellgenius.ollama_client import get_ollama_client
logger = logging.getLogger(__name__)
@dataclass
class GeneratedScript:
"""Container for generated shell script."""
commands: List[str]
explanation: str
shell_type: str
raw_response: str
class PromptTemplates:
"""Prompt templates for shell generation."""
BASH_GENERATE = """You are a shell script expert. Generate a {shell} script for the following task:
Task: {description}
Requirements:
- Use safe, best practices
- Add comments explaining each command
- Use proper error handling with set -e or similar
- Return ONLY the script content, no markdown code blocks
Script:"""
BASH_EXPLAIN = """You are a shell script expert. Explain this {shell} script line by line:
{script}
Provide explanations in this format:
Line <num>: <command> - <explanation>
Be concise and informative:"""
BASH_REFACTOR = """You are a shell script security expert. Analyze this {shell} script for security issues and suggest safer alternatives:
{script}
For each issue found:
1. Identify the risky command or pattern
2. Explain the risk
3. Provide a safer alternative
4. Suggest improvement
Format your response:"""
SHELL_TYPES = ["bash", "zsh", "sh"]
@classmethod
def get_generate_prompt(
cls, description: str, shell_type: str = "bash"
) -> str:
"""Get generation prompt for shell type."""
return cls.BASH_GENERATE.format(
description=description, shell=shell_type
)
@classmethod
def get_explain_prompt(cls, script: str, shell_type: str = "bash") -> str:
"""Get explanation prompt for script."""
return cls.BASH_EXPLAIN.format(script=script, shell=shell_type)
@classmethod
def get_refactor_prompt(cls, script: str, shell_type: str = "bash") -> str:
"""Get refactor prompt for script."""
return cls.BASH_REFACTOR.format(script=script, shell=shell_type)
class ShellParser:
"""Parser for shell scripts."""
LINE_PATTERN = re.compile(r"^(?:(\s*)(.*?)(\s*))$", re.MULTILINE)
COMMENT_PATTERN = re.compile(r"^\s*#")
SHEBANG_PATTERN = re.compile(r"^#!(.+)$", re.MULTILINE)
@classmethod
def detect_shell(cls, script: str) -> str:
"""Detect shell type from shebang.
Args:
script: Shell script content
Returns:
Detected shell type
"""
shebang_match = cls.SHEBANG_PATTERN.match(script)
if shebang_match:
shebang = shebang_match.group(1)
if "zsh" in shebang:
return "zsh"
elif "bash" in shebang:
return "bash"
elif "sh" in shebang:
return "sh"
return "bash"
@classmethod
def parse_lines(cls, script: str) -> List[Tuple[int, str]]:
"""Parse script into lines.
Args:
script: Shell script content
Returns:
List of (line_number, line_content) tuples
"""
lines = script.split("\n")
return [(i + 1, line) for i, line in enumerate(lines)]
@classmethod
def extract_commands(cls, script: str) -> List[str]:
"""Extract executable commands from script.
Args:
script: Shell script content
Returns:
List of commands
"""
commands = []
for _, line in cls.parse_lines(script):
stripped = line.strip()
if stripped and not stripped.startswith("#") and not stripped.startswith(":"):
commands.append(stripped)
return commands
class ShellGenerator:
"""Shell generation engine."""
def __init__(self):
"""Initialize generator."""
self.client = get_ollama_client()
self.parser = ShellParser()
self.templates = PromptTemplates
def generate(
self,
description: str,
shell_type: str = "bash",
model: Optional[str] = None,
) -> GeneratedScript:
"""Generate shell script from description.
Args:
description: Natural language description
shell_type: Target shell type
model: Ollama model to use
Returns:
GeneratedScript with commands and explanation
"""
prompt = self.templates.get_generate_prompt(description, shell_type)
result = self.client.generate(prompt, model=model)
if result["success"]:
raw = result["response"].get("response", "")
return self._parse_generation(raw, shell_type)
else:
return GeneratedScript(
commands=[],
explanation=f"Generation failed: {result.get('error', 'Unknown error')}",
shell_type=shell_type,
raw_response="",
)
def _parse_generation(self, raw: str, shell_type: str) -> GeneratedScript:
"""Parse raw LLM response.
Args:
raw: Raw response text
shell_type: Shell type
Returns:
GeneratedScript
"""
commands = self._extract_commands(raw)
explanation = self._generate_summary(commands, raw)
return GeneratedScript(
commands=commands,
explanation=explanation,
shell_type=shell_type,
raw_response=raw,
)
def _extract_commands(self, raw: str) -> List[str]:
"""Extract commands from raw response.
Args:
raw: Raw LLM response
Returns:
List of commands
"""
lines = raw.strip().split("\n")
commands = []
for line in lines:
line = line.strip()
if line and not line.startswith("#"):
commands.append(line)
return commands
def _generate_summary(
self, commands: List[str], raw: str
) -> str:
"""Generate explanation summary.
Args:
commands: List of commands
raw: Raw response
Returns:
Explanation summary
"""
return f"Generated {len(commands)} command(s) for your request."
class ShellSafetyChecker:
"""Safety checker for shell commands."""
DANGEROUS_PATTERNS = [
(r"rm\s+-rf\s+/", "Removes all files recursively - catastrophic data loss"),
(r":\(\)\s*\{\s*:\s*\|\s*:\s&\s*;\s*\}", "Fork bomb - can crash system"),
(r"chmod\s+777", "Removes all file permissions - security risk"),
(r"sudo\s+su", "Escalates to root without proper controls"),
(r"dd\s+if=/dev/zero", "Can overwrite disks if misconfigured"),
(r">\s*/dev/sda", "Direct disk write - destructive"),
(r"mv\s+.*\s+/dev/null", "Sends data to null device - data loss"),
(r"cat\s+.*\|\s*sh", "Executes arbitrary shell code"),
(r"curl.*\|\s*sh", "Downloads and executes code - security risk"),
(r"wget.*\|\s*sh", "Downloads and executes code - security risk"),
]
WARNING_PATTERNS = [
(r"rm\s+-r", "Recursive removal"),
(r"rm\s+[^-]", "File removal"),
(r"chmod\s+[0-7][0-7][0-7]", "Permission change"),
(r"sudo", "Requires sudo privileges"),
(r">\s*\S+", "File redirection"),
(r"\|\s*bash", "Pipe to shell"),
]
def __init__(self, safety_level: str = "moderate"):
"""Initialize safety checker.
Args:
safety_level: strict, moderate, or permissive
"""
self.safety_level = safety_level
self.config = get_config()
def check_command(self, command: str) -> Tuple[bool, List[str]]:
"""Check if command is dangerous.
Args:
command: Shell command to check
Returns:
Tuple of (is_safe, list of warnings)
"""
warnings = []
is_safe = True
for pattern, message in self.DANGEROUS_PATTERNS:
if re.search(pattern, command):
if self.safety_level in ["strict", "moderate"]:
is_safe = False
warnings.append(f"DANGEROUS: {message}")
for pattern, message in self.WARNING_PATTERNS:
if re.search(pattern, command):
if self.safety_level == "strict":
is_safe = False
warnings.append(f"WARNING: {message}")
if self.safety_level == "strict" and not warnings:
is_safe = True
return is_safe, warnings
def check_script(self, script: str) -> Dict[str, Any]:
"""Check entire script for safety issues.
Args:
script: Shell script content
Returns:
Dictionary with safety assessment
"""
lines = script.split("\n")
issues = []
warnings = []
safe_lines = []
for i, line in enumerate(lines, 1):
stripped = line.strip()
if stripped and not stripped.startswith("#"):
is_safe, line_warnings = self.check_command(stripped)
if not is_safe:
issues.append({"line": i, "command": stripped, "warnings": line_warnings})
else:
safe_lines.append(i)
warnings.extend(line_warnings)
return {
"is_safe": len(issues) == 0,
"issues": issues,
"warnings": warnings,
"safe_lines": safe_lines,
"total_lines": len([l for l in lines if l.strip() and not l.strip().startswith("#")]),
}
def generate_shell(
description: str,
shell_type: str = "bash",
safety_level: Optional[str] = None,
) -> GeneratedScript:
"""Convenience function to generate shell script.
Args:
description: Natural language description
shell_type: Target shell type
safety_level: Safety level override
Returns:
GeneratedScript
"""
generator = ShellGenerator()
if safety_level:
config = get_config()
config.config["safety"]["level"] = safety_level
return generator.generate(description, shell_type)