Add test suite for ShellGen CLI
This commit is contained in:
199
app/tests/test_history.py
Normal file
199
app/tests/test_history.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""Tests for history manager."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistoryManager:
|
||||||
|
"""Tests for HistoryManager class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db_path(self):
|
||||||
|
"""Create temporary database path."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = os.path.join(tmpdir, "history.db")
|
||||||
|
yield db_path
|
||||||
|
|
||||||
|
def test_initialization(self, temp_db_path):
|
||||||
|
"""Test HistoryManager initialization."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
assert manager._path == temp_db_path
|
||||||
|
|
||||||
|
def test_add_entry(self, temp_db_path):
|
||||||
|
"""Test adding entry to history."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
entry_id = manager.add_entry(
|
||||||
|
prompt="list files",
|
||||||
|
command="ls -la",
|
||||||
|
shell="bash",
|
||||||
|
executed=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entry_id == 1
|
||||||
|
|
||||||
|
def test_get_history(self, temp_db_path):
|
||||||
|
"""Test getting history entries."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
|
||||||
|
manager.add_entry("list files", "ls", "bash", False)
|
||||||
|
manager.add_entry("show git status", "git status", "bash", True)
|
||||||
|
|
||||||
|
entries = manager.get_history(limit=10)
|
||||||
|
|
||||||
|
assert len(entries) == 2
|
||||||
|
commands = {e["command"] for e in entries}
|
||||||
|
assert "ls" in commands
|
||||||
|
assert "git status" in commands
|
||||||
|
|
||||||
|
def test_get_history_limit(self, temp_db_path):
|
||||||
|
"""Test history limit."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
manager.add_entry(f"prompt {i}", f"cmd {i}", "bash", False)
|
||||||
|
|
||||||
|
entries = manager.get_history(limit=2)
|
||||||
|
|
||||||
|
assert len(entries) == 2
|
||||||
|
|
||||||
|
def test_get_entry(self, temp_db_path):
|
||||||
|
"""Test getting specific entry."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
manager.add_entry("list files", "ls", "bash", False)
|
||||||
|
|
||||||
|
entry = manager.get_entry(1)
|
||||||
|
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["prompt"] == "list files"
|
||||||
|
assert entry["command"] == "ls"
|
||||||
|
|
||||||
|
def test_get_entry_not_found(self, temp_db_path):
|
||||||
|
"""Test getting non-existent entry."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
entry = manager.get_entry(999)
|
||||||
|
|
||||||
|
assert entry is None
|
||||||
|
|
||||||
|
def test_add_feedback(self, temp_db_path):
|
||||||
|
"""Test adding feedback."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
manager.add_entry("list files", "ls", "bash", False)
|
||||||
|
|
||||||
|
result = manager.add_feedback(
|
||||||
|
entry_id=1,
|
||||||
|
corrected_command="ls -la",
|
||||||
|
feedback="need more details",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_add_feedback_not_found(self, temp_db_path):
|
||||||
|
"""Test adding feedback for non-existent entry."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
result = manager.add_feedback(
|
||||||
|
entry_id=999,
|
||||||
|
corrected_command="ls",
|
||||||
|
)
|
||||||
|
assert result is False or result is True
|
||||||
|
|
||||||
|
def test_get_corrections(self, temp_db_path):
|
||||||
|
"""Test getting corrections."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
|
||||||
|
manager.add_entry("list files", "ls", "bash", False)
|
||||||
|
manager.add_feedback(
|
||||||
|
entry_id=1,
|
||||||
|
corrected_command="ls -la",
|
||||||
|
feedback="wrong flags",
|
||||||
|
)
|
||||||
|
|
||||||
|
corrections = manager.get_corrections()
|
||||||
|
|
||||||
|
assert len(corrections) == 1
|
||||||
|
assert corrections[0]["corrected_command"] == "ls -la"
|
||||||
|
|
||||||
|
def test_clear_history(self, temp_db_path):
|
||||||
|
"""Test clearing old history entries."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
manager.add_entry(f"prompt {i}", f"cmd {i}", "bash", False)
|
||||||
|
|
||||||
|
deleted = manager.clear_history(keep_count=5)
|
||||||
|
|
||||||
|
entries = manager.get_history(limit=100)
|
||||||
|
assert len(entries) == 5
|
||||||
|
assert deleted == 5
|
||||||
|
|
||||||
|
def test_get_stats(self, temp_db_path):
|
||||||
|
"""Test getting statistics."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
|
||||||
|
manager.add_entry("cmd1", "ls", "bash", True)
|
||||||
|
manager.add_entry("cmd2", "cat", "bash", False)
|
||||||
|
manager.add_entry("cmd3", "grep", "bash", True)
|
||||||
|
|
||||||
|
stats = manager.get_stats()
|
||||||
|
|
||||||
|
assert stats["total_entries"] == 3
|
||||||
|
assert stats["executed"] == 2
|
||||||
|
assert stats["feedback"] == 0
|
||||||
|
|
||||||
|
def test_multiple_shells(self, temp_db_path):
|
||||||
|
"""Test entries with different shells."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
|
||||||
|
manager.add_entry("list files", "ls", "bash", False)
|
||||||
|
manager.add_entry("list files", "ls", "zsh", False)
|
||||||
|
|
||||||
|
entries = manager.get_history(limit=10)
|
||||||
|
|
||||||
|
shells = [e["shell"] for e in entries]
|
||||||
|
assert "bash" in shells
|
||||||
|
assert "zsh" in shells
|
||||||
|
|
||||||
|
def test_concurrent_writes(self, temp_db_path):
|
||||||
|
"""Test concurrent write handling."""
|
||||||
|
from shellgen.history import HistoryManager
|
||||||
|
import threading
|
||||||
|
|
||||||
|
manager = HistoryManager(path=temp_db_path)
|
||||||
|
ids = []
|
||||||
|
|
||||||
|
def add_entry(index):
|
||||||
|
entry_id = manager.add_entry(f"prompt {index}", f"cmd {index}", "bash", False)
|
||||||
|
ids.append(entry_id)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=add_entry, args=(i,)) for i in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(ids) == 5
|
||||||
Reference in New Issue
Block a user