diff --git a/shellgenius/refactoring.py b/shellgenius/refactoring.py new file mode 100644 index 0000000..8151f68 --- /dev/null +++ b/shellgenius/refactoring.py @@ -0,0 +1,408 @@ +"""Refactoring analyzer module for ShellGenius.""" + +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from shellgenius.config import get_config +from shellgenius.generation import ShellParser, get_ollama_client + + +@dataclass +class RefactoringIssue: + """A refactoring issue found in a script.""" + + line_number: int + original: str + issue_type: str + severity: str + description: str + risk_assessment: str + suggestion: str + safer_alternative: str + + +@dataclass +class RefactoringResult: + """Complete refactoring analysis result.""" + + shell_type: str + issues: List[RefactoringIssue] + score: int + suggestions: List[str] + safer_script: str + + +class SecurityRulesDB: + """Database of security rules for shell scripts.""" + + RULES = [ + { + "id": "CWE-78", + "name": "Shell Injection", + "pattern": r"(rm|mv|cp|chmod|chown)\s+.*\$", + "severity": "high", + "description": "Command injection via variable expansion", + "risk": "Could execute unintended commands if variable contains shell metacharacters", + "alternative": "Use quotes around variables: \"$var\" instead of $var", + }, + { + "id": "CWE-22", + "name": "Path Traversal", + "pattern": r"(cat|less|more|head|tail|rm|cp)\s+.*\.\./", + "severity": "medium", + "description": "Path traversal via ..", + "risk": "Could access files outside intended directory", + "alternative": "Use realpath() to resolve paths or validate input", + }, + { + "id": "CWE-377", + "name": "Insecure Temporary File", + "pattern": r"(cat|mkdir|touch)\s+.*\/tmp\/[^\s]+", + "severity": "medium", + "description": "Insecure temporary file usage", + "risk": "Race condition vulnerability (TOCTOU)", + "alternative": "Use mktemp or dedicated temp directory functions", + }, + { + "id": "CWE-退", + "name": "Eval with User Input", + "pattern": r"\beval\s+\$", + "severity": "critical", + "description": "Eval with variable input", + "risk": "Shell injection vulnerability", + "alternative": "Avoid eval, use direct variable references", + }, + { + "id": "CWE-294", + "name": "Insecure File Permissions", + "pattern": r"chmod\s+777", + "severity": "high", + "description": "World-writable permissions", + "risk": "Security vulnerability - any user can modify", + "alternative": "Use chmod 755 for directories, 644 for files", + }, + { + "id": "CWE-362", + "name": "Race Condition", + "pattern": r"(if|while)\s+.*-f\s+\$", + "severity": "medium", + "description": "TOCTOU race condition", + "risk": "File may change between check and use", + "alternative": "Use set -C (noclobber) or atomic operations", + }, + { + "id": "CWE-323", + "name": "Reusing UID/GID", + "pattern": r"useradd\s+[^-]", + "severity": "low", + "description": "User creation without system considerations", + "risk": "UID/GID conflicts possible", + "alternative": "Use useradd -M -r or adduser system tools", + }, + ] + + @classmethod + def get_rules(cls) -> List[Dict[str, Any]]: + """Get all security rules. + + Returns: + List of rule dictionaries + """ + return cls.RULES + + @classmethod + def check_rule(cls, line: str) -> Optional[Dict[str, Any]]: + """Check if line matches a security rule. + + Args: + line: Shell command line + + Returns: + Matching rule or None + """ + for rule in cls.RULES: + if re.search(rule["pattern"], line, re.IGNORECASE): + return rule + return None + + +class RefactoringAnalyzer: + """Shell script refactoring and security analyzer.""" + + BEST_PRACTICES = [ + "Use set -euo pipefail for error handling", + "Quote all variable expansions: \"$var\"", + "Use functions for code organization", + "Add shebang for portability", + "Use absolute paths when possible", + "Implement proper error handling", + "Avoid magic numbers, use variables", + "Add comments for complex logic", + ] + + IMPROVEMENTS = [ + "Replace backticks with $() for command substitution", + "Use [[ ]] instead of [ ] for conditionals", + "Replace deprecated syntax with modern alternatives", + "Add error handling with set -e or trap", + "Use readonly for constants", + "Use local variables in functions", + "Avoid using exit codes incorrectly", + ] + + def __init__(self): + """Initialize analyzer.""" + self.parser = ShellParser() + self.client = get_ollama_client() + self.rules_db = SecurityRulesDB() + + def analyze( + self, script: str, include_suggestions: bool = True + ) -> RefactoringResult: + """Analyze script for refactoring opportunities. + + Args: + script: Shell script content + include_ai_suggestions: Include AI-generated suggestions + + Returns: + RefactoringResult with issues and suggestions + """ + shell_type = self.parser.detect_shell(script) + issues = self._find_issues(script) + score = self._calculate_score(issues, script) + + suggestions = self._generate_suggestions(issues, script) + + if include_suggestions: + ai_suggestions = self._get_ai_suggestions(script, shell_type) + suggestions.extend(ai_suggestions) + + safer_script = self._generate_safer_script(script, issues) + + return RefactoringResult( + shell_type=shell_type, + issues=issues, + score=score, + suggestions=suggestions, + safer_script=safer_script, + ) + + def _find_issues(self, script: str) -> List[RefactoringIssue]: + """Find issues in script. + + Args: + script: Shell script content + + Returns: + List of RefactoringIssue + """ + issues = [] + lines = script.split("\n") + + for i, line in enumerate(lines, 1): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + + rule = self.rules_db.check_rule(stripped) + if rule: + issues.append( + RefactoringIssue( + line_number=i, + original=stripped, + issue_type=rule["name"], + severity=rule["severity"], + description=rule["description"], + risk_assessment=rule["risk"], + suggestion=f"See alternative for {rule['name']}", + safer_alternative=rule["alternative"], + ) + ) + else: + improvement = self._check_improvements(stripped) + if improvement: + issues.append( + RefactoringIssue( + line_number=i, + original=stripped, + issue_type="Code Quality", + severity="low", + description=improvement["description"], + risk_assessment=improvement["risk"], + suggestion=improvement["suggestion"], + safer_alternative=improvement["alternative"], + ) + ) + + return issues + + def _check_improvements( + self, line: str + ) -> Optional[Dict[str, str]]: + """Check for code quality improvements. + + Args: + line: Shell command line + + Returns: + Improvement suggestion or None + """ + if re.search(r"`[^`]+`", line): + return { + "description": "Use of backticks for command substitution", + "risk": "Backticks are deprecated and hard to nest", + "suggestion": "Use $() syntax instead", + "alternative": re.sub(r"`([^`]+)`", r"$(\1)", line), + } + if re.search(r"\[\s*[\^\]]+\]", line) and "=" in line: + return { + "description": "Use of [ ] instead of [[ ]]", + "risk": "[ ] has limitations with pattern matching", + "suggestion": "Use [[ ]] for modern bash", + "alternative": line.replace("[ ", "[[ ").replace(" ]", " ]]"), + } + return None + + def _calculate_score( + self, issues: List[RefactoringIssue], script: str + ) -> int: + """Calculate overall script score. + + Args: + issues: List of found issues + script: Original script + + Returns: + Score from 0-100 + """ + base_score = 100 + severity_weights = {"critical": 25, "high": 15, "medium": 10, "low": 5} + + for issue in issues: + base_score -= severity_weights.get(issue.severity, 5) + + lines_count = len([l for l in script.split("\n") if l.strip()]) + if lines_count > 0 and base_score > 50: + density_bonus = min(10, lines_count // 20) + base_score += density_bonus + + return max(0, min(100, base_score)) + + def _generate_suggestions( + self, issues: List[RefactoringIssue], script: str + ) -> List[str]: + """Generate list of improvement suggestions. + + Args: + issues: List of found issues + script: Original script + + Returns: + List of suggestion strings + """ + suggestions = [] + + if not script.strip().startswith("#!"): + suggestions.append("Add shebang (#!/bin/bash) for portability") + + if "set -" not in script: + suggestions.append( + "Add 'set -euo pipefail' for better error handling" + ) + + if issues: + severity_counts = {} + for issue in issues: + severity_counts[issue.severity] = ( + severity_counts.get(issue.severity, 0) + 1 + ) + for severity, count in sorted( + severity_counts.items(), key=lambda x: x[0] + ): + suggestions.append( + f"Address {count} {severity} severity issue(s)" + ) + + suggestions.extend( + self.BEST_PRACTICES[:3] + ) + + return suggestions + + def _get_ai_suggestions( + self, script: str, shell_type: str + ) -> List[str]: + """Get AI-generated improvement suggestions. + + Args: + script: Shell script content + shell_type: Shell type + + Returns: + List of suggestion strings + """ + from shellgenius.generation import PromptTemplates + + prompt = PromptTemplates.get_refactor_prompt(script, shell_type) + result = self.client.generate(prompt) + + if result["success"]: + response = result["response"].get("response", "") + return [line.strip() for line in response.split("\n") if line.strip()] + return [] + + def _generate_safer_script( + self, script: str, issues: List[RefactoringIssue] + ) -> str: + """Generate safer version of script. + + Args: + script: Original script + issues: List of issues to fix + + Returns: + Safer version of script + """ + lines = script.split("\n") + fixed_lines = [] + issue_map = {i.issue_type: i for i in issues} + + for i, line in enumerate(lines, 1): + stripped = line.strip() + if not stripped: + fixed_lines.append(line) + continue + + fixed = stripped + for issue in issues: + if issue.line_number == i and issue.safer_alternative: + if issue.safer_alternative not in ["", stripped]: + fixed = issue.safer_alternative + break + + if not stripped.startswith("#!") and i == 1: + fixed_lines.append("#!/bin/bash") + fixed_lines.append("set -euo pipefail") + fixed_lines.append("") + + fixed_lines.append(fixed if fixed else line) + + return "\n".join(fixed_lines) + + +def refactor_script( + script: str, include_suggestions: bool = True +) -> RefactoringResult: + """Convenience function to analyze and refactor a shell script. + + Args: + script: Shell script content + include_suggestions: Include AI suggestions + + Returns: + RefactoringResult + """ + analyzer = RefactoringAnalyzer() + return analyzer.analyze(script, include_suggestions)