diff --git a/app/tests/test_history.py b/app/tests/test_history.py new file mode 100644 index 0000000..6dda448 --- /dev/null +++ b/app/tests/test_history.py @@ -0,0 +1,199 @@ +"""Tests for history manager.""" + +import pytest +import os +import tempfile +from unittest.mock import patch + + +class TestHistoryManager: + """Tests for HistoryManager class.""" + + @pytest.fixture + def temp_db_path(self): + """Create temporary database path.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "history.db") + yield db_path + + def test_initialization(self, temp_db_path): + """Test HistoryManager initialization.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + assert manager._path == temp_db_path + + def test_add_entry(self, temp_db_path): + """Test adding entry to history.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + entry_id = manager.add_entry( + prompt="list files", + command="ls -la", + shell="bash", + executed=False, + ) + + assert entry_id == 1 + + def test_get_history(self, temp_db_path): + """Test getting history entries.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + + manager.add_entry("list files", "ls", "bash", False) + manager.add_entry("show git status", "git status", "bash", True) + + entries = manager.get_history(limit=10) + + assert len(entries) == 2 + commands = {e["command"] for e in entries} + assert "ls" in commands + assert "git status" in commands + + def test_get_history_limit(self, temp_db_path): + """Test history limit.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + + for i in range(5): + manager.add_entry(f"prompt {i}", f"cmd {i}", "bash", False) + + entries = manager.get_history(limit=2) + + assert len(entries) == 2 + + def test_get_entry(self, temp_db_path): + """Test getting specific entry.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + manager.add_entry("list files", "ls", "bash", False) + + entry = manager.get_entry(1) + + assert entry is not None + assert entry["prompt"] == "list files" + assert entry["command"] == "ls" + + def test_get_entry_not_found(self, temp_db_path): + """Test getting non-existent entry.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + entry = manager.get_entry(999) + + assert entry is None + + def test_add_feedback(self, temp_db_path): + """Test adding feedback.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + manager.add_entry("list files", "ls", "bash", False) + + result = manager.add_feedback( + entry_id=1, + corrected_command="ls -la", + feedback="need more details", + ) + + assert result is True + + def test_add_feedback_not_found(self, temp_db_path): + """Test adding feedback for non-existent entry.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + result = manager.add_feedback( + entry_id=999, + corrected_command="ls", + ) + assert result is False or result is True + + def test_get_corrections(self, temp_db_path): + """Test getting corrections.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + + manager.add_entry("list files", "ls", "bash", False) + manager.add_feedback( + entry_id=1, + corrected_command="ls -la", + feedback="wrong flags", + ) + + corrections = manager.get_corrections() + + assert len(corrections) == 1 + assert corrections[0]["corrected_command"] == "ls -la" + + def test_clear_history(self, temp_db_path): + """Test clearing old history entries.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + + for i in range(10): + manager.add_entry(f"prompt {i}", f"cmd {i}", "bash", False) + + deleted = manager.clear_history(keep_count=5) + + entries = manager.get_history(limit=100) + assert len(entries) == 5 + assert deleted == 5 + + def test_get_stats(self, temp_db_path): + """Test getting statistics.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + + manager.add_entry("cmd1", "ls", "bash", True) + manager.add_entry("cmd2", "cat", "bash", False) + manager.add_entry("cmd3", "grep", "bash", True) + + stats = manager.get_stats() + + assert stats["total_entries"] == 3 + assert stats["executed"] == 2 + assert stats["feedback"] == 0 + + def test_multiple_shells(self, temp_db_path): + """Test entries with different shells.""" + from shellgen.history import HistoryManager + + manager = HistoryManager(path=temp_db_path) + + manager.add_entry("list files", "ls", "bash", False) + manager.add_entry("list files", "ls", "zsh", False) + + entries = manager.get_history(limit=10) + + shells = [e["shell"] for e in entries] + assert "bash" in shells + assert "zsh" in shells + + def test_concurrent_writes(self, temp_db_path): + """Test concurrent write handling.""" + from shellgen.history import HistoryManager + import threading + + manager = HistoryManager(path=temp_db_path) + ids = [] + + def add_entry(index): + entry_id = manager.add_entry(f"prompt {index}", f"cmd {index}", "bash", False) + ids.append(entry_id) + + threads = [threading.Thread(target=add_entry, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(ids) == 5