"""Tests for ShellGenius generation module.""" from shellgenius.generation import ( PromptTemplates, ShellGenerator, ShellParser, ShellSafetyChecker, ) class TestPromptTemplates: """Test prompt template generation.""" def test_generate_prompt(self): """Test generation prompt template.""" prompt = PromptTemplates.get_generate_prompt( "list files", shell_type="bash" ) assert "list files" in prompt assert "bash" in prompt assert "Script:" in prompt def test_explain_prompt(self): """Test explanation prompt template.""" script = "echo hello" prompt = PromptTemplates.get_explain_prompt(script, "bash") assert "echo hello" in prompt assert "bash" in prompt def test_refactor_prompt(self): """Test refactor prompt template.""" script = "rm -rf /tmp/*" prompt = PromptTemplates.get_refactor_prompt(script, "bash") assert "rm -rf /tmp/*" in prompt class TestShellParser: """Test shell script parsing.""" def test_detect_bash(self): """Test bash detection from shebang.""" script = "#!/bin/bash\necho hello" assert ShellParser.detect_shell(script) == "bash" def test_detect_zsh(self): """Test zsh detection from shebang.""" script = "#!/bin/zsh\necho hello" assert ShellParser.detect_shell(script) == "zsh" def test_detect_sh(self): """Test sh detection from shebang.""" script = "#!/bin/sh\necho hello" assert ShellParser.detect_shell(script) == "sh" def test_detect_default(self): """Test default bash detection.""" script = "echo hello" assert ShellParser.detect_shell(script) == "bash" def test_parse_lines(self): """Test line parsing.""" script = "line1\nline2\nline3" lines = ShellParser.parse_lines(script) assert len(lines) == 3 assert lines[0] == (1, "line1") assert lines[1] == (2, "line2") def test_extract_commands(self): """Test command extraction.""" script = "#!/bin/bash # This is a comment echo hello : empty command echo world " commands = ShellParser.extract_commands(script) assert len(commands) == 2 assert "echo hello" in commands[0] assert "echo world" in commands[1] class TestShellSafetyChecker: """Test shell safety checking.""" def test_dangerous_command(self): """Test dangerous command detection.""" checker = ShellSafetyChecker("moderate") is_safe, warnings = checker.check_command("rm -rf /") assert not is_safe def test_safe_command(self): """Test safe command passes.""" checker = ShellSafetyChecker("moderate") is_safe, warnings = checker.check_command("ls -la") assert is_safe def test_warning_command(self): """Test warning-level commands.""" checker = ShellSafetyChecker("moderate") is_safe, warnings = checker.check_command("rm -r /tmp/*") assert is_safe assert len(warnings) > 0 def test_script_check(self): """Test full script safety check.""" checker = ShellSafetyChecker() script = "#!/bin/bash ls echo done " result = checker.check_script(script) assert "is_safe" in result assert "issues" in result assert "warnings" in result def test_strict_mode_blocks_warnings(self): """Test strict mode blocks warning-level commands.""" checker = ShellSafetyChecker("strict") is_safe, warnings = checker.check_command("rm -r /tmp/*") assert not is_safe class TestShellGenerator: """Test shell generation.""" def test_generator_initialization(self): """Test generator creates properly.""" generator = ShellGenerator() assert generator.parser is not None assert generator.templates is not None def test_extract_commands_from_raw(self): """Test command extraction from raw response.""" generator = ShellGenerator() raw = "echo hello\nls -la\ncat file.txt" commands = generator._extract_commands(raw) assert len(commands) == 3 assert "echo hello" in commands def test_generate_summary(self): """Test summary generation.""" generator = ShellGenerator() summary = generator._generate_summary(["cmd1", "cmd2"], "raw") assert "2 command" in summary