Initial commit: Add shell-memory-cli project
A CLI tool that learns from terminal command patterns to automate repetitive workflows. Features: - Command recording with tags and descriptions - Pattern detection for command sequences - Session recording and replay - Natural language script generation
This commit is contained in:
137
tests/test_database.py
Normal file
137
tests/test_database.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Tests for database operations."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from shell_memory.database import Database
|
||||
from shell_memory.models import Command, Pattern, Session, ScriptTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
db = Database(db_path)
|
||||
yield db
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
class TestDatabaseCommandOperations:
|
||||
def test_add_command(self, temp_db):
|
||||
cmd = Command(
|
||||
command="ls -la",
|
||||
description="List files",
|
||||
tags=["filesystem", "list"],
|
||||
)
|
||||
cmd_id = temp_db.add_command(cmd)
|
||||
assert cmd_id is not None
|
||||
|
||||
def test_get_command(self, temp_db):
|
||||
cmd = Command(command="pwd", description="Print directory")
|
||||
cmd_id = temp_db.add_command(cmd)
|
||||
retrieved = temp_db.get_command(cmd_id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.command == "pwd"
|
||||
|
||||
def test_get_all_commands(self, temp_db):
|
||||
temp_db.add_command(Command(command="cmd1"))
|
||||
temp_db.add_command(Command(command="cmd2"))
|
||||
commands = temp_db.get_all_commands()
|
||||
assert len(commands) == 2
|
||||
|
||||
def test_search_commands(self, temp_db):
|
||||
temp_db.add_command(Command(command="git status", description="Check repo"))
|
||||
temp_db.add_command(Command(command="git diff", description="Show changes"))
|
||||
results = temp_db.search_commands("git")
|
||||
assert len(results) >= 2
|
||||
|
||||
def test_delete_command(self, temp_db):
|
||||
cmd = Command(command="to_delete")
|
||||
cmd_id = temp_db.add_command(cmd)
|
||||
result = temp_db.delete_command(cmd_id)
|
||||
assert result is True
|
||||
retrieved = temp_db.get_command(cmd_id)
|
||||
assert retrieved is None
|
||||
|
||||
def test_update_command_usage(self, temp_db):
|
||||
cmd = Command(command="frequent_cmd")
|
||||
cmd_id = temp_db.add_command(cmd)
|
||||
temp_db.update_command_usage(cmd_id)
|
||||
temp_db.update_command_usage(cmd_id)
|
||||
retrieved = temp_db.get_command(cmd_id)
|
||||
assert retrieved.usage_count == 2
|
||||
|
||||
|
||||
class TestDatabasePatternOperations:
|
||||
def test_add_pattern(self, temp_db):
|
||||
pattern = Pattern(
|
||||
name="Git workflow",
|
||||
command_ids=[1, 2, 3],
|
||||
frequency=5,
|
||||
)
|
||||
pattern_id = temp_db.add_pattern(pattern)
|
||||
assert pattern_id is not None
|
||||
|
||||
def test_get_patterns(self, temp_db):
|
||||
temp_db.add_pattern(Pattern(command_ids=[1, 2], frequency=3))
|
||||
temp_db.add_pattern(Pattern(command_ids=[3, 4], frequency=1))
|
||||
patterns = temp_db.get_patterns(min_frequency=2)
|
||||
assert len(patterns) == 1
|
||||
|
||||
def test_update_pattern_frequency(self, temp_db):
|
||||
pattern = Pattern(command_ids=[1, 2], frequency=1)
|
||||
pattern_id = temp_db.add_pattern(pattern)
|
||||
temp_db.update_pattern_frequency(pattern_id)
|
||||
patterns = temp_db.get_patterns(min_frequency=1)
|
||||
assert patterns[0].frequency == 2
|
||||
|
||||
|
||||
class TestDatabaseSessionOperations:
|
||||
def test_add_session(self, temp_db):
|
||||
session = Session(
|
||||
name="Test session",
|
||||
commands=[{"command": "ls"}, {"command": "pwd"}],
|
||||
)
|
||||
session_id = temp_db.add_session(session)
|
||||
assert session_id is not None
|
||||
|
||||
def test_get_sessions(self, temp_db):
|
||||
temp_db.add_session(Session(name="session1"))
|
||||
temp_db.add_session(Session(name="session2"))
|
||||
sessions = temp_db.get_sessions()
|
||||
assert len(sessions) == 2
|
||||
|
||||
def test_get_session(self, temp_db):
|
||||
session = Session(name="my session")
|
||||
session_id = temp_db.add_session(session)
|
||||
retrieved = temp_db.get_session(session_id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "my session"
|
||||
|
||||
|
||||
class TestDatabaseScriptTemplateOperations:
|
||||
def test_add_script_template(self, temp_db):
|
||||
template = ScriptTemplate(
|
||||
keywords=["deploy", "app"],
|
||||
template="#!/bin/bash\necho deploy",
|
||||
description="Deploy app",
|
||||
)
|
||||
template_id = temp_db.add_script_template(template)
|
||||
assert template_id is not None
|
||||
|
||||
def test_get_script_templates(self, temp_db):
|
||||
temp_db.add_script_template(ScriptTemplate(keywords=["test"], template="pytest"))
|
||||
templates = temp_db.get_script_templates()
|
||||
assert len(templates) >= 1
|
||||
|
||||
def test_search_script_templates(self, temp_db):
|
||||
temp_db.add_script_template(ScriptTemplate(
|
||||
keywords=["docker", "build"],
|
||||
template="docker build",
|
||||
description="Build Docker image",
|
||||
))
|
||||
results = temp_db.search_script_templates("docker")
|
||||
assert len(results) >= 1
|
||||
Reference in New Issue
Block a user