From 380b30934e8a2fd28c3430f9a761af7e3d0c18d6 Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 11:03:01 +0000 Subject: [PATCH] Add test files --- tests/test_generation.py | 111 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_generation.py diff --git a/tests/test_generation.py b/tests/test_generation.py new file mode 100644 index 0000000..326e87b --- /dev/null +++ b/tests/test_generation.py @@ -0,0 +1,111 @@ +"""Tests for shell generation module.""" + +import pytest +from unittest.mock import Mock, patch + +from shellgenius.generation import ( + ShellGenerator, + ShellParser, + ShellSafetyChecker, + PromptTemplates, + GeneratedScript, +) + + +class TestShellParser: + def test_detect_bash_shell(self): + """Test detection of bash shell from shebang.""" + script = "#!/bin/bash\necho hello" + assert ShellParser.detect_shell(script) == "bash" + + def test_detect_zsh_shell(self): + """Test detection of zsh shell from shebang.""" + script = "#!/usr/bin/zsh\necho hello" + assert ShellParser.detect_shell(script) == "zsh" + + def test_detect_sh_shell(self): + """Test detection of sh shell from shebang.""" + script = "#!/bin/sh\necho hello" + assert ShellParser.detect_shell(script) == "sh" + + def test_default_shell_detection(self): + """Test default shell when no shebang present.""" + script = "echo hello" + assert ShellParser.detect_shell(script) == "bash" + + def test_parse_lines(self): + """Test parsing script into lines.""" + script = "line1\nline2\nline3" + lines = ShellParser.parse_lines(script) + + assert len(lines) == 3 + assert lines[0] == (1, "line1") + assert lines[1] == (2, "line2") + assert lines[2] == (3, "line3") + + def test_extract_commands(self): + """Test extracting executable commands.""" + script = "#!/bin/bash\n# comment\necho hello\n\nrm -rf /" + commands = ShellParser.extract_commands(script) + + assert len(commands) == 2 + assert "echo hello" in commands + assert "rm -rf /" in commands + + +class TestShellSafetyChecker: + def test_safe_command(self): + """Test that safe commands pass check.""" + checker = ShellSafetyChecker() + is_safe, warnings = checker.check_command("ls -la") + + assert is_safe == True + assert len(warnings) == 0 + + def test_dangerous_command(self): + """Test that dangerous commands are flagged.""" + checker = ShellSafetyChecker() + is_safe, warnings = checker.check_command("rm -rf /") + + assert is_safe == False + assert any("DANGEROUS" in w for w in warnings) + + def test_warning_patterns(self): + """Test that warning patterns generate warnings.""" + checker = ShellSafetyChecker(safety_level="strict") + is_safe, warnings = checker.check_command("rm -r directory") + + assert any("WARNING" in w for w in warnings) + + def test_check_script(self): + """Test checking entire script.""" + script = "#!/bin/bash\necho hello\nrm -rf /" + checker = ShellSafetyChecker() + result = checker.check_script(script) + + assert result["is_safe"] == False + assert len(result["issues"]) > 0 + + +class TestPromptTemplates: + def test_get_generate_prompt(self): + """Test generation prompt creation.""" + prompt = PromptTemplates.get_generate_prompt("list files", "bash") + + assert "shell script expert" in prompt + assert "list files" in prompt + assert "bash" in prompt + + def test_get_explain_prompt(self): + """Test explanation prompt creation.""" + prompt = PromptTemplates.get_explain_prompt("echo hello", "bash") + + assert "explain" in prompt + assert "echo hello" in prompt + + def test_get_refactor_prompt(self): + """Test refactor prompt creation.""" + prompt = PromptTemplates.get_refactor_prompt("rm -rf /", "bash") + + assert "security" in prompt.lower() + assert "rm -rf /" in prompt