diff --git a/shellgenius/generation.py b/shellgenius/generation.py new file mode 100644 index 0000000..6304c84 --- /dev/null +++ b/shellgenius/generation.py @@ -0,0 +1,347 @@ +"""Shell generation engine for ShellGenius.""" + +import logging +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from shellgenius.config import get_config +from shellgenius.ollama_client import get_ollama_client + +logger = logging.getLogger(__name__) + + +@dataclass +class GeneratedScript: + """Container for generated shell script.""" + + commands: List[str] + explanation: str + shell_type: str + raw_response: str + + +class PromptTemplates: + """Prompt templates for shell generation.""" + + BASH_GENERATE = """You are a shell script expert. Generate a {shell} script for the following task: + +Task: {description} + +Requirements: +- Use safe, best practices +- Add comments explaining each command +- Use proper error handling with set -e or similar +- Return ONLY the script content, no markdown code blocks + +Script:""" + + BASH_EXPLAIN = """You are a shell script expert. Explain this {shell} script line by line: + +{script} + +Provide explanations in this format: +Line : - + +Be concise and informative:""" + + BASH_REFACTOR = """You are a shell script security expert. Analyze this {shell} script for security issues and suggest safer alternatives: + +{script} + +For each issue found: +1. Identify the risky command or pattern +2. Explain the risk +3. Provide a safer alternative +4. Suggest improvement + +Format your response:""" + + SHELL_TYPES = ["bash", "zsh", "sh"] + + @classmethod + def get_generate_prompt( + cls, description: str, shell_type: str = "bash" + ) -> str: + """Get generation prompt for shell type.""" + return cls.BASH_GENERATE.format( + description=description, shell=shell_type + ) + + @classmethod + def get_explain_prompt(cls, script: str, shell_type: str = "bash") -> str: + """Get explanation prompt for script.""" + return cls.BASH_EXPLAIN.format(script=script, shell=shell_type) + + @classmethod + def get_refactor_prompt(cls, script: str, shell_type: str = "bash") -> str: + """Get refactor prompt for script.""" + return cls.BASH_REFACTOR.format(script=script, shell=shell_type) + + +class ShellParser: + """Parser for shell scripts.""" + + LINE_PATTERN = re.compile(r"^(?:(\s*)(.*?)(\s*))$", re.MULTILINE) + COMMENT_PATTERN = re.compile(r"^\s*#") + SHEBANG_PATTERN = re.compile(r"^#!(.+)$", re.MULTILINE) + + @classmethod + def detect_shell(cls, script: str) -> str: + """Detect shell type from shebang. + + Args: + script: Shell script content + + Returns: + Detected shell type + """ + shebang_match = cls.SHEBANG_PATTERN.match(script) + if shebang_match: + shebang = shebang_match.group(1) + if "zsh" in shebang: + return "zsh" + elif "bash" in shebang: + return "bash" + elif "sh" in shebang: + return "sh" + return "bash" + + @classmethod + def parse_lines(cls, script: str) -> List[Tuple[int, str]]: + """Parse script into lines. + + Args: + script: Shell script content + + Returns: + List of (line_number, line_content) tuples + """ + lines = script.split("\n") + return [(i + 1, line) for i, line in enumerate(lines)] + + @classmethod + def extract_commands(cls, script: str) -> List[str]: + """Extract executable commands from script. + + Args: + script: Shell script content + + Returns: + List of commands + """ + commands = [] + for _, line in cls.parse_lines(script): + stripped = line.strip() + if stripped and not stripped.startswith("#") and not stripped.startswith(":"): + commands.append(stripped) + return commands + + +class ShellGenerator: + """Shell generation engine.""" + + def __init__(self): + """Initialize generator.""" + self.client = get_ollama_client() + self.parser = ShellParser() + self.templates = PromptTemplates + + def generate( + self, + description: str, + shell_type: str = "bash", + model: Optional[str] = None, + ) -> GeneratedScript: + """Generate shell script from description. + + Args: + description: Natural language description + shell_type: Target shell type + model: Ollama model to use + + Returns: + GeneratedScript with commands and explanation + """ + prompt = self.templates.get_generate_prompt(description, shell_type) + + result = self.client.generate(prompt, model=model) + + if result["success"]: + raw = result["response"].get("response", "") + return self._parse_generation(raw, shell_type) + else: + return GeneratedScript( + commands=[], + explanation=f"Generation failed: {result.get('error', 'Unknown error')}", + shell_type=shell_type, + raw_response="", + ) + + def _parse_generation(self, raw: str, shell_type: str) -> GeneratedScript: + """Parse raw LLM response. + + Args: + raw: Raw response text + shell_type: Shell type + + Returns: + GeneratedScript + """ + commands = self._extract_commands(raw) + explanation = self._generate_summary(commands, raw) + return GeneratedScript( + commands=commands, + explanation=explanation, + shell_type=shell_type, + raw_response=raw, + ) + + def _extract_commands(self, raw: str) -> List[str]: + """Extract commands from raw response. + + Args: + raw: Raw LLM response + + Returns: + List of commands + """ + lines = raw.strip().split("\n") + commands = [] + for line in lines: + line = line.strip() + if line and not line.startswith("#"): + commands.append(line) + return commands + + def _generate_summary( + self, commands: List[str], raw: str + ) -> str: + """Generate explanation summary. + + Args: + commands: List of commands + raw: Raw response + + Returns: + Explanation summary + """ + return f"Generated {len(commands)} command(s) for your request." + + +class ShellSafetyChecker: + """Safety checker for shell commands.""" + + DANGEROUS_PATTERNS = [ + (r"rm\s+-rf\s+/", "Removes all files recursively - catastrophic data loss"), + (r":\(\)\s*\{\s*:\s*\|\s*:\s&\s*;\s*\}", "Fork bomb - can crash system"), + (r"chmod\s+777", "Removes all file permissions - security risk"), + (r"sudo\s+su", "Escalates to root without proper controls"), + (r"dd\s+if=/dev/zero", "Can overwrite disks if misconfigured"), + (r">\s*/dev/sda", "Direct disk write - destructive"), + (r"mv\s+.*\s+/dev/null", "Sends data to null device - data loss"), + (r"cat\s+.*\|\s*sh", "Executes arbitrary shell code"), + (r"curl.*\|\s*sh", "Downloads and executes code - security risk"), + (r"wget.*\|\s*sh", "Downloads and executes code - security risk"), + ] + + WARNING_PATTERNS = [ + (r"rm\s+-r", "Recursive removal"), + (r"rm\s+[^-]", "File removal"), + (r"chmod\s+[0-7][0-7][0-7]", "Permission change"), + (r"sudo", "Requires sudo privileges"), + (r">\s*\S+", "File redirection"), + (r"\|\s*bash", "Pipe to shell"), + ] + + def __init__(self, safety_level: str = "moderate"): + """Initialize safety checker. + + Args: + safety_level: strict, moderate, or permissive + """ + self.safety_level = safety_level + self.config = get_config() + + def check_command(self, command: str) -> Tuple[bool, List[str]]: + """Check if command is dangerous. + + Args: + command: Shell command to check + + Returns: + Tuple of (is_safe, list of warnings) + """ + warnings = [] + is_safe = True + + for pattern, message in self.DANGEROUS_PATTERNS: + if re.search(pattern, command): + if self.safety_level in ["strict", "moderate"]: + is_safe = False + warnings.append(f"DANGEROUS: {message}") + + for pattern, message in self.WARNING_PATTERNS: + if re.search(pattern, command): + if self.safety_level == "strict": + is_safe = False + warnings.append(f"WARNING: {message}") + + if self.safety_level == "strict" and not warnings: + is_safe = True + + return is_safe, warnings + + def check_script(self, script: str) -> Dict[str, Any]: + """Check entire script for safety issues. + + Args: + script: Shell script content + + Returns: + Dictionary with safety assessment + """ + lines = script.split("\n") + issues = [] + warnings = [] + safe_lines = [] + + for i, line in enumerate(lines, 1): + stripped = line.strip() + if stripped and not stripped.startswith("#"): + is_safe, line_warnings = self.check_command(stripped) + if not is_safe: + issues.append({"line": i, "command": stripped, "warnings": line_warnings}) + else: + safe_lines.append(i) + warnings.extend(line_warnings) + + return { + "is_safe": len(issues) == 0, + "issues": issues, + "warnings": warnings, + "safe_lines": safe_lines, + "total_lines": len([l for l in lines if l.strip() and not l.strip().startswith("#")]), + } + + +def generate_shell( + description: str, + shell_type: str = "bash", + safety_level: Optional[str] = None, +) -> GeneratedScript: + """Convenience function to generate shell script. + + Args: + description: Natural language description + shell_type: Target shell type + safety_level: Safety level override + + Returns: + GeneratedScript + """ + generator = ShellGenerator() + if safety_level: + config = get_config() + config.config["safety"]["level"] = safety_level + return generator.generate(description, shell_type)