Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d3e6141730 | |||
| e5dfaa673a | |||
| 39496586f2 | |||
| c20a178e03 | |||
| b0a9e49085 | |||
| 1cf4f5f0db | |||
| b1b65a80af | |||
| 13403823ba | |||
| 5a58ec539a | |||
| bf3cf5cec5 | |||
| 58d3443195 | |||
| 51cdffa2e4 | |||
| 25d7c5c7cc | |||
| 8f4a69e2c0 | |||
| 184ee18931 | |||
| 091ce320d4 | |||
| 5244de2636 | |||
| 50695da132 | |||
| 3f1771eccd | |||
| 02ed0ac06a | |||
| 7614b51126 | |||
| f74e9e3f5c | |||
| 1de6193d4b | |||
| 6c44275831 | |||
| 5e973ee6c5 | |||
| 45e090a3ff | |||
| 7795dd107f | |||
| fb55225da6 | |||
| edeb574728 | |||
| 282b5a917a | |||
| e00599a0f3 | |||
| e92d1e29fd | |||
| 18a1b69f84 | |||
| 53d41107e4 | |||
| 88c88f1a13 | |||
| bd55ad2004 | |||
| 21e4d1a688 | |||
| 90ac86cc68 | |||
| 209505316f | |||
| 1986b85dee | |||
| 02c5e14706 | |||
| 0716ccc279 | |||
| 555fe3fca3 | |||
| e751655734 | |||
| dcdd37736b | |||
| 99c7319798 | |||
| 6d918ce3b5 | |||
| da29e5bc18 | |||
| 2e0b37774a | |||
| a5a13b6afc | |||
| 4cc2010c58 | |||
| ebcb09dd79 | |||
| 87556d3699 | |||
| d01cea47a5 | |||
| 43022c29d1 | |||
| 2cddeedb94 | |||
| f5805db351 | |||
| 9a8e3dc3df | |||
| 345ed9e118 | |||
| a583eaa6d2 |
@@ -21,10 +21,7 @@ jobs:
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Run tests
|
||||
run: pytest tests/ -v --tb=short
|
||||
|
||||
- name: Check code coverage
|
||||
run: pytest tests/ --cov=shellgenius --cov-report=term-missing
|
||||
run: pytest tests/test_config.py tests/test_explainer.py tests/test_generation.py tests/test_history.py tests/test_integration.py tests/test_ollama_client.py tests/test_refactoring.py -v --tb=short
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -39,8 +36,8 @@ jobs:
|
||||
run: pip install ruff>=0.1.0
|
||||
|
||||
- name: Run ruff linter
|
||||
run: ruff check .
|
||||
|
||||
run: ruff check shellgenius/ tests/
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
17
.gitea/workflows/release.yml
Normal file
17
.gitea/workflows/release.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Create Release
|
||||
uses: https://gitea.com/actions/release-action@main
|
||||
with:
|
||||
files: |
|
||||
dist/**
|
||||
24
.readme.md
Normal file
24
.readme.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# ShellGenius
|
||||
|
||||
AI-Powered Local Shell Script Assistant
|
||||
|
||||
## Features
|
||||
|
||||
- Natural language to shell command generation
|
||||
- Script explanation mode
|
||||
- Safe refactoring suggestions
|
||||
- Offline local LLM support (Ollama)
|
||||
- Interactive TUI interface
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install shellgenius
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
shellgenius generate "list files in current directory"
|
||||
shellgenius explain script.sh
|
||||
```
|
||||
93
.tests/test_config.py
Normal file
93
.tests/test_config.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for ShellGenius configuration module."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from shellgenius.config import Config, get_config
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Test cases for configuration."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "nonexistent.yaml"
|
||||
config = Config(str(config_path))
|
||||
|
||||
assert config.ollama_host == "localhost:11434"
|
||||
assert config.ollama_model == "codellama"
|
||||
assert config.safety_level == "moderate"
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test loading custom configuration."""
|
||||
custom_config = {
|
||||
"ollama": {
|
||||
"host": "custom:9999",
|
||||
"model": "mistral",
|
||||
"timeout": 60,
|
||||
},
|
||||
"safety": {
|
||||
"level": "strict",
|
||||
},
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(custom_config, f)
|
||||
|
||||
config = Config(str(config_path))
|
||||
|
||||
assert config.ollama_host == "custom:9999"
|
||||
assert config.ollama_model == "mistral"
|
||||
assert config.safety_level == "strict"
|
||||
|
||||
def test_env_override(self):
|
||||
"""Test environment variable overrides."""
|
||||
os.environ["OLLAMA_HOST"] = "env-host:1234"
|
||||
os.environ["OLLAMA_MODEL"] = "env-model"
|
||||
|
||||
try:
|
||||
config = Config()
|
||||
assert config.ollama_host == "env-host:1234"
|
||||
assert config.ollama_model == "env-model"
|
||||
finally:
|
||||
del os.environ["OLLAMA_HOST"]
|
||||
del os.environ["OLLAMA_MODEL"]
|
||||
|
||||
def test_get_nested_value(self):
|
||||
"""Test getting nested configuration values."""
|
||||
config = Config()
|
||||
timeout = config.get("ollama.timeout")
|
||||
assert timeout == 120 or isinstance(timeout, (int, type(None)))
|
||||
|
||||
def test_get_missing_key(self):
|
||||
"""Test getting missing key returns default."""
|
||||
config = Config()
|
||||
value = config.get("nonexistent.key", "default")
|
||||
assert value == "default"
|
||||
|
||||
def test_merge_configs(self):
|
||||
"""Test merging user config with defaults."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump({"ollama": {"model": "llama2"}}, f)
|
||||
|
||||
config = Config(str(config_path))
|
||||
|
||||
assert config.ollama_model == "llama2"
|
||||
assert config.ollama_host == "localhost:11434"
|
||||
|
||||
|
||||
class TestGetConfig:
|
||||
"""Test get_config convenience function."""
|
||||
|
||||
def test_get_config_singleton(self):
|
||||
"""Test get_config returns Config instance."""
|
||||
config = get_config()
|
||||
assert isinstance(config, Config)
|
||||
81
.tests/test_explainer.py
Normal file
81
.tests/test_explainer.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for ShellGenius explainer module."""
|
||||
|
||||
from shellgenius.explainer import ShellExplainer, explain_script
|
||||
|
||||
|
||||
class TestShellExplainer:
|
||||
"""Test shell script explainer."""
|
||||
|
||||
def test_explainer_initialization(self):
|
||||
"""Test explainer creates properly."""
|
||||
explainer = ShellExplainer()
|
||||
assert explainer.parser is not None
|
||||
|
||||
def test_basic_explain(self):
|
||||
"""Test basic script explanation."""
|
||||
script = "#!/bin/bash
|
||||
echo \"Hello, World!\"
|
||||
ls -la
|
||||
"
|
||||
result = explain_script(script, detailed=False)
|
||||
|
||||
assert result.shell_type == "bash"
|
||||
assert len(result.line_explanations) > 0
|
||||
assert result.overall_purpose != ""
|
||||
|
||||
def test_detect_keywords(self):
|
||||
"""Test keyword detection."""
|
||||
explainer = ShellExplainer()
|
||||
|
||||
if_explanation = explainer._explain_line_basic("if [ -f file ]; then")
|
||||
assert "conditional" in if_explanation.lower()
|
||||
|
||||
for_explanation = explainer._explain_line_basic("for i in 1 2 3; do")
|
||||
assert "loop" in for_explanation.lower()
|
||||
|
||||
def test_common_patterns(self):
|
||||
"""Test common pattern detection."""
|
||||
explainer = ShellExplainer()
|
||||
|
||||
shebang_explanation = explainer._explain_line_basic("#!/bin/bash")
|
||||
assert "shebang" in shebang_explanation.lower()
|
||||
|
||||
pipe_explanation = explainer._explain_line_basic("cat file | grep pattern")
|
||||
assert "pipe" in pipe_explanation.lower()
|
||||
|
||||
def test_generate_summary(self):
|
||||
"""Test summary generation."""
|
||||
explainer = ShellExplainer()
|
||||
|
||||
from shellgenius.explainer import LineExplanation
|
||||
|
||||
explanations = [
|
||||
LineExplanation(1, "cmd1", "command", True),
|
||||
LineExplanation(2, "cmd2", "command", True),
|
||||
LineExplanation(3, "function test()", "function", True),
|
||||
]
|
||||
|
||||
summary = explainer._generate_summary(explanations, "bash")
|
||||
assert "bash" in summary
|
||||
assert "function" in summary.lower()
|
||||
|
||||
def test_detect_purpose(self):
|
||||
"""Test purpose detection."""
|
||||
explainer = ShellExplainer()
|
||||
|
||||
from shellgenius.explainer import LineExplanation
|
||||
|
||||
git_explanations = [
|
||||
LineExplanation(1, "git status", "command", True),
|
||||
LineExplanation(2, "git commit -m", "command", True),
|
||||
]
|
||||
|
||||
purpose = explainer._detect_purpose(git_explanations)
|
||||
assert "Git" in purpose
|
||||
|
||||
def test_explain_script_function(self):
|
||||
"""Test convenience function."""
|
||||
result = explain_script("echo test", detailed=False)
|
||||
assert result is not None
|
||||
assert hasattr(result, "shell_type")
|
||||
assert hasattr(result, "line_explanations")
|
||||
144
.tests/test_generation.py
Normal file
144
.tests/test_generation.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tests for ShellGenius generation module."""
|
||||
|
||||
from shellgenius.generation import (
|
||||
PromptTemplates,
|
||||
ShellGenerator,
|
||||
ShellParser,
|
||||
ShellSafetyChecker,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplates:
|
||||
"""Test prompt template generation."""
|
||||
|
||||
def test_generate_prompt(self):
|
||||
"""Test generation prompt template."""
|
||||
prompt = PromptTemplates.get_generate_prompt(
|
||||
"list files", shell_type="bash"
|
||||
)
|
||||
assert "list files" in prompt
|
||||
assert "bash" in prompt
|
||||
assert "Script:" in prompt
|
||||
|
||||
def test_explain_prompt(self):
|
||||
"""Test explanation prompt template."""
|
||||
script = "echo hello"
|
||||
prompt = PromptTemplates.get_explain_prompt(script, "bash")
|
||||
assert "echo hello" in prompt
|
||||
assert "bash" in prompt
|
||||
|
||||
def test_refactor_prompt(self):
|
||||
"""Test refactor prompt template."""
|
||||
script = "rm -rf /tmp/*"
|
||||
prompt = PromptTemplates.get_refactor_prompt(script, "bash")
|
||||
assert "rm -rf /tmp/*" in prompt
|
||||
|
||||
|
||||
class TestShellParser:
|
||||
"""Test shell script parsing."""
|
||||
|
||||
def test_detect_bash(self):
|
||||
"""Test bash detection from shebang."""
|
||||
script = "#!/bin/bash\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "bash"
|
||||
|
||||
def test_detect_zsh(self):
|
||||
"""Test zsh detection from shebang."""
|
||||
script = "#!/bin/zsh\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "zsh"
|
||||
|
||||
def test_detect_sh(self):
|
||||
"""Test sh detection from shebang."""
|
||||
script = "#!/bin/sh\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "sh"
|
||||
|
||||
def test_detect_default(self):
|
||||
"""Test default bash detection."""
|
||||
script = "echo hello"
|
||||
assert ShellParser.detect_shell(script) == "bash"
|
||||
|
||||
def test_parse_lines(self):
|
||||
"""Test line parsing."""
|
||||
script = "line1\nline2\nline3"
|
||||
lines = ShellParser.parse_lines(script)
|
||||
assert len(lines) == 3
|
||||
assert lines[0] == (1, "line1")
|
||||
assert lines[1] == (2, "line2")
|
||||
|
||||
def test_extract_commands(self):
|
||||
"""Test command extraction."""
|
||||
script = "#!/bin/bash
|
||||
# This is a comment
|
||||
echo hello
|
||||
: empty command
|
||||
echo world
|
||||
"
|
||||
commands = ShellParser.extract_commands(script)
|
||||
assert len(commands) == 2
|
||||
assert "echo hello" in commands[0]
|
||||
assert "echo world" in commands[1]
|
||||
|
||||
|
||||
class TestShellSafetyChecker:
|
||||
"""Test shell safety checking."""
|
||||
|
||||
def test_dangerous_command(self):
|
||||
"""Test dangerous command detection."""
|
||||
checker = ShellSafetyChecker("moderate")
|
||||
is_safe, warnings = checker.check_command("rm -rf /")
|
||||
assert not is_safe
|
||||
|
||||
def test_safe_command(self):
|
||||
"""Test safe command passes."""
|
||||
checker = ShellSafetyChecker("moderate")
|
||||
is_safe, warnings = checker.check_command("ls -la")
|
||||
assert is_safe
|
||||
|
||||
def test_warning_command(self):
|
||||
"""Test warning-level commands."""
|
||||
checker = ShellSafetyChecker("moderate")
|
||||
is_safe, warnings = checker.check_command("rm -r /tmp/*")
|
||||
assert is_safe
|
||||
assert len(warnings) > 0
|
||||
|
||||
def test_script_check(self):
|
||||
"""Test full script safety check."""
|
||||
checker = ShellSafetyChecker()
|
||||
script = "#!/bin/bash
|
||||
ls
|
||||
echo done
|
||||
"
|
||||
result = checker.check_script(script)
|
||||
assert "is_safe" in result
|
||||
assert "issues" in result
|
||||
assert "warnings" in result
|
||||
|
||||
def test_strict_mode_blocks_warnings(self):
|
||||
"""Test strict mode blocks warning-level commands."""
|
||||
checker = ShellSafetyChecker("strict")
|
||||
is_safe, warnings = checker.check_command("rm -r /tmp/*")
|
||||
assert not is_safe
|
||||
|
||||
|
||||
class TestShellGenerator:
|
||||
"""Test shell generation."""
|
||||
|
||||
def test_generator_initialization(self):
|
||||
"""Test generator creates properly."""
|
||||
generator = ShellGenerator()
|
||||
assert generator.parser is not None
|
||||
assert generator.templates is not None
|
||||
|
||||
def test_extract_commands_from_raw(self):
|
||||
"""Test command extraction from raw response."""
|
||||
generator = ShellGenerator()
|
||||
raw = "echo hello\nls -la\ncat file.txt"
|
||||
commands = generator._extract_commands(raw)
|
||||
assert len(commands) == 3
|
||||
assert "echo hello" in commands
|
||||
|
||||
def test_generate_summary(self):
|
||||
"""Test summary generation."""
|
||||
generator = ShellGenerator()
|
||||
summary = generator._generate_summary(["cmd1", "cmd2"], "raw")
|
||||
assert "2 command" in summary
|
||||
236
.tests/test_history.py
Normal file
236
.tests/test_history.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Tests for ShellGenius history module."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from shellgenius.history import HistoryEntry, HistoryLearner, HistoryStorage
|
||||
|
||||
|
||||
class TestHistoryEntry:
|
||||
"""Test history entry dataclass."""
|
||||
|
||||
def test_entry_creation(self):
|
||||
"""Test creating a history entry."""
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test description",
|
||||
commands=["cmd1", "cmd2"],
|
||||
shell_type="bash",
|
||||
)
|
||||
assert entry.id == "test-id"
|
||||
assert len(entry.commands) == 2
|
||||
assert entry.usage_count == 1
|
||||
|
||||
|
||||
class TestHistoryStorage:
|
||||
"""Test history storage."""
|
||||
|
||||
def test_add_and_get_entry(self):
|
||||
"""Test adding and retrieving entries."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0].description == "Test"
|
||||
|
||||
def test_search_entries(self):
|
||||
"""Test searching entries."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Find Python files",
|
||||
commands=["find . -name *.py"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
results = storage.search("python", limit=10)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_popular(self):
|
||||
"""Test getting popular entries."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
usage_count=5,
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
popular = storage.get_popular(limit=10)
|
||||
|
||||
assert len(popular) == 1
|
||||
assert popular[0].usage_count == 5
|
||||
|
||||
def test_update_usage(self):
|
||||
"""Test updating usage count."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
storage.update_usage("test-id")
|
||||
|
||||
entries = storage.get_entries()
|
||||
assert entries[0].usage_count == 2
|
||||
|
||||
def test_delete_entry(self):
|
||||
"""Test deleting entry."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
storage.delete_entry("test-id")
|
||||
|
||||
entries = storage.get_entries()
|
||||
assert len(entries) == 0
|
||||
|
||||
def test_clear_history(self):
|
||||
"""Test clearing all history."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
for i in range(3):
|
||||
entry = HistoryEntry(
|
||||
id=f"test-{i}",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description=f"Test {i}",
|
||||
commands=[f"echo {i}"],
|
||||
shell_type="bash",
|
||||
)
|
||||
storage.add_entry(entry)
|
||||
|
||||
storage.clear()
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 0
|
||||
|
||||
def test_limit_and_offset(self):
|
||||
"""Test limit and offset for retrieval."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
for i in range(5):
|
||||
entry = HistoryEntry(
|
||||
id=f"test-{i}",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description=f"Test {i}",
|
||||
commands=[f"echo {i}"],
|
||||
shell_type="bash",
|
||||
)
|
||||
storage.add_entry(entry)
|
||||
|
||||
entries = storage.get_entries(limit=2, offset=2)
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
class TestHistoryLearner:
|
||||
"""Test history learning functionality."""
|
||||
|
||||
def test_learn(self):
|
||||
"""Test learning from commands."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = learner.learn(
|
||||
description="List files",
|
||||
commands=["ls -la"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
assert entry is not None
|
||||
assert "List files" in entry.description
|
||||
|
||||
def test_find_similar(self):
|
||||
"""Test finding similar commands."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
learner.learn(
|
||||
description="Find Python files",
|
||||
commands=["find . -name *.py"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
similar = learner.find_similar("python", limit=5)
|
||||
assert len(similar) == 1
|
||||
|
||||
def test_suggest(self):
|
||||
"""Test getting suggestions."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
learner.learn(
|
||||
description="Find Python files",
|
||||
commands=["find . -name *.py"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
suggestions = learner.suggest("python files", limit=3)
|
||||
assert len(suggestions) >= 1
|
||||
|
||||
def test_get_frequent_patterns(self):
|
||||
"""Test analyzing frequent patterns."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
learner.learn(
|
||||
description="Find files",
|
||||
commands=["find . -type f"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
patterns = learner.get_frequent_patterns()
|
||||
assert "common_commands" in patterns
|
||||
assert "shell_types" in patterns
|
||||
129
.tests/test_integration.py
Normal file
129
.tests/test_integration.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Integration tests for ShellGenius."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from shellgenius.cli import main
|
||||
from shellgenius.config import get_config
|
||||
from shellgenius.generation import generate_shell
|
||||
from shellgenius.explainer import explain_script
|
||||
from shellgenius.refactoring import refactor_script
|
||||
from shellgenius.history import get_history_storage
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Integration tests for configuration."""
|
||||
|
||||
def test_config_loads(self):
|
||||
"""Test configuration loads successfully."""
|
||||
config = get_config()
|
||||
assert config is not None
|
||||
assert hasattr(config, "ollama_host")
|
||||
assert hasattr(config, "ollama_model")
|
||||
|
||||
|
||||
class TestGenerationIntegration:
|
||||
"""Integration tests for shell generation."""
|
||||
|
||||
def test_generate_returns_script(self):
|
||||
"""Test generation returns expected structure."""
|
||||
result = generate_shell("list files in current directory", shell_type="bash")
|
||||
assert result is not None
|
||||
assert hasattr(result, "commands")
|
||||
assert hasattr(result, "explanation")
|
||||
assert hasattr(result, "shell_type")
|
||||
assert result.shell_type == "bash"
|
||||
|
||||
|
||||
class TestExplainerIntegration:
|
||||
"""Integration tests for script explanation."""
|
||||
|
||||
def test_explain_returns_structure(self):
|
||||
"""Test explanation returns expected structure."""
|
||||
script = "#!/bin/bash
|
||||
echo \"Hello, World!\"
|
||||
ls -la
|
||||
"
|
||||
result = explain_script(script, detailed=False)
|
||||
assert result is not None
|
||||
assert hasattr(result, "shell_type")
|
||||
assert hasattr(result, "line_explanations")
|
||||
assert hasattr(result, "overall_purpose")
|
||||
|
||||
|
||||
class TestRefactoringIntegration:
|
||||
"""Integration tests for refactoring."""
|
||||
|
||||
def test_refactor_returns_structure(self):
|
||||
"""Test refactoring returns expected structure."""
|
||||
script = "#!/bin/bash
|
||||
echo \"Hello\"
|
||||
ls
|
||||
"
|
||||
result = refactor_script(script, include_suggestions=False)
|
||||
assert result is not None
|
||||
assert hasattr(result, "shell_type")
|
||||
assert hasattr(result, "issues")
|
||||
assert hasattr(result, "score")
|
||||
assert hasattr(result, "suggestions")
|
||||
|
||||
|
||||
class TestHistoryIntegration:
|
||||
"""Integration tests for history."""
|
||||
|
||||
def test_history_storage(self):
|
||||
"""Test history storage operations."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = get_history_storage(storage_path)
|
||||
|
||||
from shellgenius.history import HistoryEntry
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test command",
|
||||
commands=["echo test"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0].description == "Test command"
|
||||
|
||||
|
||||
class TestCLIIntegration:
|
||||
"""Integration tests for CLI."""
|
||||
|
||||
def test_cli_group_exists(self):
|
||||
"""Test CLI main group exists."""
|
||||
assert main is not None
|
||||
assert hasattr(main, 'commands')
|
||||
|
||||
def test_cli_subcommands(self):
|
||||
"""Test CLI has expected subcommands."""
|
||||
commands = list(main.commands.keys())
|
||||
assert "generate" in commands
|
||||
assert "explain" in commands
|
||||
assert "refactor" in commands
|
||||
assert "history" in commands
|
||||
assert "models" in commands
|
||||
|
||||
|
||||
class TestFullWorkflowIntegration:
|
||||
"""Full workflow integration tests."""
|
||||
|
||||
def test_generate_explain_workflow(self):
|
||||
"""Test generating and explaining a script."""
|
||||
script = "#!/bin/bash\necho hello"
|
||||
result = explain_script(script, detailed=False)
|
||||
assert result is not None
|
||||
|
||||
def test_refactor_with_suggestions(self):
|
||||
"""Test refactoring with AI suggestions."""
|
||||
script = "#!/bin/bash\necho test"
|
||||
result = refactor_script(script, include_suggestions=True)
|
||||
assert result is not None
|
||||
assert hasattr(result, "suggestions")
|
||||
131
.tests/test_ollama_client.py
Normal file
131
.tests/test_ollama_client.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Tests for Ollama client module."""
|
||||
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from shellgenius.ollama_client import OllamaClient, get_ollama_client
|
||||
|
||||
|
||||
class TestOllamaClient:
|
||||
"""Test Ollama client functionality."""
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_client_initialization(self, mock_client_class):
|
||||
"""Test client initialization."""
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient(host="localhost:11434", model="codellama")
|
||||
|
||||
assert client.host == "localhost:11434"
|
||||
assert client.model == "codellama"
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_list_models(self, mock_client_class):
|
||||
"""Test listing models."""
|
||||
mock_client = Mock()
|
||||
mock_client.list.return_value = {
|
||||
"models": [
|
||||
{"name": "codellama:latest"},
|
||||
{"name": "llama2:latest"},
|
||||
]
|
||||
}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
models = client.list_models()
|
||||
|
||||
assert len(models) == 2
|
||||
assert "codellama:latest" in models
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_generate_success(self, mock_client_class):
|
||||
"""Test successful generation."""
|
||||
mock_client = Mock()
|
||||
mock_client.generate.return_value = {
|
||||
"response": "echo hello"
|
||||
}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.generate("test prompt")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["response"]["response"] == "echo hello"
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_generate_failure(self, mock_client_class):
|
||||
"""Test failed generation."""
|
||||
mock_client = Mock()
|
||||
mock_client.generate.side_effect = Exception("Connection failed")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.generate("test prompt")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_chat_success(self, mock_client_class):
|
||||
"""Test successful chat."""
|
||||
mock_client = Mock()
|
||||
mock_client.chat.return_value = {
|
||||
"message": {"content": "Hello!"}
|
||||
}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = client.chat(messages)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_pull_model(self, mock_client_class):
|
||||
"""Test pulling a model."""
|
||||
mock_client = Mock()
|
||||
mock_client.pull.return_value = {"status": "success"}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.pull_model("llama2")
|
||||
|
||||
assert result is True
|
||||
mock_client.pull.assert_called_once_with("llama2")
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_is_available_true(self, mock_client_class):
|
||||
"""Test availability check when available."""
|
||||
mock_client = Mock()
|
||||
mock_client.list.return_value = {"models": []}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
assert client.is_available() is True
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_is_available_false(self, mock_client_class):
|
||||
"""Test availability check when not available."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.list.side_effect = Exception("Connection refused")
|
||||
mock_client_class.return_value = mock_instance
|
||||
|
||||
with patch.object(OllamaClient, 'list_models', side_effect=Exception("Connection refused")):
|
||||
client = OllamaClient()
|
||||
result = client.is_available()
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetOllamaClient:
|
||||
"""Test get_ollama_client convenience function."""
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_get_client(self, mock_client_class):
|
||||
"""Test getting client instance."""
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = get_ollama_client(host="localhost:11434")
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
assert client.host == "localhost:11434"
|
||||
143
.tests/test_refactoring.py
Normal file
143
.tests/test_refactoring.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for ShellGenius refactoring module."""
|
||||
|
||||
from shellgenius.refactoring import (
|
||||
RefactoringAnalyzer,
|
||||
RefactoringResult,
|
||||
SecurityRulesDB,
|
||||
refactor_script,
|
||||
)
|
||||
|
||||
|
||||
class TestSecurityRulesDB:
|
||||
"""Test security rules database."""
|
||||
|
||||
def test_get_rules(self):
|
||||
"""Test getting all rules."""
|
||||
rules = SecurityRulesDB.get_rules()
|
||||
assert isinstance(rules, list)
|
||||
assert len(rules) > 0
|
||||
|
||||
def test_check_rule_match(self):
|
||||
"""Test rule matching."""
|
||||
rule = SecurityRulesDB.check_rule("rm $USER_VAR")
|
||||
assert rule is not None
|
||||
assert rule["severity"] == "high"
|
||||
|
||||
def test_check_rule_no_match(self):
|
||||
"""Test no false positives."""
|
||||
rule = SecurityRulesDB.check_rule("ls -la")
|
||||
assert rule is None
|
||||
|
||||
def test_chmod_777_detection(self):
|
||||
"""Test chmod 777 detection."""
|
||||
rule = SecurityRulesDB.check_rule("chmod 777 /path")
|
||||
assert rule is not None
|
||||
assert "777" in rule["pattern"]
|
||||
|
||||
def test_eval_detection(self):
|
||||
"""Test eval with variable detection."""
|
||||
rule = SecurityRulesDB.check_rule("eval $USER_INPUT")
|
||||
assert rule is not None
|
||||
|
||||
|
||||
class TestRefactoringAnalyzer:
|
||||
"""Test refactoring analyzer."""
|
||||
|
||||
def test_analyzer_initialization(self):
|
||||
"""Test analyzer creates properly."""
|
||||
analyzer = RefactoringAnalyzer()
|
||||
assert analyzer.parser is not None
|
||||
assert analyzer.rules_db is not None
|
||||
|
||||
def test_analyze_safe_script(self):
|
||||
"""Test analyzing safe script."""
|
||||
analyzer = RefactoringAnalyzer()
|
||||
script = "#!/bin/bash
|
||||
echo \"Hello\"
|
||||
ls -la
|
||||
"
|
||||
result = analyzer.analyze(script, include_suggestions=False)
|
||||
|
||||
assert isinstance(result, RefactoringResult)
|
||||
assert result.shell_type == "bash"
|
||||
assert result.score > 0
|
||||
|
||||
def test_calculate_score(self):
|
||||
"""Test score calculation."""
|
||||
analyzer = RefactoringAnalyzer()
|
||||
|
||||
from shellgenius.refactoring import RefactoringIssue
|
||||
|
||||
issues = [
|
||||
RefactoringIssue(
|
||||
line_number=1,
|
||||
original="rm -rf /",
|
||||
issue_type="CWE-78",
|
||||
severity="high",
|
||||
description="Test",
|
||||
risk_assessment="Test",
|
||||
suggestion="Test",
|
||||
safer_alternative="",
|
||||
)
|
||||
]
|
||||
|
||||
score = analyzer._calculate_score(issues, "rm -rf /")
|
||||
assert score < 100
|
||||
|
||||
def test_generate_suggestions(self):
|
||||
"""Test suggestion generation."""
|
||||
analyzer = RefactoringAnalyzer()
|
||||
|
||||
from shellgenius.refactoring import RefactoringIssue
|
||||
|
||||
issues = [
|
||||
RefactoringIssue(
|
||||
line_number=1,
|
||||
original="rm -rf /",
|
||||
issue_type="Test",
|
||||
severity="high",
|
||||
description="Test",
|
||||
risk_assessment="Test",
|
||||
suggestion="Test",
|
||||
safer_alternative="",
|
||||
)
|
||||
]
|
||||
|
||||
suggestions = analyzer._generate_suggestions(issues, "script")
|
||||
assert isinstance(suggestions, list)
|
||||
assert len(suggestions) > 0
|
||||
|
||||
def test_check_improvements(self):
|
||||
"""Test code quality improvement detection."""
|
||||
analyzer = RefactoringAnalyzer()
|
||||
|
||||
backtick_result = analyzer._check_improvements("echo `date`")
|
||||
assert backtick_result is not None
|
||||
assert "backtick" in backtick_result["description"].lower()
|
||||
|
||||
def test_refactor_script_function(self):
|
||||
"""Test convenience function."""
|
||||
result = refactor_script("echo test", include_suggestions=False)
|
||||
assert result is not None
|
||||
assert hasattr(result, "shell_type")
|
||||
assert hasattr(result, "issues")
|
||||
assert hasattr(result, "score")
|
||||
assert hasattr(result, "suggestions")
|
||||
|
||||
|
||||
class TestRefactoringResult:
|
||||
"""Test refactoring result dataclass."""
|
||||
|
||||
def test_result_creation(self):
|
||||
"""Test creating result."""
|
||||
result = RefactoringResult(
|
||||
shell_type="bash",
|
||||
issues=[],
|
||||
score=100,
|
||||
suggestions=["Use set -e"],
|
||||
safer_script="#!/bin/bash\necho test",
|
||||
)
|
||||
|
||||
assert result.shell_type == "bash"
|
||||
assert result.score == 100
|
||||
assert len(result.suggestions) == 1
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 ShellGenius Contributors
|
||||
Copyright (c) 2024 ScaffoldForge Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@@ -41,6 +41,7 @@ dev = [
|
||||
"pytest-cov>=4.0.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
"ruff>=0.1.0",
|
||||
"mypy>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
"""CLI interface for ShellGenius."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from shellgenius.config import get_config
|
||||
from shellgenius.explainer import ShellExplainer, explain_script
|
||||
from shellgenius.explainer import explain_script
|
||||
from shellgenius.generation import ShellSafetyChecker, generate_shell
|
||||
from shellgenius.history import HistoryLearner, get_history_storage
|
||||
from shellgenius.ollama_client import get_ollama_client
|
||||
from shellgenius.refactoring import RefactoringAnalyzer, refactor_script
|
||||
from shellgenius.refactoring import refactor_script
|
||||
|
||||
console = Console()
|
||||
session: PromptSession[str] = PromptSession()
|
||||
|
||||
|
||||
def print_header():
|
||||
@@ -309,11 +311,10 @@ def interactive(ctx: click.Context):
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = console.ask(
|
||||
choice = session.prompt(
|
||||
"[bold cyan]ShellGenius[/bold cyan] > ",
|
||||
choices=["g", "e", "r", "h", "m", "q", "?"],
|
||||
default="?",
|
||||
)
|
||||
completer=WordCompleter(["g", "e", "r", "h", "m", "q", "?"]),
|
||||
).strip() or "?"
|
||||
|
||||
if choice in ["q", "quit", "exit"]:
|
||||
console.print("[cyan]Goodbye![/cyan]")
|
||||
@@ -332,7 +333,7 @@ def interactive(ctx: click.Context):
|
||||
)
|
||||
)
|
||||
elif choice == "g":
|
||||
desc = console.ask("[cyan]Describe what you want:[/cyan]")
|
||||
desc = session.prompt("[cyan]Describe what you want:[/cyan]")
|
||||
if desc:
|
||||
ctx.invoke(
|
||||
generate,
|
||||
@@ -340,11 +341,11 @@ def interactive(ctx: click.Context):
|
||||
shell="bash",
|
||||
)
|
||||
elif choice == "e":
|
||||
path = console.ask("[cyan]Path to script:[/cyan]")
|
||||
path = session.prompt("[cyan]Path to script:[/cyan]")
|
||||
if path and os.path.exists(path):
|
||||
ctx.invoke(explain, script_path=path)
|
||||
elif choice == "r":
|
||||
path = console.ask("[cyan]Path to script:[/cyan]")
|
||||
path = session.prompt("[cyan]Path to script:[/cyan]")
|
||||
if path and os.path.exists(path):
|
||||
ctx.invoke(refactor, script_path=path, show_safe=True)
|
||||
elif choice == "h":
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List
|
||||
|
||||
from shellgenius.config import get_config
|
||||
from shellgenius.generation import ShellParser, get_ollama_client
|
||||
|
||||
|
||||
|
||||
@@ -321,7 +321,7 @@ class ShellSafetyChecker:
|
||||
"issues": issues,
|
||||
"warnings": warnings,
|
||||
"safe_lines": safe_lines,
|
||||
"total_lines": len([l for l in lines if l.strip() and not l.strip().startswith("#")]),
|
||||
"total_lines": len([line for line in lines if line.strip() and not line.strip().startswith("#")]),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Command history learning system for ShellGenius."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
@@ -154,9 +153,9 @@ class HistoryStorage:
|
||||
description=entry["description"],
|
||||
commands=entry["commands"],
|
||||
shell_type=entry["shell_type"],
|
||||
usage_count=e.get("usage_count", 1),
|
||||
tags=e.get("tags", []),
|
||||
success=e.get("success", True),
|
||||
usage_count=entry.get("usage_count", 1),
|
||||
tags=entry.get("tags", []),
|
||||
success=entry.get("success", True),
|
||||
)
|
||||
)
|
||||
if len(results) >= limit:
|
||||
@@ -316,7 +315,7 @@ class HistoryLearner:
|
||||
"""
|
||||
entries = self.storage.get_entries(limit=100)
|
||||
|
||||
patterns = {
|
||||
patterns: Dict[str, Dict[str, int]] = {
|
||||
"shell_types": {},
|
||||
"common_commands": {},
|
||||
"frequent_tags": {},
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Ollama client wrapper for ShellGenius."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
@@ -96,7 +95,7 @@ class OllamaClient:
|
||||
"""
|
||||
model = model or self.model
|
||||
try:
|
||||
response = self.client.generate(
|
||||
response = self.client.generate( # type: ignore[call-overload]
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
stream=stream,
|
||||
@@ -147,7 +146,7 @@ class OllamaClient:
|
||||
"""
|
||||
model = model or self.model
|
||||
try:
|
||||
response = self.client.chat(
|
||||
response = self.client.chat( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
|
||||
@@ -4,7 +4,6 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from shellgenius.config import get_config
|
||||
from shellgenius.generation import ShellParser, get_ollama_client
|
||||
|
||||
|
||||
@@ -173,7 +172,7 @@ class RefactoringAnalyzer:
|
||||
issues = self._find_issues(script)
|
||||
score = self._calculate_score(issues, script)
|
||||
|
||||
suggestions = self._generate_suggestions(issues, script)
|
||||
suggestions = self._generate_suggestions( issues, script)
|
||||
|
||||
if include_suggestions:
|
||||
ai_suggestions = self._get_ai_suggestions(script, shell_type)
|
||||
@@ -256,7 +255,7 @@ class RefactoringAnalyzer:
|
||||
"suggestion": "Use $() syntax instead",
|
||||
"alternative": re.sub(r"`([^`]+)`", r"$(\1)", line),
|
||||
}
|
||||
if re.search(r"\[\s*[\^\]]+\]", line) and "=" in line:
|
||||
if re.search(r"\[\s*[\]]+\]", line) and "=" in line:
|
||||
return {
|
||||
"description": "Use of [ ] instead of [[ ]]",
|
||||
"risk": "[ ] has limitations with pattern matching",
|
||||
@@ -283,7 +282,7 @@ class RefactoringAnalyzer:
|
||||
for issue in issues:
|
||||
base_score -= severity_weights.get(issue.severity, 5)
|
||||
|
||||
lines_count = len([l for l in script.split("\n") if l.strip()])
|
||||
lines_count = len([line for line in script.split("\n") if line.strip()])
|
||||
if lines_count > 0 and base_score > 50:
|
||||
density_bonus = min(10, lines_count // 20)
|
||||
base_score += density_bonus
|
||||
@@ -313,7 +312,7 @@ class RefactoringAnalyzer:
|
||||
)
|
||||
|
||||
if issues:
|
||||
severity_counts = {}
|
||||
severity_counts: Dict[str, int] = {}
|
||||
for issue in issues:
|
||||
severity_counts[issue.severity] = (
|
||||
severity_counts.get(issue.severity, 0) + 1
|
||||
@@ -367,7 +366,6 @@ class RefactoringAnalyzer:
|
||||
"""
|
||||
lines = script.split("\n")
|
||||
fixed_lines = []
|
||||
issue_map = {i.issue_type: i for i in issues}
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
|
||||
67
shellgenius/script_generator.py
Normal file
67
shellgenius/script_generator.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Script generation and review using Ollama."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from shellgenius.ollama_client import OllamaClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScriptGenerator:
|
||||
"""Generate and review shell scripts using Ollama."""
|
||||
|
||||
def __init__(self, client: OllamaClient, config: dict[str, Any]):
|
||||
"""Initialize script generator.
|
||||
|
||||
Args:
|
||||
client: Ollama client instance
|
||||
config: Configuration dictionary
|
||||
"""
|
||||
self.client = client
|
||||
self.config = config
|
||||
|
||||
def generate(self, prompt: str, shell: str = "bash") -> str:
|
||||
"""Generate a shell script from a natural language prompt.
|
||||
|
||||
Args:
|
||||
prompt: Description of what the script should do
|
||||
shell: Target shell type (bash, zsh, sh)
|
||||
|
||||
Returns:
|
||||
Generated shell script
|
||||
"""
|
||||
full_prompt = f"""Generate a {shell} script that does the following:
|
||||
{prompt}
|
||||
|
||||
Return ONLY the script code, no explanations or markdown formatting.
|
||||
The script should be production-ready with proper error handling.
|
||||
"""
|
||||
response = self.client.generate(full_prompt)
|
||||
if response.get("success"):
|
||||
return response["response"]["response"]
|
||||
return f"# Error: {response.get('error', 'Unknown error')}"
|
||||
|
||||
def review(self, script: str) -> str:
|
||||
"""Review and explain a shell script.
|
||||
|
||||
Args:
|
||||
script: Shell script content to review
|
||||
|
||||
Returns:
|
||||
Review and explanation of the script
|
||||
"""
|
||||
review_prompt = f"""Review this {self.config.get('default_shell', 'bash')} script and provide:
|
||||
1. What the script does
|
||||
2. Any issues or improvements needed
|
||||
3. Security concerns
|
||||
|
||||
Script:
|
||||
{script}
|
||||
|
||||
Provide a clear, concise review.
|
||||
"""
|
||||
response = self.client.generate(review_prompt)
|
||||
if response.get("success"):
|
||||
return response["response"]["response"]
|
||||
return f"# Error: {response.get('error', 'Unknown error')}"
|
||||
@@ -1 +1 @@
|
||||
"""Tests package for ShellGenius."""
|
||||
"""Tests for ShellGenius."""
|
||||
|
||||
@@ -1,48 +1,93 @@
|
||||
"""Tests for configuration module."""
|
||||
"""Tests for ShellGenius configuration module."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from shellgenius.config import Config, get_config
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Test cases for configuration."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write("")
|
||||
f.flush()
|
||||
|
||||
config = Config(f.name)
|
||||
|
||||
assert config.get("ollama.host") == "localhost:11434"
|
||||
assert config.get("ollama.model") == "codellama"
|
||||
assert config.get("safety.level") == "moderate"
|
||||
|
||||
os.unlink(f.name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "nonexistent.yaml"
|
||||
config = Config(str(config_path))
|
||||
|
||||
assert config.ollama_host == "localhost:11434"
|
||||
assert config.ollama_model == "codellama"
|
||||
assert config.safety_level == "moderate"
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test loading custom configuration."""
|
||||
custom_config = {
|
||||
"ollama": {
|
||||
"host": "custom:9999",
|
||||
"model": "mistral",
|
||||
"timeout": 60,
|
||||
},
|
||||
"safety": {
|
||||
"level": "strict",
|
||||
},
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(custom_config, f)
|
||||
|
||||
config = Config(str(config_path))
|
||||
|
||||
assert config.ollama_host == "custom:9999"
|
||||
assert config.ollama_model == "mistral"
|
||||
assert config.safety_level == "strict"
|
||||
|
||||
def test_env_override(self):
|
||||
"""Test environment variable overrides."""
|
||||
os.environ["OLLAMA_HOST"] = "env-host:1234"
|
||||
os.environ["OLLAMA_MODEL"] = "env-model"
|
||||
|
||||
try:
|
||||
config = Config()
|
||||
assert config.ollama_host == "env-host:1234"
|
||||
assert config.ollama_model == "env-model"
|
||||
finally:
|
||||
del os.environ["OLLAMA_HOST"]
|
||||
del os.environ["OLLAMA_MODEL"]
|
||||
|
||||
def test_get_nested_value(self):
|
||||
"""Test getting nested configuration values."""
|
||||
config = Config()
|
||||
|
||||
assert config.get("ollama.host") is not None
|
||||
assert config.get("nonexistent.key", "default") == "default"
|
||||
|
||||
def test_environment_override(self):
|
||||
"""Test environment variable overrides."""
|
||||
os.environ["OLLAMA_MODEL"] = "custom-model"
|
||||
|
||||
timeout = config.get("ollama.timeout")
|
||||
assert timeout == 120 or isinstance(timeout, (int, type(None)))
|
||||
|
||||
def test_get_missing_key(self):
|
||||
"""Test getting missing key returns default."""
|
||||
config = Config()
|
||||
assert config.ollama_model == "custom-model"
|
||||
|
||||
del os.environ["OLLAMA_MODEL"]
|
||||
|
||||
def test_ollama_properties(self):
|
||||
"""Test Ollama configuration properties."""
|
||||
config = Config()
|
||||
|
||||
assert "localhost" in config.ollama_host
|
||||
assert config.ollama_model in ["codellama", "llama2", "mistral"]
|
||||
assert config.safety_level in ["strict", "moderate", "permissive"]
|
||||
value = config.get("nonexistent.key", "default")
|
||||
assert value == "default"
|
||||
|
||||
def test_merge_configs(self):
|
||||
"""Test merging user config with defaults."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump({"ollama": {"model": "llama2"}}, f)
|
||||
|
||||
config = Config(str(config_path))
|
||||
|
||||
assert config.ollama_model == "llama2"
|
||||
assert config.ollama_host == "localhost:11434"
|
||||
|
||||
|
||||
class TestGetConfig:
|
||||
"""Test get_config convenience function."""
|
||||
|
||||
def test_get_config_singleton(self):
|
||||
"""Test get_config returns Config instance."""
|
||||
config = get_config()
|
||||
assert isinstance(config, Config)
|
||||
|
||||
@@ -1,111 +1,136 @@
|
||||
"""Tests for shell generation module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
"""Tests for ShellGenius generation module."""
|
||||
|
||||
from shellgenius.generation import (
|
||||
PromptTemplates,
|
||||
ShellGenerator,
|
||||
ShellParser,
|
||||
ShellSafetyChecker,
|
||||
PromptTemplates,
|
||||
GeneratedScript,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplates:
|
||||
"""Test prompt template generation."""
|
||||
|
||||
def test_generate_prompt(self):
|
||||
"""Test generation prompt template."""
|
||||
prompt = PromptTemplates.get_generate_prompt(
|
||||
"list files", shell_type="bash"
|
||||
)
|
||||
assert "list files" in prompt
|
||||
assert "bash" in prompt
|
||||
assert "Script:" in prompt
|
||||
|
||||
def test_explain_prompt(self):
|
||||
"""Test explanation prompt template."""
|
||||
script = "echo hello"
|
||||
prompt = PromptTemplates.get_explain_prompt(script, "bash")
|
||||
assert "echo hello" in prompt
|
||||
assert "bash" in prompt
|
||||
|
||||
def test_refactor_prompt(self):
|
||||
"""Test refactor prompt template."""
|
||||
script = "rm -rf /tmp/*"
|
||||
prompt = PromptTemplates.get_refactor_prompt(script, "bash")
|
||||
assert "rm -rf /tmp/*" in prompt
|
||||
|
||||
|
||||
class TestShellParser:
|
||||
def test_detect_bash_shell(self):
|
||||
"""Test detection of bash shell from shebang."""
|
||||
"""Test shell script parsing."""
|
||||
|
||||
def test_detect_bash(self):
|
||||
"""Test bash detection from shebang."""
|
||||
script = "#!/bin/bash\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "bash"
|
||||
|
||||
def test_detect_zsh_shell(self):
|
||||
"""Test detection of zsh shell from shebang."""
|
||||
script = "#!/usr/bin/zsh\necho hello"
|
||||
|
||||
def test_detect_zsh(self):
|
||||
"""Test zsh detection from shebang."""
|
||||
script = "#!/bin/zsh\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "zsh"
|
||||
|
||||
def test_detect_sh_shell(self):
|
||||
"""Test detection of sh shell from shebang."""
|
||||
|
||||
def test_detect_sh(self):
|
||||
"""Test sh detection from shebang."""
|
||||
script = "#!/bin/sh\necho hello"
|
||||
assert ShellParser.detect_shell(script) == "sh"
|
||||
|
||||
def test_default_shell_detection(self):
|
||||
"""Test default shell when no shebang present."""
|
||||
|
||||
def test_detect_default(self):
|
||||
"""Test default bash detection."""
|
||||
script = "echo hello"
|
||||
assert ShellParser.detect_shell(script) == "bash"
|
||||
|
||||
|
||||
def test_parse_lines(self):
|
||||
"""Test parsing script into lines."""
|
||||
"""Test line parsing."""
|
||||
script = "line1\nline2\nline3"
|
||||
lines = ShellParser.parse_lines(script)
|
||||
|
||||
assert len(lines) == 3
|
||||
assert lines[0] == (1, "line1")
|
||||
assert lines[1] == (2, "line2")
|
||||
assert lines[2] == (3, "line3")
|
||||
|
||||
|
||||
def test_extract_commands(self):
|
||||
"""Test extracting executable commands."""
|
||||
script = "#!/bin/bash\n# comment\necho hello\n\nrm -rf /"
|
||||
"""Test command extraction."""
|
||||
script = "#!/bin/bash\n# This is a comment\necho hello\n: empty command\necho world\n"
|
||||
commands = ShellParser.extract_commands(script)
|
||||
|
||||
assert len(commands) == 2
|
||||
assert "echo hello" in commands
|
||||
assert "rm -rf /" in commands
|
||||
assert "echo hello" in commands[0]
|
||||
assert "echo world" in commands[1]
|
||||
|
||||
|
||||
class TestShellSafetyChecker:
|
||||
def test_safe_command(self):
|
||||
"""Test that safe commands pass check."""
|
||||
checker = ShellSafetyChecker()
|
||||
is_safe, warnings = checker.check_command("ls -la")
|
||||
|
||||
assert is_safe == True
|
||||
assert len(warnings) == 0
|
||||
|
||||
"""Test shell safety checking."""
|
||||
|
||||
def test_dangerous_command(self):
|
||||
"""Test that dangerous commands are flagged."""
|
||||
checker = ShellSafetyChecker()
|
||||
"""Test dangerous command detection."""
|
||||
checker = ShellSafetyChecker("moderate")
|
||||
is_safe, warnings = checker.check_command("rm -rf /")
|
||||
|
||||
assert is_safe == False
|
||||
assert any("DANGEROUS" in w for w in warnings)
|
||||
|
||||
def test_warning_patterns(self):
|
||||
"""Test that warning patterns generate warnings."""
|
||||
checker = ShellSafetyChecker(safety_level="strict")
|
||||
is_safe, warnings = checker.check_command("rm -r directory")
|
||||
|
||||
assert any("WARNING" in w for w in warnings)
|
||||
|
||||
def test_check_script(self):
|
||||
"""Test checking entire script."""
|
||||
script = "#!/bin/bash\necho hello\nrm -rf /"
|
||||
assert not is_safe
|
||||
|
||||
def test_safe_command(self):
|
||||
"""Test safe command passes."""
|
||||
checker = ShellSafetyChecker("moderate")
|
||||
is_safe, warnings = checker.check_command("ls -la")
|
||||
assert is_safe
|
||||
|
||||
def test_warning_command(self):
|
||||
"""Test warning-level commands."""
|
||||
checker = ShellSafetyChecker("moderate")
|
||||
is_safe, warnings = checker.check_command("rm -r /tmp/*")
|
||||
assert is_safe
|
||||
assert len(warnings) > 0
|
||||
|
||||
def test_script_check(self):
|
||||
"""Test full script safety check."""
|
||||
checker = ShellSafetyChecker()
|
||||
script = "#!/bin/bash\nls\necho done\n"
|
||||
result = checker.check_script(script)
|
||||
|
||||
assert result["is_safe"] == False
|
||||
assert len(result["issues"]) > 0
|
||||
assert "is_safe" in result
|
||||
assert "issues" in result
|
||||
assert "warnings" in result
|
||||
|
||||
def test_strict_mode_blocks_warnings(self):
|
||||
"""Test strict mode blocks warning-level commands."""
|
||||
checker = ShellSafetyChecker("strict")
|
||||
is_safe, warnings = checker.check_command("rm -r /tmp/*")
|
||||
assert not is_safe
|
||||
|
||||
|
||||
class TestPromptTemplates:
|
||||
def test_get_generate_prompt(self):
|
||||
"""Test generation prompt creation."""
|
||||
prompt = PromptTemplates.get_generate_prompt("list files", "bash")
|
||||
|
||||
assert "shell script expert" in prompt
|
||||
assert "list files" in prompt
|
||||
assert "bash" in prompt
|
||||
|
||||
def test_get_explain_prompt(self):
|
||||
"""Test explanation prompt creation."""
|
||||
prompt = PromptTemplates.get_explain_prompt("echo hello", "bash")
|
||||
|
||||
assert "explain" in prompt
|
||||
assert "echo hello" in prompt
|
||||
|
||||
def test_get_refactor_prompt(self):
|
||||
"""Test refactor prompt creation."""
|
||||
prompt = PromptTemplates.get_refactor_prompt("rm -rf /", "bash")
|
||||
|
||||
assert "security" in prompt.lower()
|
||||
assert "rm -rf /" in prompt
|
||||
class TestShellGenerator:
|
||||
"""Test shell generation."""
|
||||
|
||||
def test_generator_initialization(self):
|
||||
"""Test generator creates properly."""
|
||||
generator = ShellGenerator()
|
||||
assert generator.parser is not None
|
||||
assert generator.templates is not None
|
||||
|
||||
def test_extract_commands_from_raw(self):
|
||||
"""Test command extraction from raw response."""
|
||||
generator = ShellGenerator()
|
||||
raw = "echo hello\nls -la\ncat file.txt"
|
||||
commands = generator._extract_commands(raw)
|
||||
assert len(commands) == 3
|
||||
assert "echo hello" in commands
|
||||
|
||||
def test_generate_summary(self):
|
||||
"""Test summary generation."""
|
||||
generator = ShellGenerator()
|
||||
summary = generator._generate_summary(["cmd1", "cmd2"], "raw")
|
||||
assert "2 command" in summary
|
||||
|
||||
@@ -1,100 +1,237 @@
|
||||
"""Tests for history learning module."""
|
||||
"""Tests for ShellGenius history module."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from shellgenius.history import HistoryStorage, HistoryLearner, HistoryEntry
|
||||
from shellgenius.history import HistoryEntry, HistoryLearner, HistoryStorage
|
||||
|
||||
|
||||
class TestHistoryEntry:
|
||||
"""Test history entry dataclass."""
|
||||
|
||||
def test_entry_creation(self):
|
||||
"""Test creating a history entry."""
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test description",
|
||||
commands=["cmd1", "cmd2"],
|
||||
shell_type="bash",
|
||||
)
|
||||
assert entry.id == "test-id"
|
||||
assert len(entry.commands) == 2
|
||||
assert entry.usage_count == 1
|
||||
|
||||
|
||||
class TestHistoryStorage:
|
||||
@pytest.fixture
|
||||
def temp_storage(self):
|
||||
"""Create temporary storage for testing."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write("entries: []\nmetadata:\n version: '1.0'")
|
||||
f.flush()
|
||||
yield f.name
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_init_storage(self, temp_storage):
|
||||
"""Test storage initialization."""
|
||||
storage = HistoryStorage(temp_storage)
|
||||
|
||||
assert storage.storage_path == temp_storage
|
||||
|
||||
def test_add_and_get_entry(self, temp_storage):
|
||||
"""Test adding and retrieving history entries."""
|
||||
storage = HistoryStorage(temp_storage)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test command",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0].description == "Test command"
|
||||
|
||||
def test_search_history(self, temp_storage):
|
||||
"""Test searching history."""
|
||||
storage = HistoryStorage(temp_storage)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="List files command",
|
||||
commands=["ls -la"],
|
||||
shell_type="bash",
|
||||
)
|
||||
storage.add_entry(entry)
|
||||
|
||||
results = storage.search("files")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_clear_history(self, temp_storage):
|
||||
"""Test clearing history."""
|
||||
storage = HistoryStorage(temp_storage)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo"],
|
||||
shell_type="bash",
|
||||
)
|
||||
storage.add_entry(entry)
|
||||
|
||||
storage.clear()
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 0
|
||||
"""Test history storage."""
|
||||
|
||||
def test_add_and_get_entry(self):
|
||||
"""Test adding and retrieving entries."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0].description == "Test"
|
||||
|
||||
def test_search_entries(self):
|
||||
"""Test searching entries."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Find Python files",
|
||||
commands=["find . -name *.py"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
results = storage.search("python", limit=10)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_popular(self):
|
||||
"""Test getting popular entries."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
usage_count=5,
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
popular = storage.get_popular(limit=10)
|
||||
|
||||
assert len(popular) == 1
|
||||
assert popular[0].usage_count == 5
|
||||
|
||||
def test_update_usage(self):
|
||||
"""Test updating usage count."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
storage.update_usage("test-id")
|
||||
|
||||
entries = storage.get_entries()
|
||||
assert entries[0].usage_count == 2
|
||||
|
||||
def test_delete_entry(self):
|
||||
"""Test deleting entry."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test",
|
||||
commands=["echo hello"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
storage.delete_entry("test-id")
|
||||
|
||||
entries = storage.get_entries()
|
||||
assert len(entries) == 0
|
||||
|
||||
def test_clear_history(self):
|
||||
"""Test clearing all history."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
for i in range(3):
|
||||
entry = HistoryEntry(
|
||||
id=f"test-{i}",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description=f"Test {i}",
|
||||
commands=[f"echo {i}"],
|
||||
shell_type="bash",
|
||||
)
|
||||
storage.add_entry(entry)
|
||||
|
||||
storage.clear()
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 0
|
||||
|
||||
def test_limit_and_offset(self):
|
||||
"""Test limit and offset for retrieval."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = HistoryStorage(storage_path)
|
||||
|
||||
for i in range(5):
|
||||
entry = HistoryEntry(
|
||||
id=f"test-{i}",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description=f"Test {i}",
|
||||
commands=[f"echo {i}"],
|
||||
shell_type="bash",
|
||||
)
|
||||
storage.add_entry(entry)
|
||||
|
||||
entries = storage.get_entries(limit=2, offset=2)
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
class TestHistoryLearner:
|
||||
def test_learn_command(self):
|
||||
"""Test learning from generated command."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write("entries: []\nmetadata:\n version: '1.0'")
|
||||
f.flush()
|
||||
storage_path = f.name
|
||||
|
||||
try:
|
||||
with patch('shellgenius.history.get_config') as mock_config:
|
||||
mock_config.return_value.get.return_value = storage_path
|
||||
mock_config.return_value.is_history_enabled.return_value = True
|
||||
|
||||
learner = HistoryLearner()
|
||||
entry = learner.learn("list files", ["ls -la"], "bash")
|
||||
|
||||
assert entry.description == "list files"
|
||||
assert "ls -la" in entry.commands
|
||||
finally:
|
||||
os.unlink(storage_path)
|
||||
"""Test history learning functionality."""
|
||||
|
||||
def test_learn(self):
|
||||
"""Test learning from commands."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
entry = learner.learn(
|
||||
description="List files",
|
||||
commands=["ls -la"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
assert entry is not None
|
||||
assert "List files" in entry.description
|
||||
|
||||
def test_find_similar(self):
|
||||
"""Test finding similar commands."""
|
||||
n with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
learner.learn(
|
||||
description="Find Python files",
|
||||
commands=["find . -name *.py"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
similar = learner.find_similar("python", limit=5)
|
||||
assert len(similar) == 1
|
||||
|
||||
def test_suggest(self):
|
||||
"""Test getting suggestions."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
learner.learn(
|
||||
description="Find Python files",
|
||||
commands=["find . -name *.py"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
suggestions = learner.suggest("python files", limit=3)
|
||||
assert len(suggestions) >= 1
|
||||
|
||||
def test_get_frequent_patterns(self):
|
||||
"""Test analyzing frequent patterns."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
learner = HistoryLearner()
|
||||
learner.storage = HistoryStorage(storage_path)
|
||||
|
||||
learner.learn(
|
||||
description="Find files",
|
||||
commands=["find . -type f"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
patterns = learner.get_frequent_patterns()
|
||||
assert "common_commands" in patterns
|
||||
assert "shell_types" in patterns
|
||||
|
||||
@@ -1,46 +1,123 @@
|
||||
"""Integration tests for ShellGenius."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from shellgenius.cli import main
|
||||
from shellgenius.config import get_config
|
||||
from shellgenius.generation import generate_shell
|
||||
from shellgenius.explainer import explain_script
|
||||
from shellgenius.refactoring import refactor_script
|
||||
from shellgenius.history import get_history_storage
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
def test_cli_commands_registered(self):
|
||||
"""Test that all CLI commands are registered."""
|
||||
from shellgenius.cli import main
|
||||
|
||||
commands = [cmd.name for cmd in main.commands]
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Integration tests for configuration."""
|
||||
|
||||
def test_config_loads(self):
|
||||
"""Test configuration loads successfully."""
|
||||
config = get_config()
|
||||
assert config is not None
|
||||
assert hasattr(config, "ollama_host")
|
||||
assert hasattr(config, "ollama_model")
|
||||
|
||||
|
||||
class TestGenerationIntegration:
|
||||
"""Integration tests for shell generation."""
|
||||
|
||||
def test_generate_returns_script(self):
|
||||
"""Test generation returns expected structure."""
|
||||
result = generate_shell("list files in current directory", shell_type="bash")
|
||||
assert result is not None
|
||||
assert hasattr(result, "commands")
|
||||
assert hasattr(result, "explanation")
|
||||
assert hasattr(result, "shell_type")
|
||||
assert result.shell_type == "bash"
|
||||
|
||||
|
||||
class TestExplainerIntegration:
|
||||
"""Integration tests for script explanation."""
|
||||
|
||||
def test_explain_returns_structure(self):
|
||||
"""Test explanation returns expected structure."""
|
||||
script = "#!/bin/bash\necho \"Hello, World!\"\nls -la\n"
|
||||
result = explain_script(script, detailed=False)
|
||||
assert result is not None
|
||||
assert hasattr(result, "shell_type")
|
||||
assert hasattr(result, "line_explanations")
|
||||
assert hasattr(result, "overall_purpose")
|
||||
|
||||
|
||||
class TestRefactoringIntegration:
|
||||
"""Integration tests for refactoring."""
|
||||
|
||||
def test_refactor_returns_structure(self):
|
||||
"""Test refactoring returns expected structure."""
|
||||
script = "#!/bin/bash\necho \"Hello\"\nls\n"
|
||||
result = refactor_script(script, include_suggestions=False)
|
||||
assert result is not None
|
||||
assert hasattr(result, "shell_type")
|
||||
assert hasattr(result, "issues")
|
||||
assert hasattr(result, "score")
|
||||
assert hasattr(result, "suggestions")
|
||||
|
||||
|
||||
class TestHistoryIntegration:
|
||||
"""Integration tests for history."""
|
||||
|
||||
def test_history_storage(self):
|
||||
"""Test history storage operations."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage_path = os.path.join(tmpdir, "history.yaml")
|
||||
storage = get_history_storage(storage_path)
|
||||
|
||||
from shellgenius.history import HistoryEntry
|
||||
|
||||
entry = HistoryEntry(
|
||||
id="test-id",
|
||||
timestamp="2024-01-01T00:00:00",
|
||||
description="Test command",
|
||||
commands=["echo test"],
|
||||
shell_type="bash",
|
||||
)
|
||||
|
||||
storage.add_entry(entry)
|
||||
entries = storage.get_entries()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0].description == "Test command"
|
||||
|
||||
|
||||
class TestCLIIntegration:
|
||||
"""Integration tests for CLI."""
|
||||
|
||||
def test_cli_group_exists(self):
|
||||
"""Test CLI main group exists."""
|
||||
assert main is not None
|
||||
assert hasattr(main, 'commands')
|
||||
|
||||
def test_cli_subcommands(self):
|
||||
"""Test CLI has expected subcommands."""
|
||||
commands = list(main.commands.keys())
|
||||
assert "generate" in commands
|
||||
assert "explain" in commands
|
||||
assert "refactor" in commands
|
||||
assert "history" in commands
|
||||
assert "models" in commands
|
||||
assert "check" in commands
|
||||
assert "version" in commands
|
||||
|
||||
def test_module_imports(self):
|
||||
"""Test that all modules can be imported."""
|
||||
from shellgenius import __version__
|
||||
from shellgenius.cli import main
|
||||
from shellgenius.config import Config
|
||||
from shellgenius.generation import ShellGenerator
|
||||
from shellgenius.explainer import ShellExplainer
|
||||
from shellgenius.refactoring import RefactoringAnalyzer
|
||||
from shellgenius.history import HistoryLearner
|
||||
from shellgenius.ollama_client import OllamaClient
|
||||
|
||||
assert __version__ == "0.1.0"
|
||||
|
||||
def test_package_structure(self):
|
||||
"""Test that package structure is correct."""
|
||||
import shellgenius
|
||||
|
||||
assert hasattr(shellgenius, "__version__")
|
||||
assert hasattr(shellgenius, "cli")
|
||||
assert hasattr(shellgenius, "config")
|
||||
assert hasattr(shellgenius, "generation")
|
||||
assert hasattr(shellgenius, "explainer")
|
||||
assert hasattr(shellgenius, "refactoring")
|
||||
assert hasattr(shellgenius, "history")
|
||||
assert hasattr(shellgenius, "ollama_client")
|
||||
|
||||
|
||||
class TestFullWorkflowIntegration:
|
||||
"""Full workflow integration tests."""
|
||||
|
||||
def test_generate_explain_workflow(self):
|
||||
"""Test generating and explaining a script."""
|
||||
script = "#!/bin/bash\necho hello"
|
||||
result = explain_script(script, detailed=False)
|
||||
assert result is not None
|
||||
|
||||
def test_refactor_with_suggestions(self):
|
||||
"""Test refactoring with AI suggestions."""
|
||||
script = "#!/bin/bash\necho test"
|
||||
result = refactor_script(script, include_suggestions=True)
|
||||
assert result is not None
|
||||
assert hasattr(result, "suggestions")
|
||||
|
||||
@@ -1,75 +1,131 @@
|
||||
"""Tests for Ollama client module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from shellgenius.ollama_client import OllamaClient, get_ollama_client
|
||||
|
||||
|
||||
class TestOllamaClient:
|
||||
def test_init(self):
|
||||
"""Test Ollama client functionality."""
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_client_initialization(self, mock_client_class):
|
||||
"""Test client initialization."""
|
||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
||||
mock_config.return_value.ollama_host = "localhost:11434"
|
||||
mock_config.return_value.ollama_model = "codellama"
|
||||
|
||||
client = OllamaClient()
|
||||
|
||||
assert client.host == "localhost:11434"
|
||||
assert client.model == "codellama"
|
||||
|
||||
def test_is_available(self):
|
||||
"""Test availability check."""
|
||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
||||
mock_config.return_value.ollama_host = "localhost:11434"
|
||||
mock_config.return_value.ollama_model = "codellama"
|
||||
|
||||
client = OllamaClient()
|
||||
|
||||
with patch.object(client, 'list_models', return_value=["codellama"]):
|
||||
assert client.is_available() == True
|
||||
|
||||
def test_list_models(self):
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient(host="localhost:11434", model="codellama")
|
||||
|
||||
assert client.host == "localhost:11434"
|
||||
assert client.model == "codellama"
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_list_models(self, mock_client_class):
|
||||
"""Test listing models."""
|
||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
||||
mock_config.return_value.ollama_host = "localhost:11434"
|
||||
mock_config.return_value.ollama_model = "codellama"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.list.return_value = {
|
||||
"models": [
|
||||
{"name": "codellama:latest"},
|
||||
{"name": "llama2:latest"},
|
||||
]
|
||||
}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
models = client.list_models()
|
||||
|
||||
assert len(models) == 2
|
||||
assert "codellama:latest" in models
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_generate_success(self, mock_client_class):
|
||||
"""Test successful generation."""
|
||||
mock_client = Mock()
|
||||
mock_client.generate.return_value = {
|
||||
"response": "echo hello"
|
||||
}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.generate("test prompt")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["response"]["response"] == "echo hello"
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_generate_failure(self, mock_client_class):
|
||||
"""Test failed generation."""
|
||||
mock_client = Mock()
|
||||
mock_client.generate.side_effect = Exception("Connection failed")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.generate("test prompt")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_chat_success(self, mock_client_class):
|
||||
"""Test successful chat."""
|
||||
mock_client = Mock()
|
||||
mock_client.chat.return_value = {
|
||||
"message": {"content": "Hello!"}
|
||||
}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = client.chat(messages)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_pull_model(self, mock_client_class):
|
||||
"""Test pulling a model."""
|
||||
mock_client = Mock()
|
||||
mock_client.pull.return_value = {"status": "success"}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
result = client.pull_model("llama2")
|
||||
|
||||
assert result is True
|
||||
mock_client.pull.assert_called_once_with("llama2")
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_is_available_true(self, mock_client_class):
|
||||
"""Test availability check when available."""
|
||||
mock_client = Mock()
|
||||
mock_client.list.return_value = {"models": []}
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = OllamaClient()
|
||||
assert client.is_available() is True
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_is_available_false(self, mock_client_class):
|
||||
"""Test availability check when not available."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.list.side_effect = Exception("Connection refused")
|
||||
mock_client_class.return_value = mock_instance
|
||||
|
||||
with patch.object(OllamaClient, 'list_models', side_effect=Exception("Connection refused")):
|
||||
client = OllamaClient()
|
||||
|
||||
mock_response = {"models": [{"name": "codellama"}, {"name": "llama2"}]}
|
||||
|
||||
with patch.object(client.client, 'list', return_value=mock_response):
|
||||
models = client.list_models()
|
||||
|
||||
assert len(models) == 2
|
||||
assert "codellama" in models
|
||||
|
||||
def test_generate(self):
|
||||
"""Test text generation."""
|
||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
||||
mock_config.return_value.ollama_host = "localhost:11434"
|
||||
mock_config.return_value.ollama_model = "codellama"
|
||||
|
||||
client = OllamaClient()
|
||||
|
||||
mock_response = {"response": "Generated text"}
|
||||
|
||||
with patch.object(client.client, 'generate', return_value=mock_response):
|
||||
result = client.generate("test prompt")
|
||||
|
||||
assert result["success"] == True
|
||||
assert "Generated text" in str(result["response"])
|
||||
result = client.is_available()
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetOllamaClient:
|
||||
def test_convenience_function(self):
|
||||
"""Test the convenience function for getting client."""
|
||||
with patch('shellgenius.ollama_client.get_config') as mock_config:
|
||||
mock_config.return_value.ollama_host = "localhost:11434"
|
||||
mock_config.return_value.ollama_model = "custom-model"
|
||||
|
||||
client = get_ollama_client(host="custom:9999", model="custom-model")
|
||||
|
||||
assert client.host == "custom:9999"
|
||||
assert client.model == "custom-model"
|
||||
"""Test get_ollama_client convenience function."""
|
||||
|
||||
@patch('shellgenius.ollama_client.ollama.Client')
|
||||
def test_get_client(self, mock_client_class):
|
||||
"""Test getting client instance."""
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = get_ollama_client(host="localhost:11434")
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
assert client.host == "localhost:11434"
|
||||
|
||||
Reference in New Issue
Block a user