Fix database.py - add missing methods and aliases for API compatibility
This commit is contained in:
@@ -11,6 +11,11 @@ from typing import Any
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
|
def get_database(db_path: str | None = None) -> "Database":
|
||||||
|
"""Get a Database instance."""
|
||||||
|
return Database(db_path)
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, db_path: str | None = None):
|
def __init__(self, db_path: str | None = None):
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
@@ -35,6 +40,7 @@ class Database:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def init_db(self):
|
def init_db(self):
|
||||||
|
"""Initialize database schema."""
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
@@ -73,8 +79,10 @@ class Database:
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS sync_peers (
|
CREATE TABLE IF NOT EXISTS sync_peers (
|
||||||
peer_id TEXT PRIMARY KEY,
|
peer_id TEXT PRIMARY KEY,
|
||||||
host TEXT NOT NULL,
|
peer_name TEXT,
|
||||||
port INTEGER NOT NULL,
|
peer_address TEXT,
|
||||||
|
port INTEGER,
|
||||||
|
last_sync TEXT,
|
||||||
last_seen TEXT NOT NULL
|
last_seen TEXT NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
@@ -110,6 +118,22 @@ class Database:
|
|||||||
END
|
END
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
def init_schema(self):
|
||||||
|
"""Alias for init_db for backwards compatibility."""
|
||||||
|
self.init_db()
|
||||||
|
|
||||||
|
def create_snippet(
|
||||||
|
self,
|
||||||
|
title: str,
|
||||||
|
code: str,
|
||||||
|
description: str = "",
|
||||||
|
language: str = "",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
is_encrypted: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Alias for add_snippet for backwards compatibility."""
|
||||||
|
return self.add_snippet(title, code, description, language, tags, is_encrypted)
|
||||||
|
|
||||||
def add_snippet(
|
def add_snippet(
|
||||||
self,
|
self,
|
||||||
title: str,
|
title: str,
|
||||||
@@ -141,19 +165,23 @@ class Database:
|
|||||||
return dict(row)
|
return dict(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None) -> list[dict[str, Any]]:
|
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]:
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
query = "SELECT * FROM snippets WHERE 1=1"
|
||||||
|
params = []
|
||||||
if tag:
|
if tag:
|
||||||
cursor.execute(
|
query += " AND tags LIKE ?"
|
||||||
"SELECT * FROM snippets WHERE tags LIKE ? ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
params.append(f'%"{tag}"%')
|
||||||
(f'%"{tag}"%', limit, offset),
|
if language:
|
||||||
)
|
query += " AND language = ?"
|
||||||
else:
|
params.append(language)
|
||||||
cursor.execute(
|
if collection_id:
|
||||||
"SELECT * FROM snippets ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)"
|
||||||
(limit, offset),
|
params.append(collection_id)
|
||||||
)
|
query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?"
|
||||||
|
params.extend([limit, offset])
|
||||||
|
cursor.execute(query, params)
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
def update_snippet(
|
def update_snippet(
|
||||||
@@ -233,7 +261,7 @@ class Database:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def add_tag(self, snippet_id: int, tag: str) -> bool:
|
def tag_add(self, snippet_id: int, tag: str) -> bool:
|
||||||
snippet = self.get_snippet(snippet_id)
|
snippet = self.get_snippet(snippet_id)
|
||||||
if not snippet:
|
if not snippet:
|
||||||
return False
|
return False
|
||||||
@@ -243,7 +271,7 @@ class Database:
|
|||||||
return self.update_snippet(snippet_id, tags=tags)
|
return self.update_snippet(snippet_id, tags=tags)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def remove_tag(self, snippet_id: int, tag: str) -> bool:
|
def tag_remove(self, snippet_id: int, tag: str) -> bool:
|
||||||
snippet = self.get_snippet(snippet_id)
|
snippet = self.get_snippet(snippet_id)
|
||||||
if not snippet:
|
if not snippet:
|
||||||
return False
|
return False
|
||||||
@@ -262,6 +290,10 @@ class Database:
|
|||||||
all_tags.update(json.loads(row["tags"]))
|
all_tags.update(json.loads(row["tags"]))
|
||||||
return sorted(all_tags)
|
return sorted(all_tags)
|
||||||
|
|
||||||
|
def collection_create(self, name: str, description: str = "") -> int:
|
||||||
|
"""Alias for create_collection for backwards compatibility."""
|
||||||
|
return self.create_collection(name, description)
|
||||||
|
|
||||||
def create_collection(self, name: str, description: str = "") -> int:
|
def create_collection(self, name: str, description: str = "") -> int:
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
@@ -272,10 +304,20 @@ class Database:
|
|||||||
)
|
)
|
||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def collection_list(self) -> list[dict[str, Any]]:
|
||||||
|
"""Alias for list_collections for backwards compatibility."""
|
||||||
|
return self.list_collections()
|
||||||
|
|
||||||
def list_collections(self) -> list[dict[str, Any]]:
|
def list_collections(self) -> list[dict[str, Any]]:
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT * FROM collections ORDER BY name")
|
cursor.execute("""
|
||||||
|
SELECT c.*, COUNT(sc.snippet_id) as snippet_count
|
||||||
|
FROM collections c
|
||||||
|
LEFT JOIN snippet_collections sc ON c.id = sc.collection_id
|
||||||
|
GROUP BY c.id
|
||||||
|
ORDER BY c.name
|
||||||
|
""")
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
|
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
|
||||||
@@ -287,12 +329,20 @@ class Database:
|
|||||||
return dict(row)
|
return dict(row)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def collection_delete(self, collection_id: int) -> bool:
|
||||||
|
"""Alias for delete_collection for backwards compatibility."""
|
||||||
|
return self.delete_collection(collection_id)
|
||||||
|
|
||||||
def delete_collection(self, collection_id: int) -> bool:
|
def delete_collection(self, collection_id: int) -> bool:
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
|
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
||||||
|
"""Alias for add_snippet_to_collection for backwards compatibility."""
|
||||||
|
return self.add_snippet_to_collection(snippet_id, collection_id)
|
||||||
|
|
||||||
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
|
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
|
||||||
snippet = self.get_snippet(snippet_id)
|
snippet = self.get_snippet(snippet_id)
|
||||||
collection = self.get_collection(collection_id)
|
collection = self.get_collection(collection_id)
|
||||||
@@ -302,13 +352,17 @@ class Database:
|
|||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
|
"INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
|
||||||
(snippet_id, collection_id),
|
(snippet_id, collection_id),
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
||||||
|
"""Alias for remove_snippet_from_collection for backwards compatibility."""
|
||||||
|
return self.remove_snippet_from_collection(snippet_id, collection_id)
|
||||||
|
|
||||||
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
|
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -383,3 +437,16 @@ class Database:
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
|
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def list_sync_peers(self) -> list[dict[str, Any]]:
|
||||||
|
"""Alias for list_peers for backwards compatibility."""
|
||||||
|
return self.list_peers()
|
||||||
|
|
||||||
|
def update_sync_meta(self, peer_id: str, last_sync: str):
|
||||||
|
"""Update sync metadata for a peer."""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?",
|
||||||
|
(last_sync, peer_id),
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user