From f79e9425a9b5ffd5e50a1b1d71ff9bb6834ce66a Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Sun, 22 Mar 2026 11:46:29 +0000 Subject: [PATCH] Fix database.py - add missing methods and aliases for API compatibility --- snip/db/database.py | 99 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 83 insertions(+), 16 deletions(-) diff --git a/snip/db/database.py b/snip/db/database.py index a471614..00b6497 100644 --- a/snip/db/database.py +++ b/snip/db/database.py @@ -11,6 +11,11 @@ from typing import Any import sqlite3 +def get_database(db_path: str | None = None) -> "Database": + """Get a Database instance.""" + return Database(db_path) + + class Database: def __init__(self, db_path: str | None = None): if db_path is None: @@ -35,6 +40,7 @@ class Database: raise def init_db(self): + """Initialize database schema.""" with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" @@ -73,8 +79,10 @@ class Database: cursor.execute(""" CREATE TABLE IF NOT EXISTS sync_peers ( peer_id TEXT PRIMARY KEY, - host TEXT NOT NULL, - port INTEGER NOT NULL, + peer_name TEXT, + peer_address TEXT, + port INTEGER, + last_sync TEXT, last_seen TEXT NOT NULL ) """) @@ -110,6 +118,22 @@ class Database: END """) + def init_schema(self): + """Alias for init_db for backwards compatibility.""" + self.init_db() + + def create_snippet( + self, + title: str, + code: str, + description: str = "", + language: str = "", + tags: list[str] | None = None, + is_encrypted: bool = False, + ) -> int: + """Alias for add_snippet for backwards compatibility.""" + return self.add_snippet(title, code, description, language, tags, is_encrypted) + def add_snippet( self, title: str, @@ -141,19 +165,23 @@ class Database: return dict(row) return None - def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None) -> list[dict[str, Any]]: + def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]: with self.get_connection() as conn: cursor = conn.cursor() + query = "SELECT * FROM snippets WHERE 1=1" + params = [] if tag: - cursor.execute( - "SELECT * FROM snippets WHERE tags LIKE ? ORDER BY updated_at DESC LIMIT ? OFFSET ?", - (f'%"{tag}"%', limit, offset), - ) - else: - cursor.execute( - "SELECT * FROM snippets ORDER BY updated_at DESC LIMIT ? OFFSET ?", - (limit, offset), - ) + query += " AND tags LIKE ?" + params.append(f'%"{tag}"%') + if language: + query += " AND language = ?" + params.append(language) + if collection_id: + query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)" + params.append(collection_id) + query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] def update_snippet( @@ -233,7 +261,7 @@ class Database: return results - def add_tag(self, snippet_id: int, tag: str) -> bool: + def tag_add(self, snippet_id: int, tag: str) -> bool: snippet = self.get_snippet(snippet_id) if not snippet: return False @@ -243,7 +271,7 @@ class Database: return self.update_snippet(snippet_id, tags=tags) return True - def remove_tag(self, snippet_id: int, tag: str) -> bool: + def tag_remove(self, snippet_id: int, tag: str) -> bool: snippet = self.get_snippet(snippet_id) if not snippet: return False @@ -262,6 +290,10 @@ class Database: all_tags.update(json.loads(row["tags"])) return sorted(all_tags) + def collection_create(self, name: str, description: str = "") -> int: + """Alias for create_collection for backwards compatibility.""" + return self.create_collection(name, description) + def create_collection(self, name: str, description: str = "") -> int: now = datetime.utcnow().isoformat() with self.get_connection() as conn: @@ -272,10 +304,20 @@ class Database: ) return cursor.lastrowid + def collection_list(self) -> list[dict[str, Any]]: + """Alias for list_collections for backwards compatibility.""" + return self.list_collections() + def list_collections(self) -> list[dict[str, Any]]: with self.get_connection() as conn: cursor = conn.cursor() - cursor.execute("SELECT * FROM collections ORDER BY name") + cursor.execute(""" + SELECT c.*, COUNT(sc.snippet_id) as snippet_count + FROM collections c + LEFT JOIN snippet_collections sc ON c.id = sc.collection_id + GROUP BY c.id + ORDER BY c.name + """) return [dict(row) for row in cursor.fetchall()] def get_collection(self, collection_id: int) -> dict[str, Any] | None: @@ -287,12 +329,20 @@ class Database: return dict(row) return None + def collection_delete(self, collection_id: int) -> bool: + """Alias for delete_collection for backwards compatibility.""" + return self.delete_collection(collection_id) + def delete_collection(self, collection_id: int) -> bool: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,)) return cursor.rowcount > 0 + def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool: + """Alias for add_snippet_to_collection for backwards compatibility.""" + return self.add_snippet_to_collection(snippet_id, collection_id) + def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool: snippet = self.get_snippet(snippet_id) collection = self.get_collection(collection_id) @@ -302,13 +352,17 @@ class Database: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( - "INSERT INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)", + "INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)", (snippet_id, collection_id), ) return True except sqlite3.IntegrityError: return True + def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool: + """Alias for remove_snippet_from_collection for backwards compatibility.""" + return self.remove_snippet_from_collection(snippet_id, collection_id) + def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool: with self.get_connection() as conn: cursor = conn.cursor() @@ -383,3 +437,16 @@ class Database: cursor = conn.cursor() cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC") return [dict(row) for row in cursor.fetchall()] + + def list_sync_peers(self) -> list[dict[str, Any]]: + """Alias for list_peers for backwards compatibility.""" + return self.list_peers() + + def update_sync_meta(self, peer_id: str, last_sync: str): + """Update sync metadata for a peer.""" + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?", + (last_sync, peer_id), + ) \ No newline at end of file