137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
"""Tests for ShellGenius generation module."""
|
|
|
|
from shellgenius.generation import (
|
|
PromptTemplates,
|
|
ShellGenerator,
|
|
ShellParser,
|
|
ShellSafetyChecker,
|
|
)
|
|
|
|
|
|
class TestPromptTemplates:
|
|
"""Test prompt template generation."""
|
|
|
|
def test_generate_prompt(self):
|
|
"""Test generation prompt template."""
|
|
prompt = PromptTemplates.get_generate_prompt(
|
|
"list files", shell_type="bash"
|
|
)
|
|
assert "list files" in prompt
|
|
assert "bash" in prompt
|
|
assert "Script:" in prompt
|
|
|
|
def test_explain_prompt(self):
|
|
"""Test explanation prompt template."""
|
|
script = "echo hello"
|
|
prompt = PromptTemplates.get_explain_prompt(script, "bash")
|
|
assert "echo hello" in prompt
|
|
assert "bash" in prompt
|
|
|
|
def test_refactor_prompt(self):
|
|
"""Test refactor prompt template."""
|
|
script = "rm -rf /tmp/*"
|
|
prompt = PromptTemplates.get_refactor_prompt(script, "bash")
|
|
assert "rm -rf /tmp/*" in prompt
|
|
|
|
|
|
class TestShellParser:
|
|
"""Test shell script parsing."""
|
|
|
|
def test_detect_bash(self):
|
|
"""Test bash detection from shebang."""
|
|
script = "#!/bin/bash\necho hello"
|
|
assert ShellParser.detect_shell(script) == "bash"
|
|
|
|
def test_detect_zsh(self):
|
|
"""Test zsh detection from shebang."""
|
|
script = "#!/bin/zsh\necho hello"
|
|
assert ShellParser.detect_shell(script) == "zsh"
|
|
|
|
def test_detect_sh(self):
|
|
"""Test sh detection from shebang."""
|
|
script = "#!/bin/sh\necho hello"
|
|
assert ShellParser.detect_shell(script) == "sh"
|
|
|
|
def test_detect_default(self):
|
|
"""Test default bash detection."""
|
|
script = "echo hello"
|
|
assert ShellParser.detect_shell(script) == "bash"
|
|
|
|
def test_parse_lines(self):
|
|
"""Test line parsing."""
|
|
script = "line1\nline2\nline3"
|
|
lines = ShellParser.parse_lines(script)
|
|
assert len(lines) == 3
|
|
assert lines[0] == (1, "line1")
|
|
assert lines[1] == (2, "line2")
|
|
|
|
def test_extract_commands(self):
|
|
"""Test command extraction."""
|
|
script = "#!/bin/bash\n# This is a comment\necho hello\n: empty command\necho world\n"
|
|
commands = ShellParser.extract_commands(script)
|
|
assert len(commands) == 2
|
|
assert "echo hello" in commands[0]
|
|
assert "echo world" in commands[1]
|
|
|
|
|
|
class TestShellSafetyChecker:
|
|
"""Test shell safety checking."""
|
|
|
|
def test_dangerous_command(self):
|
|
"""Test dangerous command detection."""
|
|
checker = ShellSafetyChecker("moderate")
|
|
is_safe, warnings = checker.check_command("rm -rf /")
|
|
assert not is_safe
|
|
|
|
def test_safe_command(self):
|
|
"""Test safe command passes."""
|
|
checker = ShellSafetyChecker("moderate")
|
|
is_safe, warnings = checker.check_command("ls -la")
|
|
assert is_safe
|
|
|
|
def test_warning_command(self):
|
|
"""Test warning-level commands."""
|
|
checker = ShellSafetyChecker("moderate")
|
|
is_safe, warnings = checker.check_command("rm -r /tmp/*")
|
|
assert is_safe
|
|
assert len(warnings) > 0
|
|
|
|
def test_script_check(self):
|
|
"""Test full script safety check."""
|
|
checker = ShellSafetyChecker()
|
|
script = "#!/bin/bash\nls\necho done\n"
|
|
result = checker.check_script(script)
|
|
assert "is_safe" in result
|
|
assert "issues" in result
|
|
assert "warnings" in result
|
|
|
|
def test_strict_mode_blocks_warnings(self):
|
|
"""Test strict mode blocks warning-level commands."""
|
|
checker = ShellSafetyChecker("strict")
|
|
is_safe, warnings = checker.check_command("rm -r /tmp/*")
|
|
assert not is_safe
|
|
|
|
|
|
class TestShellGenerator:
|
|
"""Test shell generation."""
|
|
|
|
def test_generator_initialization(self):
|
|
"""Test generator creates properly."""
|
|
generator = ShellGenerator()
|
|
assert generator.parser is not None
|
|
assert generator.templates is not None
|
|
|
|
def test_extract_commands_from_raw(self):
|
|
"""Test command extraction from raw response."""
|
|
generator = ShellGenerator()
|
|
raw = "echo hello\nls -la\ncat file.txt"
|
|
commands = generator._extract_commands(raw)
|
|
assert len(commands) == 3
|
|
assert "echo hello" in commands
|
|
|
|
def test_generate_summary(self):
|
|
"""Test summary generation."""
|
|
generator = ShellGenerator()
|
|
summary = generator._generate_summary(["cmd1", "cmd2"], "raw")
|
|
assert "2 command" in summary
|