diff --git a/shell_memory/database.py b/shell_memory/database.py new file mode 100644 index 0000000..1b4ac83 --- /dev/null +++ b/shell_memory/database.py @@ -0,0 +1,271 @@ +"""Database management for Shell Memory CLI.""" + +import os +import sqlite3 +from datetime import datetime +from typing import List, Optional, Dict, Any +from contextlib import contextmanager + +from .models import Command, Pattern, Session, ScriptTemplate + + +class Database: + """Manages SQLite database operations.""" + + def __init__(self, db_path: Optional[str] = None): + if db_path is None: + home = os.path.expanduser("~/.shell_memory") + os.makedirs(home, exist_ok=True) + db_path = os.path.join(home, "shell_memory.db") + self.db_path = db_path + self.init_db() + + @contextmanager + def get_connection(self): + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + def init_db(self): + with self.get_connection() as conn: + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS commands ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + command TEXT NOT NULL, + description TEXT, + tags TEXT, + created_at TEXT, + usage_count INTEGER DEFAULT 0 + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS patterns ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + command_ids TEXT, + frequency INTEGER DEFAULT 1, + last_seen TEXT + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + commands TEXT, + start_time TEXT, + end_time TEXT + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS script_templates ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keywords TEXT, + template TEXT, + description TEXT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_commands_tags ON commands(tags) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_commands_created ON commands(created_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency) + """) + + conn.commit() + + def add_command(self, command: Command) -> Optional[int]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO commands (command, description, tags, created_at, usage_count) + VALUES (?, ?, ?, ?, ?) + """, ( + command.command, + command.description, + ",".join(command.tags), + command.created_at.isoformat(), + command.usage_count, + )) + conn.commit() + return cursor.lastrowid + + def get_command(self, command_id: int) -> Optional[Command]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM commands WHERE id = ?", (command_id,)) + row = cursor.fetchone() + if row: + return self._row_to_command(row) + return None + + def get_all_commands(self) -> List[Command]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM commands ORDER BY created_at DESC") + return [self._row_to_command(row) for row in cursor.fetchall()] + + def search_commands(self, query: str, limit: int = 10) -> List[Command]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT * FROM commands + WHERE command LIKE ? OR description LIKE ? OR tags LIKE ? + ORDER BY usage_count DESC + LIMIT ? + """, (f"%{query}%", f"%{query}%", f"%{query}%", limit)) + return [self._row_to_command(row) for row in cursor.fetchall()] + + def delete_command(self, command_id: int) -> bool: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM commands WHERE id = ?", (command_id,)) + conn.commit() + return cursor.rowcount > 0 + + def update_command_usage(self, command_id: int): + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE commands SET usage_count = usage_count + 1 WHERE id = ? + """, (command_id,)) + conn.commit() + + def add_pattern(self, pattern: Pattern) -> Optional[int]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO patterns (name, command_ids, frequency, last_seen) + VALUES (?, ?, ?, ?) + """, ( + pattern.name, + ",".join(map(str, pattern.command_ids)), + pattern.frequency, + pattern.last_seen.isoformat(), + )) + conn.commit() + return cursor.lastrowid + + def get_patterns(self, min_frequency: int = 2) -> List[Pattern]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT * FROM patterns WHERE frequency >= ? ORDER BY frequency DESC + """, (min_frequency,)) + return [self._row_to_pattern(row) for row in cursor.fetchall()] + + def update_pattern_frequency(self, pattern_id: int): + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE patterns SET frequency = frequency + 1, last_seen = ? + WHERE id = ? + """, (datetime.now().isoformat(), pattern_id)) + conn.commit() + + def add_session(self, session: Session) -> Optional[int]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO sessions (name, commands, start_time, end_time) + VALUES (?, ?, ?, ?) + """, ( + session.name, + ",".join(map(str, session.commands)), + session.start_time.isoformat(), + session.end_time.isoformat() if session.end_time else None, + )) + conn.commit() + return cursor.lastrowid + + def get_sessions(self) -> List[Session]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM sessions ORDER BY start_time DESC") + return [self._row_to_session(row) for row in cursor.fetchall()] + + def get_session(self, session_id: int) -> Optional[Session]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM sessions WHERE id = ?", (session_id,)) + row = cursor.fetchone() + if row: + return self._row_to_session(row) + return None + + def add_script_template(self, template: ScriptTemplate) -> Optional[int]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO script_templates (keywords, template, description) + VALUES (?, ?, ?) + """, ( + ",".join(template.keywords), + template.template, + template.description, + )) + conn.commit() + return cursor.lastrowid + + def get_script_templates(self) -> List[ScriptTemplate]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM script_templates") + return [self._row_to_template(row) for row in cursor.fetchall()] + + def search_script_templates(self, query: str) -> List[ScriptTemplate]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT * FROM script_templates + WHERE keywords LIKE ? OR description LIKE ? + """, (f"%{query}%", f"%{query}%")) + return [self._row_to_template(row) for row in cursor.fetchall()] + + def _row_to_command(self, row: sqlite3.Row) -> Command: + return Command( + id=row["id"], + command=row["command"], + description=row["description"] or "", + tags=row["tags"].split(",") if row["tags"] else [], + created_at=datetime.fromisoformat(row["created_at"]), + usage_count=row["usage_count"], + ) + + def _row_to_pattern(self, row: sqlite3.Row) -> Pattern: + return Pattern( + id=row["id"], + name=row["name"], + command_ids=[int(x) for x in row["command_ids"].split(",")] if row["command_ids"] else [], + frequency=row["frequency"], + last_seen=datetime.fromisoformat(row["last_seen"]), + ) + + def _row_to_session(self, row: sqlite3.Row) -> Session: + commands = row["commands"].split(",") if row["commands"] else [] + return Session( + id=row["id"], + name=row["name"], + commands=commands, + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]) if row["end_time"] else None, + ) + + def _row_to_template(self, row: sqlite3.Row) -> ScriptTemplate: + return ScriptTemplate( + id=row["id"], + keywords=row["keywords"].split(",") if row["keywords"] else [], + template=row["template"], + description=row["description"] or "", + ) \ No newline at end of file