diff --git a/snip/db/database.py b/snip/db/database.py new file mode 100644 index 0000000..a471614 --- /dev/null +++ b/snip/db/database.py @@ -0,0 +1,385 @@ +"""SQLite database with FTS5 search for snippet storage.""" + +import json +import os +import time +from contextlib import contextmanager +from datetime import datetime +from pathlib import Path +from typing import Any + +import sqlite3 + + +class Database: + def __init__(self, db_path: str | None = None): + if db_path is None: + db_path = os.environ.get("SNIP_DB_PATH", "~/.snip/snippets.db") + self.db_path = os.path.expanduser(db_path) + self._ensure_dir() + self.conn = None + + def _ensure_dir(self): + Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) + + @contextmanager + def get_connection(self): + if self.conn is None: + self.conn = sqlite3.connect(self.db_path) + self.conn.row_factory = sqlite3.Row + try: + yield self.conn + self.conn.commit() + except Exception: + self.conn.rollback() + raise + + def init_db(self): + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS snippets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + description TEXT, + code TEXT NOT NULL, + language TEXT, + tags TEXT DEFAULT '[]', + is_encrypted INTEGER DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS collections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT, + created_at TEXT NOT NULL + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS snippet_collections ( + snippet_id INTEGER NOT NULL, + collection_id INTEGER NOT NULL, + PRIMARY KEY (snippet_id, collection_id), + FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE, + FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS sync_peers ( + peer_id TEXT PRIMARY KEY, + host TEXT NOT NULL, + port INTEGER NOT NULL, + last_seen TEXT NOT NULL + ) + """) + + cursor.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5( + title, description, code, tags, + content='snippets', + content_rowid='id' + ) + """) + + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN + INSERT INTO snippets_fts(rowid, title, description, code, tags) + VALUES (new.id, new.title, new.description, new.code, new.tags); + END + """) + + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN + INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags) + VALUES ('delete', old.id, old.title, old.description, old.code, old.tags); + END + """) + + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN + INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags) + VALUES ('delete', old.id, old.title, old.description, old.code, old.tags); + INSERT INTO snippets_fts(rowid, title, description, code, tags) + VALUES (new.id, new.title, new.description, new.code, new.tags); + END + """) + + def add_snippet( + self, + title: str, + code: str, + description: str = "", + language: str = "", + tags: list[str] | None = None, + is_encrypted: bool = False, + ) -> int: + tags = tags or [] + now = datetime.utcnow().isoformat() + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO snippets (title, description, code, language, tags, is_encrypted, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (title, description, code, language, json.dumps(tags), int(is_encrypted), now, now), + ) + return cursor.lastrowid + + def get_snippet(self, snippet_id: int) -> dict[str, Any] | None: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,)) + row = cursor.fetchone() + if row: + return dict(row) + return None + + def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None) -> list[dict[str, Any]]: + with self.get_connection() as conn: + cursor = conn.cursor() + if tag: + cursor.execute( + "SELECT * FROM snippets WHERE tags LIKE ? ORDER BY updated_at DESC LIMIT ? OFFSET ?", + (f'%"{tag}"%', limit, offset), + ) + else: + cursor.execute( + "SELECT * FROM snippets ORDER BY updated_at DESC LIMIT ? OFFSET ?", + (limit, offset), + ) + return [dict(row) for row in cursor.fetchall()] + + def update_snippet( + self, + snippet_id: int, + title: str | None = None, + description: str | None = None, + code: str | None = None, + language: str | None = None, + tags: list[str] | None = None, + ) -> bool: + snippet = self.get_snippet(snippet_id) + if not snippet: + return False + + now = datetime.utcnow().isoformat() + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE snippets SET + title = COALESCE(?, title), + description = COALESCE(?, description), + code = COALESCE(?, code), + language = COALESCE(?, language), + tags = COALESCE(?, tags), + updated_at = ? + WHERE id = ? + """, + ( + title, + description, + code, + language, + json.dumps(tags) if tags is not None else None, + now, + snippet_id, + ), + ) + return True + + def delete_snippet(self, snippet_id: int) -> bool: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,)) + return cursor.rowcount > 0 + + def search_snippets( + self, + query: str, + limit: int = 50, + language: str | None = None, + tag: str | None = None, + ) -> list[dict[str, Any]]: + with self.get_connection() as conn: + cursor = conn.cursor() + if language: + fts_query = f"{query} AND language:{language}" + else: + fts_query = query + + cursor.execute( + """ + SELECT s.*, bm25(snippets_fts) as rank + FROM snippets s + JOIN snippets_fts ON s.id = snippets_fts.rowid + WHERE snippets_fts MATCH ? + ORDER BY rank + LIMIT ? + """, + (fts_query, limit), + ) + results = [dict(row) for row in cursor.fetchall()] + + if tag: + results = [r for r in results if tag in json.loads(r.get("tags", "[]"))] + + return results + + def add_tag(self, snippet_id: int, tag: str) -> bool: + snippet = self.get_snippet(snippet_id) + if not snippet: + return False + tags = json.loads(snippet["tags"]) + if tag not in tags: + tags.append(tag) + return self.update_snippet(snippet_id, tags=tags) + return True + + def remove_tag(self, snippet_id: int, tag: str) -> bool: + snippet = self.get_snippet(snippet_id) + if not snippet: + return False + tags = json.loads(snippet["tags"]) + if tag in tags: + tags.remove(tag) + return self.update_snippet(snippet_id, tags=tags) + return True + + def list_tags(self) -> list[str]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT tags FROM snippets") + all_tags: set[str] = set() + for row in cursor.fetchall(): + all_tags.update(json.loads(row["tags"])) + return sorted(all_tags) + + def create_collection(self, name: str, description: str = "") -> int: + now = datetime.utcnow().isoformat() + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)", + (name, description, now), + ) + return cursor.lastrowid + + def list_collections(self) -> list[dict[str, Any]]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM collections ORDER BY name") + return [dict(row) for row in cursor.fetchall()] + + def get_collection(self, collection_id: int) -> dict[str, Any] | None: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM collections WHERE id = ?", (collection_id,)) + row = cursor.fetchone() + if row: + return dict(row) + return None + + def delete_collection(self, collection_id: int) -> bool: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,)) + return cursor.rowcount > 0 + + def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool: + snippet = self.get_snippet(snippet_id) + collection = self.get_collection(collection_id) + if not snippet or not collection: + return False + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)", + (snippet_id, collection_id), + ) + return True + except sqlite3.IntegrityError: + return True + + def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM snippet_collections WHERE snippet_id = ? AND collection_id = ?", + (snippet_id, collection_id), + ) + return cursor.rowcount > 0 + + def get_collection_snippets(self, collection_id: int) -> list[dict[str, Any]]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT s.* FROM snippets s + JOIN snippet_collections sc ON s.id = sc.snippet_id + WHERE sc.collection_id = ? + ORDER BY s.updated_at DESC + """, + (collection_id,), + ) + return [dict(row) for row in cursor.fetchall()] + + def export_all(self) -> list[dict[str, Any]]: + return self.list_snippets(limit=10000) + + def import_snippet( + self, + data: dict[str, Any], + strategy: str = "skip", + ) -> int | None: + existing = None + if "title" in data: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT id FROM snippets WHERE title = ?", (data["title"],)) + existing = cursor.fetchone() + + if existing: + if strategy == "skip": + return None + elif strategy == "replace": + self.update_snippet( + existing["id"], + title=data.get("title"), + description=data.get("description"), + code=data.get("code"), + language=data.get("language"), + tags=data.get("tags"), + ) + return existing["id"] + + return self.add_snippet( + title=data.get("title", "Untitled"), + code=data.get("code", ""), + description=data.get("description", ""), + language=data.get("language", ""), + tags=data.get("tags", []), + ) + + def add_peer(self, peer_id: str, host: str, port: int): + now = datetime.utcnow().isoformat() + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO sync_peers (peer_id, host, port, last_seen) VALUES (?, ?, ?, ?)", + (peer_id, host, port, now), + ) + + def list_peers(self) -> list[dict[str, Any]]: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC") + return [dict(row) for row in cursor.fetchall()]