Fix database.py - add missing methods and aliases for API compatibility
Some checks failed
CI / test (push) Failing after 14s
Some checks failed
CI / test (push) Failing after 14s
This commit is contained in:
@@ -11,6 +11,11 @@ from typing import Any
|
||||
import sqlite3
|
||||
|
||||
|
||||
def get_database(db_path: str | None = None) -> "Database":
|
||||
"""Get a Database instance."""
|
||||
return Database(db_path)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_path: str | None = None):
|
||||
if db_path is None:
|
||||
@@ -35,6 +40,7 @@ class Database:
|
||||
raise
|
||||
|
||||
def init_db(self):
|
||||
"""Initialize database schema."""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
@@ -73,8 +79,10 @@ class Database:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sync_peers (
|
||||
peer_id TEXT PRIMARY KEY,
|
||||
host TEXT NOT NULL,
|
||||
port INTEGER NOT NULL,
|
||||
peer_name TEXT,
|
||||
peer_address TEXT,
|
||||
port INTEGER,
|
||||
last_sync TEXT,
|
||||
last_seen TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
@@ -110,6 +118,22 @@ class Database:
|
||||
END
|
||||
""")
|
||||
|
||||
def init_schema(self):
|
||||
"""Alias for init_db for backwards compatibility."""
|
||||
self.init_db()
|
||||
|
||||
def create_snippet(
|
||||
self,
|
||||
title: str,
|
||||
code: str,
|
||||
description: str = "",
|
||||
language: str = "",
|
||||
tags: list[str] | None = None,
|
||||
is_encrypted: bool = False,
|
||||
) -> int:
|
||||
"""Alias for add_snippet for backwards compatibility."""
|
||||
return self.add_snippet(title, code, description, language, tags, is_encrypted)
|
||||
|
||||
def add_snippet(
|
||||
self,
|
||||
title: str,
|
||||
@@ -141,19 +165,23 @@ class Database:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None) -> list[dict[str, Any]]:
|
||||
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
query = "SELECT * FROM snippets WHERE 1=1"
|
||||
params = []
|
||||
if tag:
|
||||
cursor.execute(
|
||||
"SELECT * FROM snippets WHERE tags LIKE ? ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
(f'%"{tag}"%', limit, offset),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT * FROM snippets ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset),
|
||||
)
|
||||
query += " AND tags LIKE ?"
|
||||
params.append(f'%"{tag}"%')
|
||||
if language:
|
||||
query += " AND language = ?"
|
||||
params.append(language)
|
||||
if collection_id:
|
||||
query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)"
|
||||
params.append(collection_id)
|
||||
query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_snippet(
|
||||
@@ -233,7 +261,7 @@ class Database:
|
||||
|
||||
return results
|
||||
|
||||
def add_tag(self, snippet_id: int, tag: str) -> bool:
|
||||
def tag_add(self, snippet_id: int, tag: str) -> bool:
|
||||
snippet = self.get_snippet(snippet_id)
|
||||
if not snippet:
|
||||
return False
|
||||
@@ -243,7 +271,7 @@ class Database:
|
||||
return self.update_snippet(snippet_id, tags=tags)
|
||||
return True
|
||||
|
||||
def remove_tag(self, snippet_id: int, tag: str) -> bool:
|
||||
def tag_remove(self, snippet_id: int, tag: str) -> bool:
|
||||
snippet = self.get_snippet(snippet_id)
|
||||
if not snippet:
|
||||
return False
|
||||
@@ -262,6 +290,10 @@ class Database:
|
||||
all_tags.update(json.loads(row["tags"]))
|
||||
return sorted(all_tags)
|
||||
|
||||
def collection_create(self, name: str, description: str = "") -> int:
|
||||
"""Alias for create_collection for backwards compatibility."""
|
||||
return self.create_collection(name, description)
|
||||
|
||||
def create_collection(self, name: str, description: str = "") -> int:
|
||||
now = datetime.utcnow().isoformat()
|
||||
with self.get_connection() as conn:
|
||||
@@ -272,10 +304,20 @@ class Database:
|
||||
)
|
||||
return cursor.lastrowid
|
||||
|
||||
def collection_list(self) -> list[dict[str, Any]]:
|
||||
"""Alias for list_collections for backwards compatibility."""
|
||||
return self.list_collections()
|
||||
|
||||
def list_collections(self) -> list[dict[str, Any]]:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM collections ORDER BY name")
|
||||
cursor.execute("""
|
||||
SELECT c.*, COUNT(sc.snippet_id) as snippet_count
|
||||
FROM collections c
|
||||
LEFT JOIN snippet_collections sc ON c.id = sc.collection_id
|
||||
GROUP BY c.id
|
||||
ORDER BY c.name
|
||||
""")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
|
||||
@@ -287,12 +329,20 @@ class Database:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
def collection_delete(self, collection_id: int) -> bool:
|
||||
"""Alias for delete_collection for backwards compatibility."""
|
||||
return self.delete_collection(collection_id)
|
||||
|
||||
def delete_collection(self, collection_id: int) -> bool:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
||||
"""Alias for add_snippet_to_collection for backwards compatibility."""
|
||||
return self.add_snippet_to_collection(snippet_id, collection_id)
|
||||
|
||||
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
|
||||
snippet = self.get_snippet(snippet_id)
|
||||
collection = self.get_collection(collection_id)
|
||||
@@ -302,13 +352,17 @@ class Database:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
|
||||
"INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
|
||||
(snippet_id, collection_id),
|
||||
)
|
||||
return True
|
||||
except sqlite3.IntegrityError:
|
||||
return True
|
||||
|
||||
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
||||
"""Alias for remove_snippet_from_collection for backwards compatibility."""
|
||||
return self.remove_snippet_from_collection(snippet_id, collection_id)
|
||||
|
||||
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -383,3 +437,16 @@ class Database:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def list_sync_peers(self) -> list[dict[str, Any]]:
|
||||
"""Alias for list_peers for backwards compatibility."""
|
||||
return self.list_peers()
|
||||
|
||||
def update_sync_meta(self, peer_id: str, last_sync: str):
|
||||
"""Update sync metadata for a peer."""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?",
|
||||
(last_sync, peer_id),
|
||||
)
|
||||
Reference in New Issue
Block a user