Add test files
This commit is contained in:
111
tests/test_generation.py
Normal file
111
tests/test_generation.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Tests for shell generation module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from shellgenius.generation import (
|
||||
ShellGenerator,
|
||||
ShellParser,
|
||||
ShellSafetyChecker,
|
||||
PromptTemplates,
|
||||
GeneratedScript,
|
||||
)
|
||||
|
||||
|
||||
class TestShellParser:
|
||||
def test_detect_bash_shell(self):
|
||||
"""Test detection of bash shell from shebang."""
|
||||
script = "#!/bin/bash\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "bash"
|
||||
|
||||
def test_detect_zsh_shell(self):
|
||||
"""Test detection of zsh shell from shebang."""
|
||||
script = "#!/usr/bin/zsh\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "zsh"
|
||||
|
||||
def test_detect_sh_shell(self):
|
||||
"""Test detection of sh shell from shebang."""
|
||||
script = "#!/bin/sh\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "sh"
|
||||
|
||||
def test_default_shell_detection(self):
|
||||
"""Test default shell when no shebang present."""
|
||||
script = "echo hello"
|
||||
assert ShellParser.detect_shell(script) == "bash"
|
||||
|
||||
def test_parse_lines(self):
|
||||
"""Test parsing script into lines."""
|
||||
script = "line1\nline2\nline3"
|
||||
lines = ShellParser.parse_lines(script)
|
||||
|
||||
assert len(lines) == 3
|
||||
assert lines[0] == (1, "line1")
|
||||
assert lines[1] == (2, "line2")
|
||||
assert lines[2] == (3, "line3")
|
||||
|
||||
def test_extract_commands(self):
|
||||
"""Test extracting executable commands."""
|
||||
script = "#!/bin/bash\n# comment\necho hello\n\nrm -rf /"
|
||||
commands = ShellParser.extract_commands(script)
|
||||
|
||||
assert len(commands) == 2
|
||||
assert "echo hello" in commands
|
||||
assert "rm -rf /" in commands
|
||||
|
||||
|
||||
class TestShellSafetyChecker:
|
||||
def test_safe_command(self):
|
||||
"""Test that safe commands pass check."""
|
||||
checker = ShellSafetyChecker()
|
||||
is_safe, warnings = checker.check_command("ls -la")
|
||||
|
||||
assert is_safe == True
|
||||
assert len(warnings) == 0
|
||||
|
||||
def test_dangerous_command(self):
|
||||
"""Test that dangerous commands are flagged."""
|
||||
checker = ShellSafetyChecker()
|
||||
is_safe, warnings = checker.check_command("rm -rf /")
|
||||
|
||||
assert is_safe == False
|
||||
assert any("DANGEROUS" in w for w in warnings)
|
||||
|
||||
def test_warning_patterns(self):
|
||||
"""Test that warning patterns generate warnings."""
|
||||
checker = ShellSafetyChecker(safety_level="strict")
|
||||
is_safe, warnings = checker.check_command("rm -r directory")
|
||||
|
||||
assert any("WARNING" in w for w in warnings)
|
||||
|
||||
def test_check_script(self):
|
||||
"""Test checking entire script."""
|
||||
script = "#!/bin/bash\necho hello\nrm -rf /"
|
||||
checker = ShellSafetyChecker()
|
||||
result = checker.check_script(script)
|
||||
|
||||
assert result["is_safe"] == False
|
||||
assert len(result["issues"]) > 0
|
||||
|
||||
|
||||
class TestPromptTemplates:
|
||||
def test_get_generate_prompt(self):
|
||||
"""Test generation prompt creation."""
|
||||
prompt = PromptTemplates.get_generate_prompt("list files", "bash")
|
||||
|
||||
assert "shell script expert" in prompt
|
||||
assert "list files" in prompt
|
||||
assert "bash" in prompt
|
||||
|
||||
def test_get_explain_prompt(self):
|
||||
"""Test explanation prompt creation."""
|
||||
prompt = PromptTemplates.get_explain_prompt("echo hello", "bash")
|
||||
|
||||
assert "explain" in prompt
|
||||
assert "echo hello" in prompt
|
||||
|
||||
def test_get_refactor_prompt(self):
|
||||
"""Test refactor prompt creation."""
|
||||
prompt = PromptTemplates.get_refactor_prompt("rm -rf /", "bash")
|
||||
|
||||
assert "security" in prompt.lower()
|
||||
assert "rm -rf /" in prompt
|
||||
Reference in New Issue
Block a user