diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..e23ba18 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,137 @@ +"""Tests for database operations.""" + +import pytest +import tempfile +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from shell_memory.database import Database +from shell_memory.models import Command, Pattern, Session, ScriptTemplate + + +@pytest.fixture +def temp_db(): + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + db = Database(db_path) + yield db + os.unlink(db_path) + + +class TestDatabaseCommandOperations: + def test_add_command(self, temp_db): + cmd = Command( + command="ls -la", + description="List files", + tags=["filesystem", "list"], + ) + cmd_id = temp_db.add_command(cmd) + assert cmd_id is not None + + def test_get_command(self, temp_db): + cmd = Command(command="pwd", description="Print directory") + cmd_id = temp_db.add_command(cmd) + retrieved = temp_db.get_command(cmd_id) + assert retrieved is not None + assert retrieved.command == "pwd" + + def test_get_all_commands(self, temp_db): + temp_db.add_command(Command(command="cmd1")) + temp_db.add_command(Command(command="cmd2")) + commands = temp_db.get_all_commands() + assert len(commands) == 2 + + def test_search_commands(self, temp_db): + temp_db.add_command(Command(command="git status", description="Check repo")) + temp_db.add_command(Command(command="git diff", description="Show changes")) + results = temp_db.search_commands("git") + assert len(results) >= 2 + + def test_delete_command(self, temp_db): + cmd = Command(command="to_delete") + cmd_id = temp_db.add_command(cmd) + result = temp_db.delete_command(cmd_id) + assert result is True + retrieved = temp_db.get_command(cmd_id) + assert retrieved is None + + def test_update_command_usage(self, temp_db): + cmd = Command(command="frequent_cmd") + cmd_id = temp_db.add_command(cmd) + temp_db.update_command_usage(cmd_id) + temp_db.update_command_usage(cmd_id) + retrieved = temp_db.get_command(cmd_id) + assert retrieved.usage_count == 2 + + +class TestDatabasePatternOperations: + def test_add_pattern(self, temp_db): + pattern = Pattern( + name="Git workflow", + command_ids=[1, 2, 3], + frequency=5, + ) + pattern_id = temp_db.add_pattern(pattern) + assert pattern_id is not None + + def test_get_patterns(self, temp_db): + temp_db.add_pattern(Pattern(command_ids=[1, 2], frequency=3)) + temp_db.add_pattern(Pattern(command_ids=[3, 4], frequency=1)) + patterns = temp_db.get_patterns(min_frequency=2) + assert len(patterns) == 1 + + def test_update_pattern_frequency(self, temp_db): + pattern = Pattern(command_ids=[1, 2], frequency=1) + pattern_id = temp_db.add_pattern(pattern) + temp_db.update_pattern_frequency(pattern_id) + patterns = temp_db.get_patterns(min_frequency=1) + assert patterns[0].frequency == 2 + + +class TestDatabaseSessionOperations: + def test_add_session(self, temp_db): + session = Session( + name="Test session", + commands=[{"command": "ls"}, {"command": "pwd"}], + ) + session_id = temp_db.add_session(session) + assert session_id is not None + + def test_get_sessions(self, temp_db): + temp_db.add_session(Session(name="session1")) + temp_db.add_session(Session(name="session2")) + sessions = temp_db.get_sessions() + assert len(sessions) == 2 + + def test_get_session(self, temp_db): + session = Session(name="my session") + session_id = temp_db.add_session(session) + retrieved = temp_db.get_session(session_id) + assert retrieved is not None + assert retrieved.name == "my session" + + +class TestDatabaseScriptTemplateOperations: + def test_add_script_template(self, temp_db): + template = ScriptTemplate( + keywords=["deploy", "app"], + template="#!/bin/bash\necho deploy", + description="Deploy app", + ) + template_id = temp_db.add_script_template(template) + assert template_id is not None + + def test_get_script_templates(self, temp_db): + temp_db.add_script_template(ScriptTemplate(keywords=["test"], template="pytest")) + templates = temp_db.get_script_templates() + assert len(templates) >= 1 + + def test_search_script_templates(self, temp_db): + temp_db.add_script_template(ScriptTemplate( + keywords=["docker", "build"], + template="docker build", + description="Build Docker image", + )) + results = temp_db.search_script_templates("docker") + assert len(results) >= 1 \ No newline at end of file