"""Tests for shell generation module.""" from unittest.mock import Mock, patch from shellgenius.generation import ( ShellParser, ShellSafetyChecker, PromptTemplates, ) class TestShellParser: def test_detect_bash_shell(self): """Test detection of bash shell from shebang.""" script = "#!/bin/bash\necho hello" assert ShellParser.detect_shell(script) == "bash" def test_detect_zsh_shell(self): """Test detection of zsh shell from shebang.""" script = "#!/usr/bin/zsh\necho hello" assert ShellParser.detect_shell(script) == "zsh" def test_detect_sh_shell(self): """Test detection of sh shell from shebang.""" script = "#!/bin/sh\necho hello" assert ShellParser.detect_shell(script) == "sh" def test_default_shell_detection(self): """Test default shell when no shebang present.""" script = "echo hello" assert ShellParser.detect_shell(script) == "bash" def test_parse_lines(self): """Test parsing script into lines.""" script = "line1\nline2\nline3" lines = ShellParser.parse_lines(script) assert len(lines) == 3 assert lines[0] == (1, "line1") assert lines[1] == (2, "line2") assert lines[2] == (3, "line3") def test_extract_commands(self): """Test extracting executable commands.""" script = "#!/bin/bash\n# comment\necho hello\n\nrm -rf /" commands = ShellParser.extract_commands(script) assert len(commands) == 2 assert "echo hello" in commands assert "rm -rf /" in commands class TestShellSafetyChecker: def test_safe_command(self): """Test that safe commands pass check.""" checker = ShellSafetyChecker() is_safe, warnings = checker.check_command("ls -la") assert is_safe assert len(warnings) == 0 def test_dangerous_command(self): """Test that dangerous commands are flagged.""" checker = ShellSafetyChecker() is_safe, warnings = checker.check_command("rm -rf /") assert not is_safe assert any("DANGEROUS" in w for w in warnings) def test_warning_patterns(self): """Test that warning patterns generate warnings.""" checker = ShellSafetyChecker(safety_level="strict") is_safe, warnings = checker.check_command("rm -r directory") assert any("WARNING" in w for w in warnings) def test_check_script(self): """Test checking entire script.""" script = "#!/bin/bash\necho hello\nrm -rf /" checker = ShellSafetyChecker() result = checker.check_script(script) assert not result["is_safe"] assert len(result["issues"]) > 0 class TestPromptTemplates: def test_get_generate_prompt(self): """Test generation prompt creation.""" prompt = PromptTemplates.get_generate_prompt("list files", "bash") assert "shell script expert" in prompt assert "list files" in prompt assert "bash" in prompt def test_get_explain_prompt(self): """Test explanation prompt creation.""" prompt = PromptTemplates.get_explain_prompt("echo hello", "bash") assert "Explain" in prompt assert "echo hello" in prompt def test_get_refactor_prompt(self): """Test refactor prompt creation.""" prompt = PromptTemplates.get_refactor_prompt("rm -rf /", "bash") assert "security" in prompt.lower() assert "rm -rf /" in prompt