Initial commit: Add shell-memory-cli project
A CLI tool that learns from terminal command patterns to automate repetitive workflows. Features: - Command recording with tags and descriptions - Pattern detection for command sequences - Session recording and replay - Natural language script generation
This commit is contained in:
271
shell_memory/database.py
Normal file
271
shell_memory/database.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
"""Database management for Shell Memory CLI."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from .models import Command, Pattern, Session, ScriptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""Manages SQLite database operations."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Optional[str] = None):
|
||||||
|
if db_path is None:
|
||||||
|
home = os.path.expanduser("~/.shell_memory")
|
||||||
|
os.makedirs(home, exist_ok=True)
|
||||||
|
db_path = os.path.join(home, "shell_memory.db")
|
||||||
|
self.db_path = db_path
|
||||||
|
self.init_db()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_connection(self):
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def init_db(self):
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS commands (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
command TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
tags TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
usage_count INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS patterns (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT,
|
||||||
|
command_ids TEXT,
|
||||||
|
frequency INTEGER DEFAULT 1,
|
||||||
|
last_seen TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT,
|
||||||
|
commands TEXT,
|
||||||
|
start_time TEXT,
|
||||||
|
end_time TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS script_templates (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
keywords TEXT,
|
||||||
|
template TEXT,
|
||||||
|
description TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_commands_tags ON commands(tags)
|
||||||
|
""")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_commands_created ON commands(created_at)
|
||||||
|
""")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def add_command(self, command: Command) -> Optional[int]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO commands (command, description, tags, created_at, usage_count)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
command.command,
|
||||||
|
command.description,
|
||||||
|
",".join(command.tags),
|
||||||
|
command.created_at.isoformat(),
|
||||||
|
command.usage_count,
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def get_command(self, command_id: int) -> Optional[Command]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM commands WHERE id = ?", (command_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return self._row_to_command(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_commands(self) -> List[Command]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM commands ORDER BY created_at DESC")
|
||||||
|
return [self._row_to_command(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def search_commands(self, query: str, limit: int = 10) -> List[Command]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT * FROM commands
|
||||||
|
WHERE command LIKE ? OR description LIKE ? OR tags LIKE ?
|
||||||
|
ORDER BY usage_count DESC
|
||||||
|
LIMIT ?
|
||||||
|
""", (f"%{query}%", f"%{query}%", f"%{query}%", limit))
|
||||||
|
return [self._row_to_command(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def delete_command(self, command_id: int) -> bool:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM commands WHERE id = ?", (command_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def update_command_usage(self, command_id: int):
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
UPDATE commands SET usage_count = usage_count + 1 WHERE id = ?
|
||||||
|
""", (command_id,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def add_pattern(self, pattern: Pattern) -> Optional[int]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO patterns (name, command_ids, frequency, last_seen)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
pattern.name,
|
||||||
|
",".join(map(str, pattern.command_ids)),
|
||||||
|
pattern.frequency,
|
||||||
|
pattern.last_seen.isoformat(),
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def get_patterns(self, min_frequency: int = 2) -> List[Pattern]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT * FROM patterns WHERE frequency >= ? ORDER BY frequency DESC
|
||||||
|
""", (min_frequency,))
|
||||||
|
return [self._row_to_pattern(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def update_pattern_frequency(self, pattern_id: int):
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
UPDATE patterns SET frequency = frequency + 1, last_seen = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""", (datetime.now().isoformat(), pattern_id))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def add_session(self, session: Session) -> Optional[int]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO sessions (name, commands, start_time, end_time)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
session.name,
|
||||||
|
",".join(map(str, session.commands)),
|
||||||
|
session.start_time.isoformat(),
|
||||||
|
session.end_time.isoformat() if session.end_time else None,
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def get_sessions(self) -> List[Session]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM sessions ORDER BY start_time DESC")
|
||||||
|
return [self._row_to_session(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def get_session(self, session_id: int) -> Optional[Session]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return self._row_to_session(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_script_template(self, template: ScriptTemplate) -> Optional[int]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO script_templates (keywords, template, description)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""", (
|
||||||
|
",".join(template.keywords),
|
||||||
|
template.template,
|
||||||
|
template.description,
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def get_script_templates(self) -> List[ScriptTemplate]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM script_templates")
|
||||||
|
return [self._row_to_template(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def search_script_templates(self, query: str) -> List[ScriptTemplate]:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT * FROM script_templates
|
||||||
|
WHERE keywords LIKE ? OR description LIKE ?
|
||||||
|
""", (f"%{query}%", f"%{query}%"))
|
||||||
|
return [self._row_to_template(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def _row_to_command(self, row: sqlite3.Row) -> Command:
|
||||||
|
return Command(
|
||||||
|
id=row["id"],
|
||||||
|
command=row["command"],
|
||||||
|
description=row["description"] or "",
|
||||||
|
tags=row["tags"].split(",") if row["tags"] else [],
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
usage_count=row["usage_count"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _row_to_pattern(self, row: sqlite3.Row) -> Pattern:
|
||||||
|
return Pattern(
|
||||||
|
id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
command_ids=[int(x) for x in row["command_ids"].split(",")] if row["command_ids"] else [],
|
||||||
|
frequency=row["frequency"],
|
||||||
|
last_seen=datetime.fromisoformat(row["last_seen"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _row_to_session(self, row: sqlite3.Row) -> Session:
|
||||||
|
commands = row["commands"].split(",") if row["commands"] else []
|
||||||
|
return Session(
|
||||||
|
id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
commands=commands,
|
||||||
|
start_time=datetime.fromisoformat(row["start_time"]),
|
||||||
|
end_time=datetime.fromisoformat(row["end_time"]) if row["end_time"] else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _row_to_template(self, row: sqlite3.Row) -> ScriptTemplate:
|
||||||
|
return ScriptTemplate(
|
||||||
|
id=row["id"],
|
||||||
|
keywords=row["keywords"].split(",") if row["keywords"] else [],
|
||||||
|
template=row["template"],
|
||||||
|
description=row["description"] or "",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user