Fix database.py - add missing methods and aliases for API compatibility
Some checks failed
CI / test (push) Failing after 14s

This commit is contained in:
2026-03-22 11:46:29 +00:00
parent 583c25eaaa
commit f79e9425a9

View File

@@ -11,6 +11,11 @@ from typing import Any
import sqlite3
def get_database(db_path: str | None = None) -> "Database":
"""Get a Database instance."""
return Database(db_path)
class Database:
def __init__(self, db_path: str | None = None):
if db_path is None:
@@ -35,6 +40,7 @@ class Database:
raise
def init_db(self):
"""Initialize database schema."""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
@@ -73,8 +79,10 @@ class Database:
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_peers (
peer_id TEXT PRIMARY KEY,
host TEXT NOT NULL,
port INTEGER NOT NULL,
peer_name TEXT,
peer_address TEXT,
port INTEGER,
last_sync TEXT,
last_seen TEXT NOT NULL
)
""")
@@ -110,6 +118,22 @@ class Database:
END
""")
def init_schema(self):
"""Alias for init_db for backwards compatibility."""
self.init_db()
def create_snippet(
self,
title: str,
code: str,
description: str = "",
language: str = "",
tags: list[str] | None = None,
is_encrypted: bool = False,
) -> int:
"""Alias for add_snippet for backwards compatibility."""
return self.add_snippet(title, code, description, language, tags, is_encrypted)
def add_snippet(
self,
title: str,
@@ -141,19 +165,23 @@ class Database:
return dict(row)
return None
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None) -> list[dict[str, Any]]:
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
query = "SELECT * FROM snippets WHERE 1=1"
params = []
if tag:
cursor.execute(
"SELECT * FROM snippets WHERE tags LIKE ? ORDER BY updated_at DESC LIMIT ? OFFSET ?",
(f'%"{tag}"%', limit, offset),
)
else:
cursor.execute(
"SELECT * FROM snippets ORDER BY updated_at DESC LIMIT ? OFFSET ?",
(limit, offset),
)
query += " AND tags LIKE ?"
params.append(f'%"{tag}"%')
if language:
query += " AND language = ?"
params.append(language)
if collection_id:
query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)"
params.append(collection_id)
query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def update_snippet(
@@ -233,7 +261,7 @@ class Database:
return results
def add_tag(self, snippet_id: int, tag: str) -> bool:
def tag_add(self, snippet_id: int, tag: str) -> bool:
snippet = self.get_snippet(snippet_id)
if not snippet:
return False
@@ -243,7 +271,7 @@ class Database:
return self.update_snippet(snippet_id, tags=tags)
return True
def remove_tag(self, snippet_id: int, tag: str) -> bool:
def tag_remove(self, snippet_id: int, tag: str) -> bool:
snippet = self.get_snippet(snippet_id)
if not snippet:
return False
@@ -262,6 +290,10 @@ class Database:
all_tags.update(json.loads(row["tags"]))
return sorted(all_tags)
def collection_create(self, name: str, description: str = "") -> int:
"""Alias for create_collection for backwards compatibility."""
return self.create_collection(name, description)
def create_collection(self, name: str, description: str = "") -> int:
now = datetime.utcnow().isoformat()
with self.get_connection() as conn:
@@ -272,10 +304,20 @@ class Database:
)
return cursor.lastrowid
def collection_list(self) -> list[dict[str, Any]]:
"""Alias for list_collections for backwards compatibility."""
return self.list_collections()
def list_collections(self) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM collections ORDER BY name")
cursor.execute("""
SELECT c.*, COUNT(sc.snippet_id) as snippet_count
FROM collections c
LEFT JOIN snippet_collections sc ON c.id = sc.collection_id
GROUP BY c.id
ORDER BY c.name
""")
return [dict(row) for row in cursor.fetchall()]
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
@@ -287,12 +329,20 @@ class Database:
return dict(row)
return None
def collection_delete(self, collection_id: int) -> bool:
"""Alias for delete_collection for backwards compatibility."""
return self.delete_collection(collection_id)
def delete_collection(self, collection_id: int) -> bool:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
return cursor.rowcount > 0
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
"""Alias for add_snippet_to_collection for backwards compatibility."""
return self.add_snippet_to_collection(snippet_id, collection_id)
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
snippet = self.get_snippet(snippet_id)
collection = self.get_collection(collection_id)
@@ -302,13 +352,17 @@ class Database:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
"INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
(snippet_id, collection_id),
)
return True
except sqlite3.IntegrityError:
return True
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
"""Alias for remove_snippet_from_collection for backwards compatibility."""
return self.remove_snippet_from_collection(snippet_id, collection_id)
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
with self.get_connection() as conn:
cursor = conn.cursor()
@@ -383,3 +437,16 @@ class Database:
cursor = conn.cursor()
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
return [dict(row) for row in cursor.fetchall()]
def list_sync_peers(self) -> list[dict[str, Any]]:
"""Alias for list_peers for backwards compatibility."""
return self.list_peers()
def update_sync_meta(self, peer_id: str, last_sync: str):
"""Update sync metadata for a peer."""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?",
(last_sync, peer_id),
)