"""Database management for Shell Memory CLI.""" import os import sqlite3 from datetime import datetime from typing import List, Optional, Dict, Any from contextlib import contextmanager from .models import Command, Pattern, Session, ScriptTemplate class Database: """Manages SQLite database operations.""" def __init__(self, db_path: Optional[str] = None): if db_path is None: home = os.path.expanduser("~/.shell_memory") os.makedirs(home, exist_ok=True) db_path = os.path.join(home, "shell_memory.db") self.db_path = db_path self.init_db() @contextmanager def get_connection(self): conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row try: yield conn finally: conn.close() def init_db(self): with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS commands ( id INTEGER PRIMARY KEY AUTOINCREMENT, command TEXT NOT NULL, description TEXT, tags TEXT, created_at TEXT, usage_count INTEGER DEFAULT 0 ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS patterns ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, command_ids TEXT, frequency INTEGER DEFAULT 1, last_seen TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, commands TEXT, start_time TEXT, end_time TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS script_templates ( id INTEGER PRIMARY KEY AUTOINCREMENT, keywords TEXT, template TEXT, description TEXT ) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_commands_tags ON commands(tags) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_commands_created ON commands(created_at) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency) """) conn.commit() def add_command(self, command: Command) -> Optional[int]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO commands (command, description, tags, created_at, usage_count) VALUES (?, ?, ?, ?, ?) """, ( command.command, command.description, ",".join(command.tags), command.created_at.isoformat(), command.usage_count, )) conn.commit() return cursor.lastrowid def get_command(self, command_id: int) -> Optional[Command]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM commands WHERE id = ?", (command_id,)) row = cursor.fetchone() if row: return self._row_to_command(row) return None def get_all_commands(self) -> List[Command]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM commands ORDER BY created_at DESC") return [self._row_to_command(row) for row in cursor.fetchall()] def search_commands(self, query: str, limit: int = 10) -> List[Command]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT * FROM commands WHERE command LIKE ? OR description LIKE ? OR tags LIKE ? ORDER BY usage_count DESC LIMIT ? """, (f"%{query}%", f"%{query}%", f"%{query}%", limit)) return [self._row_to_command(row) for row in cursor.fetchall()] def delete_command(self, command_id: int) -> bool: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM commands WHERE id = ?", (command_id,)) conn.commit() return cursor.rowcount > 0 def update_command_usage(self, command_id: int): with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE commands SET usage_count = usage_count + 1 WHERE id = ? """, (command_id,)) conn.commit() def add_pattern(self, pattern: Pattern) -> Optional[int]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO patterns (name, command_ids, frequency, last_seen) VALUES (?, ?, ?, ?) """, ( pattern.name, ",".join(map(str, pattern.command_ids)), pattern.frequency, pattern.last_seen.isoformat(), )) conn.commit() return cursor.lastrowid def get_patterns(self, min_frequency: int = 2) -> List[Pattern]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT * FROM patterns WHERE frequency >= ? ORDER BY frequency DESC """, (min_frequency,)) return [self._row_to_pattern(row) for row in cursor.fetchall()] def update_pattern_frequency(self, pattern_id: int): with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE patterns SET frequency = frequency + 1, last_seen = ? WHERE id = ? """, (datetime.now().isoformat(), pattern_id)) conn.commit() def add_session(self, session: Session) -> Optional[int]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO sessions (name, commands, start_time, end_time) VALUES (?, ?, ?, ?) """, ( session.name, ",".join(map(str, session.commands)), session.start_time.isoformat(), session.end_time.isoformat() if session.end_time else None, )) conn.commit() return cursor.lastrowid def get_sessions(self) -> List[Session]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM sessions ORDER BY start_time DESC") return [self._row_to_session(row) for row in cursor.fetchall()] def get_session(self, session_id: int) -> Optional[Session]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM sessions WHERE id = ?", (session_id,)) row = cursor.fetchone() if row: return self._row_to_session(row) return None def add_script_template(self, template: ScriptTemplate) -> Optional[int]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO script_templates (keywords, template, description) VALUES (?, ?, ?) """, ( ",".join(template.keywords), template.template, template.description, )) conn.commit() return cursor.lastrowid def get_script_templates(self) -> List[ScriptTemplate]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM script_templates") return [self._row_to_template(row) for row in cursor.fetchall()] def search_script_templates(self, query: str) -> List[ScriptTemplate]: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT * FROM script_templates WHERE keywords LIKE ? OR description LIKE ? """, (f"%{query}%", f"%{query}%")) return [self._row_to_template(row) for row in cursor.fetchall()] def _row_to_command(self, row: sqlite3.Row) -> Command: return Command( id=row["id"], command=row["command"], description=row["description"] or "", tags=row["tags"].split(",") if row["tags"] else [], created_at=datetime.fromisoformat(row["created_at"]), usage_count=row["usage_count"], ) def _row_to_pattern(self, row: sqlite3.Row) -> Pattern: return Pattern( id=row["id"], name=row["name"], command_ids=[int(x) for x in row["command_ids"].split(",")] if row["command_ids"] else [], frequency=row["frequency"], last_seen=datetime.fromisoformat(row["last_seen"]), ) def _row_to_session(self, row: sqlite3.Row) -> Session: commands = row["commands"].split(",") if row["commands"] else [] return Session( id=row["id"], name=row["name"], commands=commands, start_time=datetime.fromisoformat(row["start_time"]), end_time=datetime.fromisoformat(row["end_time"]) if row["end_time"] else None, ) def _row_to_template(self, row: sqlite3.Row) -> ScriptTemplate: return ScriptTemplate( id=row["id"], keywords=row["keywords"].split(",") if row["keywords"] else [], template=row["template"], description=row["description"] or "", )