import os import sqlite3 import threading from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Any class Database: def __init__(self, db_path: str | None = None): if db_path is None: db_path = os.environ.get("SNIP_DB_PATH", str(Path.home() / ".snip" / "snippets.db")) self.db_path = db_path self._local = threading.local() self._ensure_dir() def _ensure_dir(self): Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) def get_connection(self) -> sqlite3.Connection: if not hasattr(self._local, "conn") or self._local.conn is None: self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False) self._local.conn.row_factory = sqlite3.Row self._local.conn.execute("PRAGMA foreign_keys = ON") self._local.conn.execute("PRAGMA journal_mode = WAL") return self._local.conn @contextmanager def get_cursor(self): conn = self.get_connection() cursor = conn.cursor() try: yield cursor conn.commit() except Exception: conn.rollback() raise finally: cursor.close() def init_schema(self): with self.get_cursor() as cursor: cursor.executescript(""" CREATE TABLE IF NOT EXISTS snippets ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, description TEXT DEFAULT '', code TEXT NOT NULL, language TEXT DEFAULT 'text', is_encrypted INTEGER DEFAULT 0, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS tags ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE NOT NULL ); CREATE TABLE IF NOT EXISTS snippet_tags ( snippet_id INTEGER NOT NULL, tag_id INTEGER NOT NULL, PRIMARY KEY (snippet_id, tag_id), FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE, FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS collections ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE NOT NULL, description TEXT DEFAULT '', created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS snippet_collections ( snippet_id INTEGER NOT NULL, collection_id INTEGER NOT NULL, PRIMARY KEY (snippet_id, collection_id), FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE, FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS sync_meta ( peer_id TEXT PRIMARY KEY, peer_name TEXT, last_sync TEXT, peer_address TEXT, peer_port INTEGER ); CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5( title, description, code, tags, content='snippets', content_rowid='id' ); CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN INSERT INTO snippets_fts(rowid, title, description, code, tags) VALUES (new.id, new.title, new.description, new.code, (SELECT GROUP_CONCAT(t.name) FROM tags t JOIN snippet_tags st ON st.tag_id = t.id WHERE st.snippet_id = new.id)); END; CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags) VALUES ('delete', old.id, old.title, old.description, old.code, ''); END; CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags) VALUES ('delete', old.id, old.title, old.description, old.code, ''); INSERT INTO snippets_fts(rowid, title, description, code, tags) VALUES (new.id, new.title, new.description, new.code, (SELECT GROUP_CONCAT(t.name) FROM tags t JOIN snippet_tags st ON st.tag_id = t.id WHERE st.snippet_id = new.id)); END; CREATE INDEX IF NOT EXISTS idx_snippets_language ON snippets(language); CREATE INDEX IF NOT EXISTS idx_snippets_created ON snippets(created_at); CREATE INDEX IF NOT EXISTS idx_tags_name ON tags(name); """) def create_snippet(self, title: str, code: str, description: str = "", language: str = "text", tags: list[str] | None = None, is_encrypted: bool = False) -> int: now = datetime.utcnow().isoformat() with self.get_cursor() as cursor: cursor.execute( """INSERT INTO snippets (title, description, code, language, is_encrypted, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)""", (title, description, code, language, int(is_encrypted), now, now), ) snippet_id = cursor.lastrowid assert snippet_id is not None if tags: self._update_snippet_tags(cursor, snippet_id, tags) return snippet_id def _update_snippet_tags(self, cursor: sqlite3.Cursor, snippet_id: int, tags: list[str]): cursor.execute("DELETE FROM snippet_tags WHERE snippet_id = ?", (snippet_id,)) for tag_name in tags: tag_name = tag_name.strip().lower() if not tag_name: continue cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,)) cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,)) tag_id = cursor.fetchone()[0] cursor.execute( "INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)", (snippet_id, tag_id), ) def get_snippet(self, snippet_id: int) -> dict[str, Any] | None: with self.get_cursor() as cursor: cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,)) row = cursor.fetchone() if not row: return None snippet = dict(row) cursor.execute( """SELECT t.name FROM tags t JOIN snippet_tags st ON st.tag_id = t.id WHERE st.snippet_id = ?""", (snippet_id,), ) snippet["tags"] = [r["name"] for r in cursor.fetchall()] return snippet def update_snippet(self, snippet_id: int, title: str | None = None, description: str | None = None, code: str | None = None, language: str | None = None, tags: list[str] | None = None) -> bool: now = datetime.utcnow().isoformat() updates = [] params = [] if title is not None: updates.append("title = ?") params.append(title) if description is not None: updates.append("description = ?") params.append(description) if code is not None: updates.append("code = ?") params.append(code) if language is not None: updates.append("language = ?") params.append(language) if updates: updates.append("updated_at = ?") params.append(now) params.append(snippet_id) with self.get_cursor() as cursor: cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params) if tags is not None: self._update_snippet_tags(cursor, snippet_id, tags) return cursor.rowcount > 0 if tags is not None: with self.get_cursor() as cursor: self._update_snippet_tags(cursor, snippet_id, tags) return False def delete_snippet(self, snippet_id: int) -> bool: with self.get_cursor() as cursor: cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,)) return cursor.rowcount > 0 def list_snippets(self, language: str | None = None, tag: str | None = None, collection_id: int | None = None, limit: int = 50, offset: int = 0) -> list[dict[str, Any]]: query = """SELECT s.* FROM snippets s""" joins = [] conditions = [] params = [] if tag: joins.append("JOIN snippet_tags st ON st.snippet_id = s.id") joins.append("JOIN tags t ON t.id = st.tag_id") conditions.append("t.name = ?") params.append(tag.lower()) if collection_id: joins.append("JOIN snippet_collections sc ON sc.snippet_id = s.id") conditions.append("sc.collection_id = ?") params.append(collection_id) if joins: query += " " + " ".join(joins) if conditions: query += " WHERE " + " AND ".join(conditions) if language: if conditions: query += " AND s.language = ?" else: query += " WHERE s.language = ?" params.append(language) query += " ORDER BY s.updated_at DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) with self.get_cursor() as cursor: cursor.execute(query, params) snippets = [] for row in cursor.fetchall(): snippet = dict(row) cursor.execute( """SELECT t.name FROM tags t JOIN snippet_tags st ON st.tag_id = t.id WHERE st.snippet_id = ?""", (snippet["id"],), ) snippet["tags"] = [r["name"] for r in cursor.fetchall()] snippets.append(snippet) return snippets def search_snippets(self, query: str, limit: int = 50) -> list[dict[str, Any]]: fts_query = self._build_fts_query(query) with self.get_cursor() as cursor: cursor.execute( """SELECT s.*, bm25(snippets_fts) as rank FROM snippets_fts JOIN snippets s ON s.id = snippets_fts.rowid WHERE snippets_fts MATCH ? ORDER BY rank LIMIT ?""", (fts_query, limit), ) snippets = [] for row in cursor.fetchall(): snippet = dict(row) cursor.execute( """SELECT t.name FROM tags t JOIN snippet_tags st ON st.tag_id = t.id WHERE st.snippet_id = ?""", (snippet["id"],), ) snippet["tags"] = [r["name"] for r in cursor.fetchall()] snippets.append(snippet) return snippets def _build_fts_query(self, query: str) -> str: parts = query.strip().split() if not parts: return '""' processed = [] for part in parts: if part in ("AND", "OR", "NOT"): processed.append(part) elif part.startswith('"') and part.endswith('"'): processed.append(part) else: escaped = part.replace('"', '""') processed.append(f'"{escaped}"') return " ".join(processed) + "*" def tag_add(self, snippet_id: int, tag_name: str) -> bool: tag_name = tag_name.strip().lower() with self.get_cursor() as cursor: cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,)) cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,)) tag_id = cursor.fetchone()[0] cursor.execute( "INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)", (snippet_id, tag_id), ) return True def tag_remove(self, snippet_id: int, tag_name: str) -> bool: tag_name = tag_name.strip().lower() with self.get_cursor() as cursor: cursor.execute( """DELETE FROM snippet_tags WHERE snippet_id = ? AND tag_id = ( SELECT id FROM tags WHERE name = ? )""", (snippet_id, tag_name), ) return cursor.rowcount > 0 def list_tags(self) -> list[str]: with self.get_cursor() as cursor: cursor.execute("SELECT name FROM tags ORDER BY name") return [row["name"] for row in cursor.fetchall()] def collection_create(self, name: str, description: str = "") -> int: now = datetime.utcnow().isoformat() with self.get_cursor() as cursor: cursor.execute( "INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)", (name, description, now), ) return cursor.lastrowid def collection_list(self) -> list[dict[str, Any]]: with self.get_cursor() as cursor: cursor.execute( "SELECT c.*, COUNT(sc.snippet_id) as snippet_count " "FROM collections c LEFT JOIN snippet_collections sc ON sc.collection_id = c.id " "GROUP BY c.id ORDER BY c.name" ) return [dict(row) for row in cursor.fetchall()] def collection_delete(self, collection_id: int) -> bool: with self.get_cursor() as cursor: cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,)) return cursor.rowcount > 0 def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool: with self.get_cursor() as cursor: cursor.execute( "INSERT OR IGNORE INTO snippet_collections (collection_id, snippet_id) VALUES (?, ?)", (collection_id, snippet_id), ) return True def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool: with self.get_cursor() as cursor: cursor.execute( "DELETE FROM snippet_collections WHERE collection_id = ? AND snippet_id = ?", (collection_id, snippet_id), ) return cursor.rowcount > 0 def get_all_snippets_for_sync(self, since: str | None = None) -> list[dict[str, Any]]: with self.get_cursor() as cursor: if since: cursor.execute( "SELECT * FROM snippets WHERE updated_at > ? ORDER BY updated_at", (since,), ) else: cursor.execute("SELECT * FROM snippets ORDER BY updated_at") snippets = [] for row in cursor.fetchall(): snippet = dict(row) cursor.execute( """SELECT t.name FROM tags t JOIN snippet_tags st ON st.tag_id = t.id WHERE st.snippet_id = ?""", (snippet["id"],), ) snippet["tags"] = [r["name"] for r in cursor.fetchall()] snippets.append(snippet) return snippets def upsert_snippet(self, snippet_data: dict[str, Any]) -> int: with self.get_cursor() as cursor: cursor.execute("SELECT id FROM snippets WHERE id = ?", (snippet_data["id"],)) existing = cursor.fetchone() if existing: fields = ["title", "description", "code", "language", "is_encrypted", "updated_at"] updates = [] params = [] for field in fields: if field in snippet_data: updates.append(f"{field} = ?") params.append(snippet_data[field]) params.append(snippet_data["id"]) cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params) if "tags" in snippet_data: self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"]) return snippet_data["id"] else: cursor.execute( """INSERT INTO snippets (id, title, description, code, language, is_encrypted, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", ( snippet_data["id"], snippet_data["title"], snippet_data.get("description", ""), snippet_data["code"], snippet_data.get("language", "text"), snippet_data.get("is_encrypted", 0), snippet_data["created_at"], snippet_data["updated_at"], ), ) if "tags" in snippet_data: self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"]) return snippet_data["id"] def update_sync_meta(self, peer_id: str, peer_name: str | None = None, address: str | None = None, port: int | None = None): now = datetime.utcnow().isoformat() with self.get_cursor() as cursor: cursor.execute( """INSERT INTO sync_meta (peer_id, peer_name, last_sync, peer_address, peer_port) VALUES (?, ?, ?, ?, ?) ON CONFLICT(peer_id) DO UPDATE SET last_sync = COALESCE(excluded.last_sync, last_sync), peer_name = COALESCE(excluded.peer_name, peer_name), peer_address = COALESCE(excluded.peer_address, peer_address), peer_port = COALESCE(excluded.peer_port, peer_port)""", (peer_id, peer_name, now, address, port), ) def get_sync_meta(self, peer_id: str) -> dict[str, Any] | None: with self.get_cursor() as cursor: cursor.execute("SELECT * FROM sync_meta WHERE peer_id = ?", (peer_id,)) row = cursor.fetchone() return dict(row) if row else None def list_sync_peers(self) -> list[dict[str, Any]]: with self.get_cursor() as cursor: cursor.execute("SELECT * FROM sync_meta ORDER BY last_sync DESC") return [dict(row) for row in cursor.fetchall()] def close(self): if hasattr(self._local, "conn") and self._local.conn: self._local.conn.close() self._local.conn = None _db_instance: Database | None = None _db_path: str | None = None def get_database(db_path: str | None = None) -> Database: global _db_instance, _db_path if db_path is not None and _db_path != db_path: _db_instance = Database(db_path) _db_instance.init_schema() _db_path = db_path elif _db_instance is None: _db_instance = Database(db_path) _db_instance.init_schema() _db_path = db_path return _db_instance