Initial commit: Add shell-memory-cli project
A CLI tool that learns from terminal command patterns to automate repetitive workflows. Features: - Command recording with tags and descriptions - Pattern detection for command sequences - Session recording and replay - Natural language script generation
This commit is contained in:
271
shell_memory/database.py
Normal file
271
shell_memory/database.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Database management for Shell Memory CLI."""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
|
||||
from .models import Command, Pattern, Session, ScriptTemplate
|
||||
|
||||
|
||||
class Database:
|
||||
"""Manages SQLite database operations."""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
if db_path is None:
|
||||
home = os.path.expanduser("~/.shell_memory")
|
||||
os.makedirs(home, exist_ok=True)
|
||||
db_path = os.path.join(home, "shell_memory.db")
|
||||
self.db_path = db_path
|
||||
self.init_db()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def init_db(self):
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS commands (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
command TEXT NOT NULL,
|
||||
description TEXT,
|
||||
tags TEXT,
|
||||
created_at TEXT,
|
||||
usage_count INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS patterns (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
command_ids TEXT,
|
||||
frequency INTEGER DEFAULT 1,
|
||||
last_seen TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
commands TEXT,
|
||||
start_time TEXT,
|
||||
end_time TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS script_templates (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
keywords TEXT,
|
||||
template TEXT,
|
||||
description TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_commands_tags ON commands(tags)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_commands_created ON commands(created_at)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def add_command(self, command: Command) -> Optional[int]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO commands (command, description, tags, created_at, usage_count)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
command.command,
|
||||
command.description,
|
||||
",".join(command.tags),
|
||||
command.created_at.isoformat(),
|
||||
command.usage_count,
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def get_command(self, command_id: int) -> Optional[Command]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM commands WHERE id = ?", (command_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_command(row)
|
||||
return None
|
||||
|
||||
def get_all_commands(self) -> List[Command]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM commands ORDER BY created_at DESC")
|
||||
return [self._row_to_command(row) for row in cursor.fetchall()]
|
||||
|
||||
def search_commands(self, query: str, limit: int = 10) -> List[Command]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM commands
|
||||
WHERE command LIKE ? OR description LIKE ? OR tags LIKE ?
|
||||
ORDER BY usage_count DESC
|
||||
LIMIT ?
|
||||
""", (f"%{query}%", f"%{query}%", f"%{query}%", limit))
|
||||
return [self._row_to_command(row) for row in cursor.fetchall()]
|
||||
|
||||
def delete_command(self, command_id: int) -> bool:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM commands WHERE id = ?", (command_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def update_command_usage(self, command_id: int):
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE commands SET usage_count = usage_count + 1 WHERE id = ?
|
||||
""", (command_id,))
|
||||
conn.commit()
|
||||
|
||||
def add_pattern(self, pattern: Pattern) -> Optional[int]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO patterns (name, command_ids, frequency, last_seen)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (
|
||||
pattern.name,
|
||||
",".join(map(str, pattern.command_ids)),
|
||||
pattern.frequency,
|
||||
pattern.last_seen.isoformat(),
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def get_patterns(self, min_frequency: int = 2) -> List[Pattern]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM patterns WHERE frequency >= ? ORDER BY frequency DESC
|
||||
""", (min_frequency,))
|
||||
return [self._row_to_pattern(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_pattern_frequency(self, pattern_id: int):
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE patterns SET frequency = frequency + 1, last_seen = ?
|
||||
WHERE id = ?
|
||||
""", (datetime.now().isoformat(), pattern_id))
|
||||
conn.commit()
|
||||
|
||||
def add_session(self, session: Session) -> Optional[int]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO sessions (name, commands, start_time, end_time)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (
|
||||
session.name,
|
||||
",".join(map(str, session.commands)),
|
||||
session.start_time.isoformat(),
|
||||
session.end_time.isoformat() if session.end_time else None,
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def get_sessions(self) -> List[Session]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sessions ORDER BY start_time DESC")
|
||||
return [self._row_to_session(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_session(self, session_id: int) -> Optional[Session]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_session(row)
|
||||
return None
|
||||
|
||||
def add_script_template(self, template: ScriptTemplate) -> Optional[int]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO script_templates (keywords, template, description)
|
||||
VALUES (?, ?, ?)
|
||||
""", (
|
||||
",".join(template.keywords),
|
||||
template.template,
|
||||
template.description,
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def get_script_templates(self) -> List[ScriptTemplate]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM script_templates")
|
||||
return [self._row_to_template(row) for row in cursor.fetchall()]
|
||||
|
||||
def search_script_templates(self, query: str) -> List[ScriptTemplate]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM script_templates
|
||||
WHERE keywords LIKE ? OR description LIKE ?
|
||||
""", (f"%{query}%", f"%{query}%"))
|
||||
return [self._row_to_template(row) for row in cursor.fetchall()]
|
||||
|
||||
def _row_to_command(self, row: sqlite3.Row) -> Command:
|
||||
return Command(
|
||||
id=row["id"],
|
||||
command=row["command"],
|
||||
description=row["description"] or "",
|
||||
tags=row["tags"].split(",") if row["tags"] else [],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
usage_count=row["usage_count"],
|
||||
)
|
||||
|
||||
def _row_to_pattern(self, row: sqlite3.Row) -> Pattern:
|
||||
return Pattern(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
command_ids=[int(x) for x in row["command_ids"].split(",")] if row["command_ids"] else [],
|
||||
frequency=row["frequency"],
|
||||
last_seen=datetime.fromisoformat(row["last_seen"]),
|
||||
)
|
||||
|
||||
def _row_to_session(self, row: sqlite3.Row) -> Session:
|
||||
commands = row["commands"].split(",") if row["commands"] else []
|
||||
return Session(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
commands=commands,
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]) if row["end_time"] else None,
|
||||
)
|
||||
|
||||
def _row_to_template(self, row: sqlite3.Row) -> ScriptTemplate:
|
||||
return ScriptTemplate(
|
||||
id=row["id"],
|
||||
keywords=row["keywords"].split(",") if row["keywords"] else [],
|
||||
template=row["template"],
|
||||
description=row["description"] or "",
|
||||
)
|
||||
Reference in New Issue
Block a user