diff --git a/snip/db/database.py b/snip/db/database.py index 00b6497..9f63590 100644 --- a/snip/db/database.py +++ b/snip/db/database.py @@ -1,303 +1,323 @@ -"""SQLite database with FTS5 search for snippet storage.""" - -import json import os -import time +import sqlite3 +import threading from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Any -import sqlite3 - - -def get_database(db_path: str | None = None) -> "Database": - """Get a Database instance.""" - return Database(db_path) - class Database: def __init__(self, db_path: str | None = None): if db_path is None: - db_path = os.environ.get("SNIP_DB_PATH", "~/.snip/snippets.db") - self.db_path = os.path.expanduser(db_path) + db_path = os.environ.get("SNIP_DB_PATH", str(Path.home() / ".snip" / "snippets.db")) + self.db_path = db_path + self._local = threading.local() self._ensure_dir() - self.conn = None def _ensure_dir(self): Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) - @contextmanager - def get_connection(self): - if self.conn is None: - self.conn = sqlite3.connect(self.db_path) - self.conn.row_factory = sqlite3.Row - try: - yield self.conn - self.conn.commit() - except Exception: - self.conn.rollback() - raise + def get_connection(self) -> sqlite3.Connection: + if not hasattr(self._local, "conn") or self._local.conn is None: + self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self._local.conn.row_factory = sqlite3.Row + self._local.conn.execute("PRAGMA foreign_keys = ON") + self._local.conn.execute("PRAGMA journal_mode = WAL") + return self._local.conn - def init_db(self): - """Initialize database schema.""" - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(""" + @contextmanager + def get_cursor(self): + conn = self.get_connection() + cursor = conn.cursor() + try: + yield cursor + conn.commit() + except Exception: + conn.rollback() + raise + finally: + cursor.close() + + def init_schema(self): + with self.get_cursor() as cursor: + cursor.executescript(""" CREATE TABLE IF NOT EXISTS snippets ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, - description TEXT, + description TEXT DEFAULT '', code TEXT NOT NULL, - language TEXT, - tags TEXT DEFAULT '[]', + language TEXT DEFAULT 'text', is_encrypted INTEGER DEFAULT 0, created_at TEXT NOT NULL, updated_at TEXT NOT NULL - ) - """) + ); + + CREATE TABLE IF NOT EXISTS tags ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL + ); + + CREATE TABLE IF NOT EXISTS snippet_tags ( + snippet_id INTEGER NOT NULL, + tag_id INTEGER NOT NULL, + PRIMARY KEY (snippet_id, tag_id), + FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE, + FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE + ); - cursor.execute(""" CREATE TABLE IF NOT EXISTS collections ( id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - description TEXT, + name TEXT UNIQUE NOT NULL, + description TEXT DEFAULT '', created_at TEXT NOT NULL - ) - """) + ); - cursor.execute(""" CREATE TABLE IF NOT EXISTS snippet_collections ( snippet_id INTEGER NOT NULL, collection_id INTEGER NOT NULL, PRIMARY KEY (snippet_id, collection_id), FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE, FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE - ) - """) + ); - cursor.execute(""" - CREATE TABLE IF NOT EXISTS sync_peers ( + CREATE TABLE IF NOT EXISTS sync_meta ( peer_id TEXT PRIMARY KEY, peer_name TEXT, - peer_address TEXT, - port INTEGER, last_sync TEXT, - last_seen TEXT NOT NULL - ) - """) + peer_address TEXT, + peer_port INTEGER + ); - cursor.execute(""" CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5( title, description, code, tags, content='snippets', content_rowid='id' - ) - """) + ); - cursor.execute(""" CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN INSERT INTO snippets_fts(rowid, title, description, code, tags) - VALUES (new.id, new.title, new.description, new.code, new.tags); - END - """) + VALUES (new.id, new.title, new.description, new.code, + (SELECT GROUP_CONCAT(t.name) FROM tags t + JOIN snippet_tags st ON st.tag_id = t.id + WHERE st.snippet_id = new.id)); + END; - cursor.execute(""" CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags) - VALUES ('delete', old.id, old.title, old.description, old.code, old.tags); - END - """) + VALUES ('delete', old.id, old.title, old.description, old.code, ''); + END; - cursor.execute(""" CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags) - VALUES ('delete', old.id, old.title, old.description, old.code, old.tags); + VALUES ('delete', old.id, old.title, old.description, old.code, ''); INSERT INTO snippets_fts(rowid, title, description, code, tags) - VALUES (new.id, new.title, new.description, new.code, new.tags); - END + VALUES (new.id, new.title, new.description, new.code, + (SELECT GROUP_CONCAT(t.name) FROM tags t + JOIN snippet_tags st ON st.tag_id = t.id + WHERE st.snippet_id = new.id)); + END; + + CREATE INDEX IF NOT EXISTS idx_snippets_language ON snippets(language); + CREATE INDEX IF NOT EXISTS idx_snippets_created ON snippets(created_at); + CREATE INDEX IF NOT EXISTS idx_tags_name ON tags(name); """) - def init_schema(self): - """Alias for init_db for backwards compatibility.""" - self.init_db() - - def create_snippet( - self, - title: str, - code: str, - description: str = "", - language: str = "", - tags: list[str] | None = None, - is_encrypted: bool = False, - ) -> int: - """Alias for add_snippet for backwards compatibility.""" - return self.add_snippet(title, code, description, language, tags, is_encrypted) - - def add_snippet( - self, - title: str, - code: str, - description: str = "", - language: str = "", - tags: list[str] | None = None, - is_encrypted: bool = False, - ) -> int: - tags = tags or [] + def create_snippet(self, title: str, code: str, description: str = "", language: str = "text", + tags: list[str] | None = None, is_encrypted: bool = False) -> int: now = datetime.utcnow().isoformat() - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute( - """ - INSERT INTO snippets (title, description, code, language, tags, is_encrypted, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (title, description, code, language, json.dumps(tags), int(is_encrypted), now, now), + """INSERT INTO snippets (title, description, code, language, is_encrypted, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (title, description, code, language, int(is_encrypted), now, now), + ) + snippet_id = cursor.lastrowid + assert snippet_id is not None + if tags: + self._update_snippet_tags(cursor, snippet_id, tags) + return snippet_id + + def _update_snippet_tags(self, cursor: sqlite3.Cursor, snippet_id: int, tags: list[str]): + cursor.execute("DELETE FROM snippet_tags WHERE snippet_id = ?", (snippet_id,)) + for tag_name in tags: + tag_name = tag_name.strip().lower() + if not tag_name: + continue + cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,)) + cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,)) + tag_id = cursor.fetchone()[0] + cursor.execute( + "INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)", + (snippet_id, tag_id), ) - return cursor.lastrowid def get_snippet(self, snippet_id: int) -> dict[str, Any] | None: - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,)) row = cursor.fetchone() - if row: - return dict(row) - return None - - def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]: - with self.get_connection() as conn: - cursor = conn.cursor() - query = "SELECT * FROM snippets WHERE 1=1" - params = [] - if tag: - query += " AND tags LIKE ?" - params.append(f'%"{tag}"%') - if language: - query += " AND language = ?" - params.append(language) - if collection_id: - query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)" - params.append(collection_id) - query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?" - params.extend([limit, offset]) - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] - - def update_snippet( - self, - snippet_id: int, - title: str | None = None, - description: str | None = None, - code: str | None = None, - language: str | None = None, - tags: list[str] | None = None, - ) -> bool: - snippet = self.get_snippet(snippet_id) - if not snippet: - return False - - now = datetime.utcnow().isoformat() - with self.get_connection() as conn: - cursor = conn.cursor() + if not row: + return None + snippet = dict(row) cursor.execute( - """ - UPDATE snippets SET - title = COALESCE(?, title), - description = COALESCE(?, description), - code = COALESCE(?, code), - language = COALESCE(?, language), - tags = COALESCE(?, tags), - updated_at = ? - WHERE id = ? - """, - ( - title, - description, - code, - language, - json.dumps(tags) if tags is not None else None, - now, - snippet_id, - ), + """SELECT t.name FROM tags t + JOIN snippet_tags st ON st.tag_id = t.id + WHERE st.snippet_id = ?""", + (snippet_id,), ) - return True + snippet["tags"] = [r["name"] for r in cursor.fetchall()] + return snippet + + def update_snippet(self, snippet_id: int, title: str | None = None, description: str | None = None, + code: str | None = None, language: str | None = None, tags: list[str] | None = None) -> bool: + now = datetime.utcnow().isoformat() + updates = [] + params = [] + if title is not None: + updates.append("title = ?") + params.append(title) + if description is not None: + updates.append("description = ?") + params.append(description) + if code is not None: + updates.append("code = ?") + params.append(code) + if language is not None: + updates.append("language = ?") + params.append(language) + if updates: + updates.append("updated_at = ?") + params.append(now) + params.append(snippet_id) + with self.get_cursor() as cursor: + cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params) + if tags is not None: + self._update_snippet_tags(cursor, snippet_id, tags) + return cursor.rowcount > 0 + if tags is not None: + with self.get_cursor() as cursor: + self._update_snippet_tags(cursor, snippet_id, tags) + return False def delete_snippet(self, snippet_id: int) -> bool: - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,)) return cursor.rowcount > 0 - def search_snippets( - self, - query: str, - limit: int = 50, - language: str | None = None, - tag: str | None = None, - ) -> list[dict[str, Any]]: - with self.get_connection() as conn: - cursor = conn.cursor() - if language: - fts_query = f"{query} AND language:{language}" + def list_snippets(self, language: str | None = None, tag: str | None = None, + collection_id: int | None = None, limit: int = 50, offset: int = 0) -> list[dict[str, Any]]: + query = """SELECT s.* FROM snippets s""" + joins = [] + conditions = [] + params = [] + if tag: + joins.append("JOIN snippet_tags st ON st.snippet_id = s.id") + joins.append("JOIN tags t ON t.id = st.tag_id") + conditions.append("t.name = ?") + params.append(tag.lower()) + if collection_id: + joins.append("JOIN snippet_collections sc ON sc.snippet_id = s.id") + conditions.append("sc.collection_id = ?") + params.append(collection_id) + if joins: + query += " " + " ".join(joins) + if conditions: + query += " WHERE " + " AND ".join(conditions) + if language: + if conditions: + query += " AND s.language = ?" else: - fts_query = query + query += " WHERE s.language = ?" + params.append(language) + query += " ORDER BY s.updated_at DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + with self.get_cursor() as cursor: + cursor.execute(query, params) + snippets = [] + for row in cursor.fetchall(): + snippet = dict(row) + cursor.execute( + """SELECT t.name FROM tags t + JOIN snippet_tags st ON st.tag_id = t.id + WHERE st.snippet_id = ?""", + (snippet["id"],), + ) + snippet["tags"] = [r["name"] for r in cursor.fetchall()] + snippets.append(snippet) + return snippets + def search_snippets(self, query: str, limit: int = 50) -> list[dict[str, Any]]: + fts_query = self._build_fts_query(query) + with self.get_cursor() as cursor: cursor.execute( - """ - SELECT s.*, bm25(snippets_fts) as rank - FROM snippets s - JOIN snippets_fts ON s.id = snippets_fts.rowid - WHERE snippets_fts MATCH ? - ORDER BY rank - LIMIT ? - """, + """SELECT s.*, bm25(snippets_fts) as rank + FROM snippets_fts + JOIN snippets s ON s.id = snippets_fts.rowid + WHERE snippets_fts MATCH ? + ORDER BY rank + LIMIT ?""", (fts_query, limit), ) - results = [dict(row) for row in cursor.fetchall()] + snippets = [] + for row in cursor.fetchall(): + snippet = dict(row) + cursor.execute( + """SELECT t.name FROM tags t + JOIN snippet_tags st ON st.tag_id = t.id + WHERE st.snippet_id = ?""", + (snippet["id"],), + ) + snippet["tags"] = [r["name"] for r in cursor.fetchall()] + snippets.append(snippet) + return snippets - if tag: - results = [r for r in results if tag in json.loads(r.get("tags", "[]"))] + def _build_fts_query(self, query: str) -> str: + parts = query.strip().split() + if not parts: + return '""' + processed = [] + for part in parts: + if part in ("AND", "OR", "NOT"): + processed.append(part) + elif part.startswith('"') and part.endswith('"'): + processed.append(part) + else: + escaped = part.replace('"', '""') + processed.append(f'"{escaped}"') + return " ".join(processed) + "*" - return results + def tag_add(self, snippet_id: int, tag_name: str) -> bool: + tag_name = tag_name.strip().lower() + with self.get_cursor() as cursor: + cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,)) + cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,)) + tag_id = cursor.fetchone()[0] + cursor.execute( + "INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)", + (snippet_id, tag_id), + ) + return True - def tag_add(self, snippet_id: int, tag: str) -> bool: - snippet = self.get_snippet(snippet_id) - if not snippet: - return False - tags = json.loads(snippet["tags"]) - if tag not in tags: - tags.append(tag) - return self.update_snippet(snippet_id, tags=tags) - return True - - def tag_remove(self, snippet_id: int, tag: str) -> bool: - snippet = self.get_snippet(snippet_id) - if not snippet: - return False - tags = json.loads(snippet["tags"]) - if tag in tags: - tags.remove(tag) - return self.update_snippet(snippet_id, tags=tags) - return True + def tag_remove(self, snippet_id: int, tag_name: str) -> bool: + tag_name = tag_name.strip().lower() + with self.get_cursor() as cursor: + cursor.execute( + """DELETE FROM snippet_tags + WHERE snippet_id = ? AND tag_id = ( + SELECT id FROM tags WHERE name = ? + )""", + (snippet_id, tag_name), + ) + return cursor.rowcount > 0 def list_tags(self) -> list[str]: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute("SELECT tags FROM snippets") - all_tags: set[str] = set() - for row in cursor.fetchall(): - all_tags.update(json.loads(row["tags"])) - return sorted(all_tags) + with self.get_cursor() as cursor: + cursor.execute("SELECT name FROM tags ORDER BY name") + return [row["name"] for row in cursor.fetchall()] def collection_create(self, name: str, description: str = "") -> int: - """Alias for create_collection for backwards compatibility.""" - return self.create_collection(name, description) - - def create_collection(self, name: str, description: str = "") -> int: now = datetime.utcnow().isoformat() - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute( "INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)", (name, description, now), @@ -305,148 +325,136 @@ class Database: return cursor.lastrowid def collection_list(self) -> list[dict[str, Any]]: - """Alias for list_collections for backwards compatibility.""" - return self.list_collections() - - def list_collections(self) -> list[dict[str, Any]]: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT c.*, COUNT(sc.snippet_id) as snippet_count - FROM collections c - LEFT JOIN snippet_collections sc ON c.id = sc.collection_id - GROUP BY c.id - ORDER BY c.name - """) + with self.get_cursor() as cursor: + cursor.execute( + "SELECT c.*, COUNT(sc.snippet_id) as snippet_count " + "FROM collections c LEFT JOIN snippet_collections sc ON sc.collection_id = c.id " + "GROUP BY c.id ORDER BY c.name" + ) return [dict(row) for row in cursor.fetchall()] - def get_collection(self, collection_id: int) -> dict[str, Any] | None: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute("SELECT * FROM collections WHERE id = ?", (collection_id,)) - row = cursor.fetchone() - if row: - return dict(row) - return None - def collection_delete(self, collection_id: int) -> bool: - """Alias for delete_collection for backwards compatibility.""" - return self.delete_collection(collection_id) - - def delete_collection(self, collection_id: int) -> bool: - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,)) return cursor.rowcount > 0 def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool: - """Alias for add_snippet_to_collection for backwards compatibility.""" - return self.add_snippet_to_collection(snippet_id, collection_id) - - def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool: - snippet = self.get_snippet(snippet_id) - collection = self.get_collection(collection_id) - if not snippet or not collection: - return False - try: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute( - "INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)", - (snippet_id, collection_id), - ) - return True - except sqlite3.IntegrityError: + with self.get_cursor() as cursor: + cursor.execute( + "INSERT OR IGNORE INTO snippet_collections (collection_id, snippet_id) VALUES (?, ?)", + (collection_id, snippet_id), + ) return True def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool: - """Alias for remove_snippet_from_collection for backwards compatibility.""" - return self.remove_snippet_from_collection(snippet_id, collection_id) - - def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool: - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute( - "DELETE FROM snippet_collections WHERE snippet_id = ? AND collection_id = ?", - (snippet_id, collection_id), + "DELETE FROM snippet_collections WHERE collection_id = ? AND snippet_id = ?", + (collection_id, snippet_id), ) return cursor.rowcount > 0 - def get_collection_snippets(self, collection_id: int) -> list[dict[str, Any]]: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute( - """ - SELECT s.* FROM snippets s - JOIN snippet_collections sc ON s.id = sc.snippet_id - WHERE sc.collection_id = ? - ORDER BY s.updated_at DESC - """, - (collection_id,), - ) - return [dict(row) for row in cursor.fetchall()] - - def export_all(self) -> list[dict[str, Any]]: - return self.list_snippets(limit=10000) - - def import_snippet( - self, - data: dict[str, Any], - strategy: str = "skip", - ) -> int | None: - existing = None - if "title" in data: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute("SELECT id FROM snippets WHERE title = ?", (data["title"],)) - existing = cursor.fetchone() - - if existing: - if strategy == "skip": - return None - elif strategy == "replace": - self.update_snippet( - existing["id"], - title=data.get("title"), - description=data.get("description"), - code=data.get("code"), - language=data.get("language"), - tags=data.get("tags"), + def get_all_snippets_for_sync(self, since: str | None = None) -> list[dict[str, Any]]: + with self.get_cursor() as cursor: + if since: + cursor.execute( + "SELECT * FROM snippets WHERE updated_at > ? ORDER BY updated_at", + (since,), ) - return existing["id"] + else: + cursor.execute("SELECT * FROM snippets ORDER BY updated_at") + snippets = [] + for row in cursor.fetchall(): + snippet = dict(row) + cursor.execute( + """SELECT t.name FROM tags t + JOIN snippet_tags st ON st.tag_id = t.id + WHERE st.snippet_id = ?""", + (snippet["id"],), + ) + snippet["tags"] = [r["name"] for r in cursor.fetchall()] + snippets.append(snippet) + return snippets - return self.add_snippet( - title=data.get("title", "Untitled"), - code=data.get("code", ""), - description=data.get("description", ""), - language=data.get("language", ""), - tags=data.get("tags", []), - ) + def upsert_snippet(self, snippet_data: dict[str, Any]) -> int: + with self.get_cursor() as cursor: + cursor.execute("SELECT id FROM snippets WHERE id = ?", (snippet_data["id"],)) + existing = cursor.fetchone() + if existing: + fields = ["title", "description", "code", "language", "is_encrypted", "updated_at"] + updates = [] + params = [] + for field in fields: + if field in snippet_data: + updates.append(f"{field} = ?") + params.append(snippet_data[field]) + params.append(snippet_data["id"]) + cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params) + if "tags" in snippet_data: + self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"]) + return snippet_data["id"] + else: + cursor.execute( + """INSERT INTO snippets (id, title, description, code, language, is_encrypted, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + snippet_data["id"], + snippet_data["title"], + snippet_data.get("description", ""), + snippet_data["code"], + snippet_data.get("language", "text"), + snippet_data.get("is_encrypted", 0), + snippet_data["created_at"], + snippet_data["updated_at"], + ), + ) + if "tags" in snippet_data: + self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"]) + return snippet_data["id"] - def add_peer(self, peer_id: str, host: str, port: int): + def update_sync_meta(self, peer_id: str, peer_name: str | None = None, address: str | None = None, port: int | None = None): now = datetime.utcnow().isoformat() - with self.get_connection() as conn: - cursor = conn.cursor() + with self.get_cursor() as cursor: cursor.execute( - "INSERT OR REPLACE INTO sync_peers (peer_id, host, port, last_seen) VALUES (?, ?, ?, ?)", - (peer_id, host, port, now), + """INSERT INTO sync_meta (peer_id, peer_name, last_sync, peer_address, peer_port) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(peer_id) DO UPDATE SET + last_sync = COALESCE(excluded.last_sync, last_sync), + peer_name = COALESCE(excluded.peer_name, peer_name), + peer_address = COALESCE(excluded.peer_address, peer_address), + peer_port = COALESCE(excluded.peer_port, peer_port)""", + (peer_id, peer_name, now, address, port), ) - def list_peers(self) -> list[dict[str, Any]]: - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC") - return [dict(row) for row in cursor.fetchall()] + def get_sync_meta(self, peer_id: str) -> dict[str, Any] | None: + with self.get_cursor() as cursor: + cursor.execute("SELECT * FROM sync_meta WHERE peer_id = ?", (peer_id,)) + row = cursor.fetchone() + return dict(row) if row else None def list_sync_peers(self) -> list[dict[str, Any]]: - """Alias for list_peers for backwards compatibility.""" - return self.list_peers() + with self.get_cursor() as cursor: + cursor.execute("SELECT * FROM sync_meta ORDER BY last_sync DESC") + return [dict(row) for row in cursor.fetchall()] - def update_sync_meta(self, peer_id: str, last_sync: str): - """Update sync metadata for a peer.""" - with self.get_connection() as conn: - cursor = conn.cursor() - cursor.execute( - "UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?", - (last_sync, peer_id), - ) \ No newline at end of file + def close(self): + if hasattr(self._local, "conn") and self._local.conn: + self._local.conn.close() + self._local.conn = None + + +_db_instance: Database | None = None +_db_path: str | None = None + + +def get_database(db_path: str | None = None) -> Database: + global _db_instance, _db_path + if db_path is not None and _db_path != db_path: + _db_instance = Database(db_path) + _db_instance.init_schema() + _db_path = db_path + elif _db_instance is None: + _db_instance = Database(db_path) + _db_instance.init_schema() + _db_path = db_path + return _db_instance