Files
shell-memory-cli/tests/test_database.py
7000pctAUTO edabc47fcb Initial commit: Add shell-memory-cli project
A CLI tool that learns from terminal command patterns to automate repetitive workflows.

Features:
- Command recording with tags and descriptions
- Pattern detection for command sequences
- Session recording and replay
- Natural language script generation
2026-01-30 11:56:16 +00:00

137 lines
4.7 KiB
Python

"""Tests for database operations."""
import pytest
import tempfile
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shell_memory.database import Database
from shell_memory.models import Command, Pattern, Session, ScriptTemplate
@pytest.fixture
def temp_db():
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
db = Database(db_path)
yield db
os.unlink(db_path)
class TestDatabaseCommandOperations:
def test_add_command(self, temp_db):
cmd = Command(
command="ls -la",
description="List files",
tags=["filesystem", "list"],
)
cmd_id = temp_db.add_command(cmd)
assert cmd_id is not None
def test_get_command(self, temp_db):
cmd = Command(command="pwd", description="Print directory")
cmd_id = temp_db.add_command(cmd)
retrieved = temp_db.get_command(cmd_id)
assert retrieved is not None
assert retrieved.command == "pwd"
def test_get_all_commands(self, temp_db):
temp_db.add_command(Command(command="cmd1"))
temp_db.add_command(Command(command="cmd2"))
commands = temp_db.get_all_commands()
assert len(commands) == 2
def test_search_commands(self, temp_db):
temp_db.add_command(Command(command="git status", description="Check repo"))
temp_db.add_command(Command(command="git diff", description="Show changes"))
results = temp_db.search_commands("git")
assert len(results) >= 2
def test_delete_command(self, temp_db):
cmd = Command(command="to_delete")
cmd_id = temp_db.add_command(cmd)
result = temp_db.delete_command(cmd_id)
assert result is True
retrieved = temp_db.get_command(cmd_id)
assert retrieved is None
def test_update_command_usage(self, temp_db):
cmd = Command(command="frequent_cmd")
cmd_id = temp_db.add_command(cmd)
temp_db.update_command_usage(cmd_id)
temp_db.update_command_usage(cmd_id)
retrieved = temp_db.get_command(cmd_id)
assert retrieved.usage_count == 2
class TestDatabasePatternOperations:
def test_add_pattern(self, temp_db):
pattern = Pattern(
name="Git workflow",
command_ids=[1, 2, 3],
frequency=5,
)
pattern_id = temp_db.add_pattern(pattern)
assert pattern_id is not None
def test_get_patterns(self, temp_db):
temp_db.add_pattern(Pattern(command_ids=[1, 2], frequency=3))
temp_db.add_pattern(Pattern(command_ids=[3, 4], frequency=1))
patterns = temp_db.get_patterns(min_frequency=2)
assert len(patterns) == 1
def test_update_pattern_frequency(self, temp_db):
pattern = Pattern(command_ids=[1, 2], frequency=1)
pattern_id = temp_db.add_pattern(pattern)
temp_db.update_pattern_frequency(pattern_id)
patterns = temp_db.get_patterns(min_frequency=1)
assert patterns[0].frequency == 2
class TestDatabaseSessionOperations:
def test_add_session(self, temp_db):
session = Session(
name="Test session",
commands=[{"command": "ls"}, {"command": "pwd"}],
)
session_id = temp_db.add_session(session)
assert session_id is not None
def test_get_sessions(self, temp_db):
temp_db.add_session(Session(name="session1"))
temp_db.add_session(Session(name="session2"))
sessions = temp_db.get_sessions()
assert len(sessions) == 2
def test_get_session(self, temp_db):
session = Session(name="my session")
session_id = temp_db.add_session(session)
retrieved = temp_db.get_session(session_id)
assert retrieved is not None
assert retrieved.name == "my session"
class TestDatabaseScriptTemplateOperations:
def test_add_script_template(self, temp_db):
template = ScriptTemplate(
keywords=["deploy", "app"],
template="#!/bin/bash\necho deploy",
description="Deploy app",
)
template_id = temp_db.add_script_template(template)
assert template_id is not None
def test_get_script_templates(self, temp_db):
temp_db.add_script_template(ScriptTemplate(keywords=["test"], template="pytest"))
templates = temp_db.get_script_templates()
assert len(templates) >= 1
def test_search_script_templates(self, temp_db):
temp_db.add_script_template(ScriptTemplate(
keywords=["docker", "build"],
template="docker build",
description="Build Docker image",
))
results = temp_db.search_script_templates("docker")
assert len(results) >= 1