diff --git a/app/cli_memory/database.py b/app/cli_memory/database.py new file mode 100644 index 0000000..b9f8401 --- /dev/null +++ b/app/cli_memory/database.py @@ -0,0 +1,429 @@ +import os +import sqlite3 +import logging +from datetime import datetime +from contextlib import contextmanager +from typing import Optional, List, Generator + +from .config import Config +from .models import Project, Command, Workflow, Suggestion, Pattern + +logger = logging.getLogger(__name__) + + +class Database: + def __init__(self, config: Optional[Config] = None): + self.config = config or Config() + self.db_path = self._get_db_path() + self._ensure_db_exists() + + def _get_db_path(self) -> str: + db_path = self.config.get("database.path", "~/.cli_memory/database.db") + return os.path.expanduser(db_path) + + def _ensure_db_exists(self) -> None: + db_dir = os.path.dirname(self.db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + self._init_schema() + + @contextmanager + def get_connection(self) -> Generator[sqlite3.Connection, None, None]: + conn = sqlite3.connect(self.db_path, timeout=self.config.get("database.timeout", 30.0)) + conn.row_factory = sqlite3.Row + if self.config.get("database.wal_mode", True): + conn.execute("PRAGMA journal_mode=WAL") + try: + yield conn + conn.commit() + except Exception as e: + conn.rollback() + logger.error(f"Database error: {e}") + raise + finally: + conn.close() + + def _init_schema(self) -> None: + with self.get_connection() as conn: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + path TEXT NOT NULL UNIQUE, + git_remote TEXT, + tech_stack TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS commands ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + workflow_id INTEGER, + project_id INTEGER, + command TEXT NOT NULL, + command_type TEXT NOT NULL, + exit_code INTEGER, + duration_ms INTEGER, + working_directory TEXT NOT NULL, + timestamp TEXT NOT NULL, + tags TEXT, + metadata TEXT, + FOREIGN KEY (workflow_id) REFERENCES workflows(id), + FOREIGN KEY (project_id) REFERENCES projects(id) + ); + + CREATE TABLE IF NOT EXISTS workflows ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER, + name TEXT NOT NULL, + description TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + is_automated INTEGER DEFAULT 0, + pattern_confidence REAL DEFAULT 0.0, + usage_count INTEGER DEFAULT 0, + FOREIGN KEY (project_id) REFERENCES projects(id) + ); + + CREATE TABLE IF NOT EXISTS patterns ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER, + name TEXT NOT NULL, + command_sequence TEXT NOT NULL, + occurrences INTEGER DEFAULT 1, + confidence REAL DEFAULT 0.0, + created_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) + ); + + CREATE TABLE IF NOT EXISTS suggestions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER, + command TEXT NOT NULL, + context TEXT, + confidence REAL DEFAULT 0.0, + frequency INTEGER DEFAULT 1, + last_used TEXT, + pattern_id INTEGER, + FOREIGN KEY (project_id) REFERENCES projects(id), + FOREIGN KEY (pattern_id) REFERENCES patterns(id) + ); + + CREATE INDEX IF NOT EXISTS idx_commands_project ON commands(project_id); + CREATE INDEX IF NOT EXISTS idx_commands_workflow ON commands(workflow_id); + CREATE INDEX IF NOT EXISTS idx_commands_timestamp ON commands(timestamp); + CREATE INDEX IF NOT EXISTS idx_workflows_project ON workflows(project_id); + CREATE INDEX IF NOT EXISTS idx_patterns_project ON patterns(project_id); + CREATE INDEX IF NOT EXISTS idx_suggestions_project ON suggestions(project_id); + """) + + def create_project(self, project: Project) -> int: + with self.get_connection() as conn: + cursor = conn.execute( + """INSERT OR REPLACE INTO projects + (name, path, git_remote, tech_stack, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + project.name, + project.path, + project.git_remote, + ",".join(project.tech_stack) if project.tech_stack else "", + project.created_at.isoformat(), + project.updated_at.isoformat(), + ), + ) + return cursor.lastrowid + + def get_project(self, project_id: int) -> Optional[Project]: + with self.get_connection() as conn: + row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() + return self._row_to_project(row) if row else None + + def get_project_by_path(self, path: str) -> Optional[Project]: + with self.get_connection() as conn: + row = conn.execute("SELECT * FROM projects WHERE path = ?", (path,)).fetchone() + return self._row_to_project(row) if row else None + + def _row_to_project(self, row: sqlite3.Row) -> Project: + return Project( + id=row["id"], + name=row["name"], + path=row["path"], + git_remote=row["git_remote"], + tech_stack=row["tech_stack"].split(",") if row["tech_stack"] else [], + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + ) + + def get_all_projects(self) -> List[Project]: + with self.get_connection() as conn: + rows = conn.execute("SELECT * FROM projects ORDER BY updated_at DESC").fetchall() + return [self._row_to_project(row) for row in rows] + + def create_command(self, command: Command) -> int: + with self.get_connection() as conn: + cursor = conn.execute( + """INSERT INTO commands + (workflow_id, project_id, command, command_type, exit_code, + duration_ms, working_directory, timestamp, tags, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + command.workflow_id, + command.project_id, + command.command, + command.command_type.value, + command.exit_code, + command.duration_ms, + command.working_directory, + command.timestamp.isoformat(), + ",".join(command.tags) if command.tags else "", + str(command.metadata) if command.metadata else "", + ), + ) + return cursor.lastrowid + + def get_commands( + self, + project_id: Optional[int] = None, + workflow_id: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> List[Command]: + query = "SELECT * FROM commands WHERE 1=1" + params = [] + if project_id: + query += " AND project_id = ?" + params.append(project_id) + if workflow_id: + query += " AND workflow_id = ?" + params.append(workflow_id) + query += " ORDER BY timestamp DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + with self.get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_command(row) for row in rows] + + def _row_to_command(self, row: sqlite3.Row) -> Command: + tags = row["tags"].split(",") if row["tags"] else [] + try: + metadata = eval(row["metadata"]) if row["metadata"] else {} + except Exception: + metadata = {} + return Command( + id=row["id"], + workflow_id=row["workflow_id"], + project_id=row["project_id"], + command=row["command"], + command_type=row["command_type"], + exit_code=row["exit_code"], + duration_ms=row["duration_ms"], + working_directory=row["working_directory"], + timestamp=datetime.fromisoformat(row["timestamp"]), + tags=tags, + metadata=metadata, + ) + + def create_workflow(self, workflow: Workflow) -> int: + with self.get_connection() as conn: + cursor = conn.execute( + """INSERT INTO workflows + (project_id, name, description, created_at, updated_at, + is_automated, pattern_confidence, usage_count) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + workflow.project_id, + workflow.name, + workflow.description, + workflow.created_at.isoformat(), + workflow.updated_at.isoformat(), + int(workflow.is_automated), + workflow.pattern_confidence, + workflow.usage_count, + ), + ) + workflow_id = cursor.lastrowid + for cmd in workflow.commands: + cmd.workflow_id = workflow_id + cmd.project_id = workflow.project_id + self.create_command(cmd) + return workflow_id + + def get_workflow(self, workflow_id: int) -> Optional[Workflow]: + with self.get_connection() as conn: + row = conn.execute("SELECT * FROM workflows WHERE id = ?", (workflow_id,)).fetchone() + if not row: + return None + workflow = Workflow( + id=row["id"], + project_id=row["project_id"], + name=row["name"], + description=row["description"], + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + is_automated=bool(row["is_automated"]), + pattern_confidence=row["pattern_confidence"], + usage_count=row["usage_count"], + ) + workflow.commands = self.get_commands(workflow_id=workflow_id) + return workflow + + def get_all_workflows(self, project_id: Optional[int] = None) -> List[Workflow]: + query = "SELECT * FROM workflows" + params = [] + if project_id: + query += " WHERE project_id = ?" + params.append(project_id) + query += " ORDER BY usage_count DESC" + + with self.get_connection() as conn: + rows = conn.execute(query, params).fetchall() + workflows = [] + for row in rows: + workflow = Workflow( + id=row["id"], + project_id=row["project_id"], + name=row["name"], + description=row["description"], + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + is_automated=bool(row["is_automated"]), + pattern_confidence=row["pattern_confidence"], + usage_count=row["usage_count"], + ) + workflow.commands = self.get_commands(workflow_id=workflow.id) + workflows.append(workflow) + return workflows + + def update_workflow_usage(self, workflow_id: int) -> None: + with self.get_connection() as conn: + conn.execute( + "UPDATE workflows SET usage_count = usage_count + 1, updated_at = ? WHERE id = ?", + (datetime.utcnow().isoformat(), workflow_id), + ) + + def create_pattern(self, pattern: Pattern) -> int: + with self.get_connection() as conn: + cursor = conn.execute( + """INSERT INTO patterns + (project_id, name, command_sequence, occurrences, confidence, created_at) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + pattern.project_id, + pattern.name, + ",".join(pattern.command_sequence), + pattern.occurrences, + pattern.confidence, + pattern.created_at.isoformat(), + ), + ) + return cursor.lastrowid + + def get_patterns(self, project_id: Optional[int] = None) -> List[Pattern]: + query = "SELECT * FROM patterns" + params = [] + if project_id: + query += " WHERE project_id = ?" + params.append(project_id) + query += " ORDER BY occurrences DESC" + + with self.get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_pattern(row) for row in rows] + + def _row_to_pattern(self, row: sqlite3.Row) -> Pattern: + return Pattern( + id=row["id"], + project_id=row["project_id"], + name=row["name"], + command_sequence=row["command_sequence"].split(",") if row["command_sequence"] else [], + occurrences=row["occurrences"], + confidence=row["confidence"], + created_at=datetime.fromisoformat(row["created_at"]), + ) + + def create_suggestion(self, suggestion: Suggestion) -> int: + with self.get_connection() as conn: + cursor = conn.execute( + """INSERT INTO suggestions + (project_id, command, context, confidence, frequency, last_used, pattern_id) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + ( + suggestion.project_id, + suggestion.command, + suggestion.context, + suggestion.confidence, + suggestion.frequency, + suggestion.last_used.isoformat() if suggestion.last_used else None, + suggestion.pattern_id, + ), + ) + return cursor.lastrowid + + def get_suggestions( + self, project_id: Optional[int] = None, limit: int = 10 + ) -> List[Suggestion]: + query = "SELECT * FROM suggestions WHERE 1=1" + params = [] + if project_id: + query += " AND project_id = ?" + params.append(project_id) + query += " ORDER BY confidence DESC, frequency DESC LIMIT ?" + params.append(limit) + + with self.get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_suggestion(row) for row in rows] + + def _row_to_suggestion(self, row: sqlite3.Row) -> Suggestion: + return Suggestion( + id=row["id"], + project_id=row["project_id"], + command=row["command"], + context=row["context"], + confidence=row["confidence"], + frequency=row["frequency"], + last_used=datetime.fromisoformat(row["last_used"]) if row["last_used"] else None, + pattern_id=row["pattern_id"], + ) + + def search_commands( + self, + query: str, + project_id: Optional[int] = None, + command_type: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 50, + ) -> List[Command]: + sql = "SELECT * FROM commands WHERE command LIKE ?" + params = [f"%{query}%"] + if project_id: + sql += " AND project_id = ?" + params.append(project_id) + if command_type: + sql += " AND command_type = ?" + params.append(command_type) + if start_time: + sql += " AND timestamp >= ?" + params.append(start_time.isoformat()) + if end_time: + sql += " AND timestamp <= ?" + params.append(end_time.isoformat()) + sql += " ORDER BY timestamp DESC LIMIT ?" + params.append(limit) + + with self.get_connection() as conn: + rows = conn.execute(sql, params).fetchall() + return [self._row_to_command(row) for row in rows] + + def delete_project(self, project_id: int) -> None: + with self.get_connection() as conn: + conn.execute("DELETE FROM commands WHERE project_id = ?", (project_id,)) + conn.execute("DELETE FROM workflows WHERE project_id = ?", (project_id,)) + conn.execute("DELETE FROM patterns WHERE project_id = ?", (project_id,)) + conn.execute("DELETE FROM suggestions WHERE project_id = ?", (project_id,)) + conn.execute("DELETE FROM projects WHERE id = ?", (project_id,)) + + def close(self) -> None: + pass