"""Tests for database operations.""" import pytest import tempfile import os import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from shell_memory.database import Database from shell_memory.models import Command, Pattern, Session, ScriptTemplate @pytest.fixture def temp_db(): with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: db_path = f.name db = Database(db_path) yield db os.unlink(db_path) class TestDatabaseCommandOperations: def test_add_command(self, temp_db): cmd = Command( command="ls -la", description="List files", tags=["filesystem", "list"], ) cmd_id = temp_db.add_command(cmd) assert cmd_id is not None def test_get_command(self, temp_db): cmd = Command(command="pwd", description="Print directory") cmd_id = temp_db.add_command(cmd) retrieved = temp_db.get_command(cmd_id) assert retrieved is not None assert retrieved.command == "pwd" def test_get_all_commands(self, temp_db): temp_db.add_command(Command(command="cmd1")) temp_db.add_command(Command(command="cmd2")) commands = temp_db.get_all_commands() assert len(commands) == 2 def test_search_commands(self, temp_db): temp_db.add_command(Command(command="git status", description="Check repo")) temp_db.add_command(Command(command="git diff", description="Show changes")) results = temp_db.search_commands("git") assert len(results) >= 2 def test_delete_command(self, temp_db): cmd = Command(command="to_delete") cmd_id = temp_db.add_command(cmd) result = temp_db.delete_command(cmd_id) assert result is True retrieved = temp_db.get_command(cmd_id) assert retrieved is None def test_update_command_usage(self, temp_db): cmd = Command(command="frequent_cmd") cmd_id = temp_db.add_command(cmd) temp_db.update_command_usage(cmd_id) temp_db.update_command_usage(cmd_id) retrieved = temp_db.get_command(cmd_id) assert retrieved.usage_count == 2 class TestDatabasePatternOperations: def test_add_pattern(self, temp_db): pattern = Pattern( name="Git workflow", command_ids=[1, 2, 3], frequency=5, ) pattern_id = temp_db.add_pattern(pattern) assert pattern_id is not None def test_get_patterns(self, temp_db): temp_db.add_pattern(Pattern(command_ids=[1, 2], frequency=3)) temp_db.add_pattern(Pattern(command_ids=[3, 4], frequency=1)) patterns = temp_db.get_patterns(min_frequency=2) assert len(patterns) == 1 def test_update_pattern_frequency(self, temp_db): pattern = Pattern(command_ids=[1, 2], frequency=1) pattern_id = temp_db.add_pattern(pattern) temp_db.update_pattern_frequency(pattern_id) patterns = temp_db.get_patterns(min_frequency=1) assert patterns[0].frequency == 2 class TestDatabaseSessionOperations: def test_add_session(self, temp_db): session = Session( name="Test session", commands=[{"command": "ls"}, {"command": "pwd"}], ) session_id = temp_db.add_session(session) assert session_id is not None def test_get_sessions(self, temp_db): temp_db.add_session(Session(name="session1")) temp_db.add_session(Session(name="session2")) sessions = temp_db.get_sessions() assert len(sessions) == 2 def test_get_session(self, temp_db): session = Session(name="my session") session_id = temp_db.add_session(session) retrieved = temp_db.get_session(session_id) assert retrieved is not None assert retrieved.name == "my session" class TestDatabaseScriptTemplateOperations: def test_add_script_template(self, temp_db): template = ScriptTemplate( keywords=["deploy", "app"], template="#!/bin/bash\necho deploy", description="Deploy app", ) template_id = temp_db.add_script_template(template) assert template_id is not None def test_get_script_templates(self, temp_db): temp_db.add_script_template(ScriptTemplate(keywords=["test"], template="pytest")) templates = temp_db.get_script_templates() assert len(templates) >= 1 def test_search_script_templates(self, temp_db): temp_db.add_script_template(ScriptTemplate( keywords=["docker", "build"], template="docker build", description="Build Docker image", )) results = temp_db.search_script_templates("docker") assert len(results) >= 1