"""Database management for sessions.""" import sqlite3 from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Any, Generator, List, Optional class SessionDatabase: """Manages session storage in SQLite.""" def __init__(self, db_path: str): self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_db() def _init_db(self) -> None: """Initialize the database schema.""" with self._get_connection() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, start_time TEXT, end_time TEXT, created_at TEXT NOT NULL, command_count INTEGER DEFAULT 0 ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS commands ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL, command TEXT NOT NULL, timestamp TEXT, exit_code INTEGER, output TEXT, FOREIGN KEY (session_id) REFERENCES sessions(id) ) """) conn.commit() @contextmanager def _get_connection(self) -> Generator[sqlite3.Connection, None, None]: """Get a database connection.""" conn = sqlite3.connect(str(self.db_path)) try: yield conn finally: conn.close() def create_session(self, session) -> Any: """Create a new session.""" with self._get_connection() as conn: cursor = conn.execute(""" INSERT INTO sessions (name, start_time, end_time, created_at, command_count) VALUES (?, ?, ?, ?, ?) """, ( session.name, session.start_time.isoformat() if session.start_time else None, session.end_time.isoformat() if session.end_time else None, session.created_at.isoformat() if session.created_at else datetime.now().isoformat(), session.command_count, )) session_id = cursor.lastrowid conn.commit() for cmd in session.commands: conn.execute(""" INSERT INTO commands (session_id, command, timestamp, exit_code, output) VALUES (?, ?, ?, ?, ?) """, ( session_id, cmd.command, cmd.timestamp, cmd.exit_code, cmd.output, )) conn.commit() session.id = session_id return session def get_session(self, session_id: int) -> Optional[dict]: """Get a session by ID.""" with self._get_connection() as conn: cursor = conn.execute("SELECT * FROM sessions WHERE id = ?", (session_id,)) row = cursor.fetchone() if row: return { "id": row[0], "name": row[1], "start_time": row[2], "end_time": row[3], "created_at": row[4], "command_count": row[5], } return None def get_all_sessions(self) -> List[dict]: """Get all sessions.""" with self._get_connection() as conn: cursor = conn.execute("SELECT * FROM sessions ORDER BY created_at DESC") return [ { "id": row[0], "name": row[1], "start_time": row[2], "end_time": row[3], "created_at": row[4], "command_count": row[5], } for row in cursor.fetchall() ] def get_session_commands(self, session_id: int) -> List[dict]: """Get all commands for a session.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT * FROM commands WHERE session_id = ? ORDER BY id", (session_id,) ) return [ { "id": row[0], "session_id": row[1], "command": row[2], "timestamp": row[3], "exit_code": row[4], "output": row[5], } for row in cursor.fetchall() ] def delete_session(self, session_id: int) -> bool: """Delete a session and its commands.""" with self._get_connection() as conn: conn.execute("DELETE FROM commands WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) conn.commit() return True