fix: resolve CI test failures - API compatibility fixes
This commit is contained in:
@@ -1,303 +1,323 @@
|
|||||||
"""SQLite database with FTS5 search for snippet storage."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import time
|
import sqlite3
|
||||||
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
|
|
||||||
def get_database(db_path: str | None = None) -> "Database":
|
|
||||||
"""Get a Database instance."""
|
|
||||||
return Database(db_path)
|
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, db_path: str | None = None):
|
def __init__(self, db_path: str | None = None):
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
db_path = os.environ.get("SNIP_DB_PATH", "~/.snip/snippets.db")
|
db_path = os.environ.get("SNIP_DB_PATH", str(Path.home() / ".snip" / "snippets.db"))
|
||||||
self.db_path = os.path.expanduser(db_path)
|
self.db_path = db_path
|
||||||
|
self._local = threading.local()
|
||||||
self._ensure_dir()
|
self._ensure_dir()
|
||||||
self.conn = None
|
|
||||||
|
|
||||||
def _ensure_dir(self):
|
def _ensure_dir(self):
|
||||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@contextmanager
|
def get_connection(self) -> sqlite3.Connection:
|
||||||
def get_connection(self):
|
if not hasattr(self._local, "conn") or self._local.conn is None:
|
||||||
if self.conn is None:
|
self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||||
self.conn = sqlite3.connect(self.db_path)
|
self._local.conn.row_factory = sqlite3.Row
|
||||||
self.conn.row_factory = sqlite3.Row
|
self._local.conn.execute("PRAGMA foreign_keys = ON")
|
||||||
try:
|
self._local.conn.execute("PRAGMA journal_mode = WAL")
|
||||||
yield self.conn
|
return self._local.conn
|
||||||
self.conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self.conn.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def init_db(self):
|
@contextmanager
|
||||||
"""Initialize database schema."""
|
def get_cursor(self):
|
||||||
with self.get_connection() as conn:
|
conn = self.get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
try:
|
||||||
|
yield cursor
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def init_schema(self):
|
||||||
|
with self.get_cursor() as cursor:
|
||||||
|
cursor.executescript("""
|
||||||
CREATE TABLE IF NOT EXISTS snippets (
|
CREATE TABLE IF NOT EXISTS snippets (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
description TEXT,
|
description TEXT DEFAULT '',
|
||||||
code TEXT NOT NULL,
|
code TEXT NOT NULL,
|
||||||
language TEXT,
|
language TEXT DEFAULT 'text',
|
||||||
tags TEXT DEFAULT '[]',
|
|
||||||
is_encrypted INTEGER DEFAULT 0,
|
is_encrypted INTEGER DEFAULT 0,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT NOT NULL
|
updated_at TEXT NOT NULL
|
||||||
)
|
);
|
||||||
""")
|
|
||||||
|
CREATE TABLE IF NOT EXISTS tags (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS snippet_tags (
|
||||||
|
snippet_id INTEGER NOT NULL,
|
||||||
|
tag_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (snippet_id, tag_id),
|
||||||
|
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS collections (
|
CREATE TABLE IF NOT EXISTS collections (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
name TEXT NOT NULL UNIQUE,
|
name TEXT UNIQUE NOT NULL,
|
||||||
description TEXT,
|
description TEXT DEFAULT '',
|
||||||
created_at TEXT NOT NULL
|
created_at TEXT NOT NULL
|
||||||
)
|
);
|
||||||
""")
|
|
||||||
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS snippet_collections (
|
CREATE TABLE IF NOT EXISTS snippet_collections (
|
||||||
snippet_id INTEGER NOT NULL,
|
snippet_id INTEGER NOT NULL,
|
||||||
collection_id INTEGER NOT NULL,
|
collection_id INTEGER NOT NULL,
|
||||||
PRIMARY KEY (snippet_id, collection_id),
|
PRIMARY KEY (snippet_id, collection_id),
|
||||||
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
|
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
|
||||||
FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE
|
FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE
|
||||||
)
|
);
|
||||||
""")
|
|
||||||
|
|
||||||
cursor.execute("""
|
CREATE TABLE IF NOT EXISTS sync_meta (
|
||||||
CREATE TABLE IF NOT EXISTS sync_peers (
|
|
||||||
peer_id TEXT PRIMARY KEY,
|
peer_id TEXT PRIMARY KEY,
|
||||||
peer_name TEXT,
|
peer_name TEXT,
|
||||||
peer_address TEXT,
|
|
||||||
port INTEGER,
|
|
||||||
last_sync TEXT,
|
last_sync TEXT,
|
||||||
last_seen TEXT NOT NULL
|
peer_address TEXT,
|
||||||
)
|
peer_port INTEGER
|
||||||
""")
|
);
|
||||||
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5(
|
CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5(
|
||||||
title, description, code, tags,
|
title, description, code, tags,
|
||||||
content='snippets',
|
content='snippets',
|
||||||
content_rowid='id'
|
content_rowid='id'
|
||||||
)
|
);
|
||||||
""")
|
|
||||||
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN
|
CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN
|
||||||
INSERT INTO snippets_fts(rowid, title, description, code, tags)
|
INSERT INTO snippets_fts(rowid, title, description, code, tags)
|
||||||
VALUES (new.id, new.title, new.description, new.code, new.tags);
|
VALUES (new.id, new.title, new.description, new.code,
|
||||||
END
|
(SELECT GROUP_CONCAT(t.name) FROM tags t
|
||||||
""")
|
JOIN snippet_tags st ON st.tag_id = t.id
|
||||||
|
WHERE st.snippet_id = new.id));
|
||||||
|
END;
|
||||||
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN
|
CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN
|
||||||
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
|
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
|
||||||
VALUES ('delete', old.id, old.title, old.description, old.code, old.tags);
|
VALUES ('delete', old.id, old.title, old.description, old.code, '');
|
||||||
END
|
END;
|
||||||
""")
|
|
||||||
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN
|
CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN
|
||||||
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
|
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
|
||||||
VALUES ('delete', old.id, old.title, old.description, old.code, old.tags);
|
VALUES ('delete', old.id, old.title, old.description, old.code, '');
|
||||||
INSERT INTO snippets_fts(rowid, title, description, code, tags)
|
INSERT INTO snippets_fts(rowid, title, description, code, tags)
|
||||||
VALUES (new.id, new.title, new.description, new.code, new.tags);
|
VALUES (new.id, new.title, new.description, new.code,
|
||||||
END
|
(SELECT GROUP_CONCAT(t.name) FROM tags t
|
||||||
|
JOIN snippet_tags st ON st.tag_id = t.id
|
||||||
|
WHERE st.snippet_id = new.id));
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_snippets_language ON snippets(language);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_snippets_created ON snippets(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_tags_name ON tags(name);
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def init_schema(self):
|
def create_snippet(self, title: str, code: str, description: str = "", language: str = "text",
|
||||||
"""Alias for init_db for backwards compatibility."""
|
tags: list[str] | None = None, is_encrypted: bool = False) -> int:
|
||||||
self.init_db()
|
|
||||||
|
|
||||||
def create_snippet(
|
|
||||||
self,
|
|
||||||
title: str,
|
|
||||||
code: str,
|
|
||||||
description: str = "",
|
|
||||||
language: str = "",
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
is_encrypted: bool = False,
|
|
||||||
) -> int:
|
|
||||||
"""Alias for add_snippet for backwards compatibility."""
|
|
||||||
return self.add_snippet(title, code, description, language, tags, is_encrypted)
|
|
||||||
|
|
||||||
def add_snippet(
|
|
||||||
self,
|
|
||||||
title: str,
|
|
||||||
code: str,
|
|
||||||
description: str = "",
|
|
||||||
language: str = "",
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
is_encrypted: bool = False,
|
|
||||||
) -> int:
|
|
||||||
tags = tags or []
|
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""INSERT INTO snippets (title, description, code, language, is_encrypted, created_at, updated_at)
|
||||||
INSERT INTO snippets (title, description, code, language, tags, is_encrypted, created_at, updated_at)
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
(title, description, code, language, int(is_encrypted), now, now),
|
||||||
""",
|
)
|
||||||
(title, description, code, language, json.dumps(tags), int(is_encrypted), now, now),
|
snippet_id = cursor.lastrowid
|
||||||
|
assert snippet_id is not None
|
||||||
|
if tags:
|
||||||
|
self._update_snippet_tags(cursor, snippet_id, tags)
|
||||||
|
return snippet_id
|
||||||
|
|
||||||
|
def _update_snippet_tags(self, cursor: sqlite3.Cursor, snippet_id: int, tags: list[str]):
|
||||||
|
cursor.execute("DELETE FROM snippet_tags WHERE snippet_id = ?", (snippet_id,))
|
||||||
|
for tag_name in tags:
|
||||||
|
tag_name = tag_name.strip().lower()
|
||||||
|
if not tag_name:
|
||||||
|
continue
|
||||||
|
cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,))
|
||||||
|
cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,))
|
||||||
|
tag_id = cursor.fetchone()[0]
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)",
|
||||||
|
(snippet_id, tag_id),
|
||||||
)
|
)
|
||||||
return cursor.lastrowid
|
|
||||||
|
|
||||||
def get_snippet(self, snippet_id: int) -> dict[str, Any] | None:
|
def get_snippet(self, snippet_id: int) -> dict[str, Any] | None:
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,))
|
cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row:
|
if not row:
|
||||||
return dict(row)
|
return None
|
||||||
return None
|
snippet = dict(row)
|
||||||
|
|
||||||
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
query = "SELECT * FROM snippets WHERE 1=1"
|
|
||||||
params = []
|
|
||||||
if tag:
|
|
||||||
query += " AND tags LIKE ?"
|
|
||||||
params.append(f'%"{tag}"%')
|
|
||||||
if language:
|
|
||||||
query += " AND language = ?"
|
|
||||||
params.append(language)
|
|
||||||
if collection_id:
|
|
||||||
query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)"
|
|
||||||
params.append(collection_id)
|
|
||||||
query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?"
|
|
||||||
params.extend([limit, offset])
|
|
||||||
cursor.execute(query, params)
|
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
|
||||||
|
|
||||||
def update_snippet(
|
|
||||||
self,
|
|
||||||
snippet_id: int,
|
|
||||||
title: str | None = None,
|
|
||||||
description: str | None = None,
|
|
||||||
code: str | None = None,
|
|
||||||
language: str | None = None,
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
) -> bool:
|
|
||||||
snippet = self.get_snippet(snippet_id)
|
|
||||||
if not snippet:
|
|
||||||
return False
|
|
||||||
|
|
||||||
now = datetime.utcnow().isoformat()
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""SELECT t.name FROM tags t
|
||||||
UPDATE snippets SET
|
JOIN snippet_tags st ON st.tag_id = t.id
|
||||||
title = COALESCE(?, title),
|
WHERE st.snippet_id = ?""",
|
||||||
description = COALESCE(?, description),
|
(snippet_id,),
|
||||||
code = COALESCE(?, code),
|
|
||||||
language = COALESCE(?, language),
|
|
||||||
tags = COALESCE(?, tags),
|
|
||||||
updated_at = ?
|
|
||||||
WHERE id = ?
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
title,
|
|
||||||
description,
|
|
||||||
code,
|
|
||||||
language,
|
|
||||||
json.dumps(tags) if tags is not None else None,
|
|
||||||
now,
|
|
||||||
snippet_id,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return True
|
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
|
||||||
|
return snippet
|
||||||
|
|
||||||
|
def update_snippet(self, snippet_id: int, title: str | None = None, description: str | None = None,
|
||||||
|
code: str | None = None, language: str | None = None, tags: list[str] | None = None) -> bool:
|
||||||
|
now = datetime.utcnow().isoformat()
|
||||||
|
updates = []
|
||||||
|
params = []
|
||||||
|
if title is not None:
|
||||||
|
updates.append("title = ?")
|
||||||
|
params.append(title)
|
||||||
|
if description is not None:
|
||||||
|
updates.append("description = ?")
|
||||||
|
params.append(description)
|
||||||
|
if code is not None:
|
||||||
|
updates.append("code = ?")
|
||||||
|
params.append(code)
|
||||||
|
if language is not None:
|
||||||
|
updates.append("language = ?")
|
||||||
|
params.append(language)
|
||||||
|
if updates:
|
||||||
|
updates.append("updated_at = ?")
|
||||||
|
params.append(now)
|
||||||
|
params.append(snippet_id)
|
||||||
|
with self.get_cursor() as cursor:
|
||||||
|
cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params)
|
||||||
|
if tags is not None:
|
||||||
|
self._update_snippet_tags(cursor, snippet_id, tags)
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
if tags is not None:
|
||||||
|
with self.get_cursor() as cursor:
|
||||||
|
self._update_snippet_tags(cursor, snippet_id, tags)
|
||||||
|
return False
|
||||||
|
|
||||||
def delete_snippet(self, snippet_id: int) -> bool:
|
def delete_snippet(self, snippet_id: int) -> bool:
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,))
|
cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,))
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def search_snippets(
|
def list_snippets(self, language: str | None = None, tag: str | None = None,
|
||||||
self,
|
collection_id: int | None = None, limit: int = 50, offset: int = 0) -> list[dict[str, Any]]:
|
||||||
query: str,
|
query = """SELECT s.* FROM snippets s"""
|
||||||
limit: int = 50,
|
joins = []
|
||||||
language: str | None = None,
|
conditions = []
|
||||||
tag: str | None = None,
|
params = []
|
||||||
) -> list[dict[str, Any]]:
|
if tag:
|
||||||
with self.get_connection() as conn:
|
joins.append("JOIN snippet_tags st ON st.snippet_id = s.id")
|
||||||
cursor = conn.cursor()
|
joins.append("JOIN tags t ON t.id = st.tag_id")
|
||||||
if language:
|
conditions.append("t.name = ?")
|
||||||
fts_query = f"{query} AND language:{language}"
|
params.append(tag.lower())
|
||||||
|
if collection_id:
|
||||||
|
joins.append("JOIN snippet_collections sc ON sc.snippet_id = s.id")
|
||||||
|
conditions.append("sc.collection_id = ?")
|
||||||
|
params.append(collection_id)
|
||||||
|
if joins:
|
||||||
|
query += " " + " ".join(joins)
|
||||||
|
if conditions:
|
||||||
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
if language:
|
||||||
|
if conditions:
|
||||||
|
query += " AND s.language = ?"
|
||||||
else:
|
else:
|
||||||
fts_query = query
|
query += " WHERE s.language = ?"
|
||||||
|
params.append(language)
|
||||||
|
query += " ORDER BY s.updated_at DESC LIMIT ? OFFSET ?"
|
||||||
|
params.extend([limit, offset])
|
||||||
|
with self.get_cursor() as cursor:
|
||||||
|
cursor.execute(query, params)
|
||||||
|
snippets = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
snippet = dict(row)
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT t.name FROM tags t
|
||||||
|
JOIN snippet_tags st ON st.tag_id = t.id
|
||||||
|
WHERE st.snippet_id = ?""",
|
||||||
|
(snippet["id"],),
|
||||||
|
)
|
||||||
|
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
|
||||||
|
snippets.append(snippet)
|
||||||
|
return snippets
|
||||||
|
|
||||||
|
def search_snippets(self, query: str, limit: int = 50) -> list[dict[str, Any]]:
|
||||||
|
fts_query = self._build_fts_query(query)
|
||||||
|
with self.get_cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""SELECT s.*, bm25(snippets_fts) as rank
|
||||||
SELECT s.*, bm25(snippets_fts) as rank
|
FROM snippets_fts
|
||||||
FROM snippets s
|
JOIN snippets s ON s.id = snippets_fts.rowid
|
||||||
JOIN snippets_fts ON s.id = snippets_fts.rowid
|
WHERE snippets_fts MATCH ?
|
||||||
WHERE snippets_fts MATCH ?
|
ORDER BY rank
|
||||||
ORDER BY rank
|
LIMIT ?""",
|
||||||
LIMIT ?
|
|
||||||
""",
|
|
||||||
(fts_query, limit),
|
(fts_query, limit),
|
||||||
)
|
)
|
||||||
results = [dict(row) for row in cursor.fetchall()]
|
snippets = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
snippet = dict(row)
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT t.name FROM tags t
|
||||||
|
JOIN snippet_tags st ON st.tag_id = t.id
|
||||||
|
WHERE st.snippet_id = ?""",
|
||||||
|
(snippet["id"],),
|
||||||
|
)
|
||||||
|
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
|
||||||
|
snippets.append(snippet)
|
||||||
|
return snippets
|
||||||
|
|
||||||
if tag:
|
def _build_fts_query(self, query: str) -> str:
|
||||||
results = [r for r in results if tag in json.loads(r.get("tags", "[]"))]
|
parts = query.strip().split()
|
||||||
|
if not parts:
|
||||||
|
return '""'
|
||||||
|
processed = []
|
||||||
|
for part in parts:
|
||||||
|
if part in ("AND", "OR", "NOT"):
|
||||||
|
processed.append(part)
|
||||||
|
elif part.startswith('"') and part.endswith('"'):
|
||||||
|
processed.append(part)
|
||||||
|
else:
|
||||||
|
escaped = part.replace('"', '""')
|
||||||
|
processed.append(f'"{escaped}"')
|
||||||
|
return " ".join(processed) + "*"
|
||||||
|
|
||||||
return results
|
def tag_add(self, snippet_id: int, tag_name: str) -> bool:
|
||||||
|
tag_name = tag_name.strip().lower()
|
||||||
|
with self.get_cursor() as cursor:
|
||||||
|
cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,))
|
||||||
|
cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,))
|
||||||
|
tag_id = cursor.fetchone()[0]
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)",
|
||||||
|
(snippet_id, tag_id),
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
def tag_add(self, snippet_id: int, tag: str) -> bool:
|
def tag_remove(self, snippet_id: int, tag_name: str) -> bool:
|
||||||
snippet = self.get_snippet(snippet_id)
|
tag_name = tag_name.strip().lower()
|
||||||
if not snippet:
|
with self.get_cursor() as cursor:
|
||||||
return False
|
cursor.execute(
|
||||||
tags = json.loads(snippet["tags"])
|
"""DELETE FROM snippet_tags
|
||||||
if tag not in tags:
|
WHERE snippet_id = ? AND tag_id = (
|
||||||
tags.append(tag)
|
SELECT id FROM tags WHERE name = ?
|
||||||
return self.update_snippet(snippet_id, tags=tags)
|
)""",
|
||||||
return True
|
(snippet_id, tag_name),
|
||||||
|
)
|
||||||
def tag_remove(self, snippet_id: int, tag: str) -> bool:
|
return cursor.rowcount > 0
|
||||||
snippet = self.get_snippet(snippet_id)
|
|
||||||
if not snippet:
|
|
||||||
return False
|
|
||||||
tags = json.loads(snippet["tags"])
|
|
||||||
if tag in tags:
|
|
||||||
tags.remove(tag)
|
|
||||||
return self.update_snippet(snippet_id, tags=tags)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def list_tags(self) -> list[str]:
|
def list_tags(self) -> list[str]:
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
cursor.execute("SELECT name FROM tags ORDER BY name")
|
||||||
cursor.execute("SELECT tags FROM snippets")
|
return [row["name"] for row in cursor.fetchall()]
|
||||||
all_tags: set[str] = set()
|
|
||||||
for row in cursor.fetchall():
|
|
||||||
all_tags.update(json.loads(row["tags"]))
|
|
||||||
return sorted(all_tags)
|
|
||||||
|
|
||||||
def collection_create(self, name: str, description: str = "") -> int:
|
def collection_create(self, name: str, description: str = "") -> int:
|
||||||
"""Alias for create_collection for backwards compatibility."""
|
|
||||||
return self.create_collection(name, description)
|
|
||||||
|
|
||||||
def create_collection(self, name: str, description: str = "") -> int:
|
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)",
|
"INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)",
|
||||||
(name, description, now),
|
(name, description, now),
|
||||||
@@ -305,148 +325,136 @@ class Database:
|
|||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
|
|
||||||
def collection_list(self) -> list[dict[str, Any]]:
|
def collection_list(self) -> list[dict[str, Any]]:
|
||||||
"""Alias for list_collections for backwards compatibility."""
|
with self.get_cursor() as cursor:
|
||||||
return self.list_collections()
|
cursor.execute(
|
||||||
|
"SELECT c.*, COUNT(sc.snippet_id) as snippet_count "
|
||||||
def list_collections(self) -> list[dict[str, Any]]:
|
"FROM collections c LEFT JOIN snippet_collections sc ON sc.collection_id = c.id "
|
||||||
with self.get_connection() as conn:
|
"GROUP BY c.id ORDER BY c.name"
|
||||||
cursor = conn.cursor()
|
)
|
||||||
cursor.execute("""
|
|
||||||
SELECT c.*, COUNT(sc.snippet_id) as snippet_count
|
|
||||||
FROM collections c
|
|
||||||
LEFT JOIN snippet_collections sc ON c.id = sc.collection_id
|
|
||||||
GROUP BY c.id
|
|
||||||
ORDER BY c.name
|
|
||||||
""")
|
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT * FROM collections WHERE id = ?", (collection_id,))
|
|
||||||
row = cursor.fetchone()
|
|
||||||
if row:
|
|
||||||
return dict(row)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def collection_delete(self, collection_id: int) -> bool:
|
def collection_delete(self, collection_id: int) -> bool:
|
||||||
"""Alias for delete_collection for backwards compatibility."""
|
with self.get_cursor() as cursor:
|
||||||
return self.delete_collection(collection_id)
|
|
||||||
|
|
||||||
def delete_collection(self, collection_id: int) -> bool:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
|
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
||||||
"""Alias for add_snippet_to_collection for backwards compatibility."""
|
with self.get_cursor() as cursor:
|
||||||
return self.add_snippet_to_collection(snippet_id, collection_id)
|
cursor.execute(
|
||||||
|
"INSERT OR IGNORE INTO snippet_collections (collection_id, snippet_id) VALUES (?, ?)",
|
||||||
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
|
(collection_id, snippet_id),
|
||||||
snippet = self.get_snippet(snippet_id)
|
)
|
||||||
collection = self.get_collection(collection_id)
|
|
||||||
if not snippet or not collection:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
|
|
||||||
(snippet_id, collection_id),
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except sqlite3.IntegrityError:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
|
||||||
"""Alias for remove_snippet_from_collection for backwards compatibility."""
|
with self.get_cursor() as cursor:
|
||||||
return self.remove_snippet_from_collection(snippet_id, collection_id)
|
|
||||||
|
|
||||||
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"DELETE FROM snippet_collections WHERE snippet_id = ? AND collection_id = ?",
|
"DELETE FROM snippet_collections WHERE collection_id = ? AND snippet_id = ?",
|
||||||
(snippet_id, collection_id),
|
(collection_id, snippet_id),
|
||||||
)
|
)
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
def get_collection_snippets(self, collection_id: int) -> list[dict[str, Any]]:
|
def get_all_snippets_for_sync(self, since: str | None = None) -> list[dict[str, Any]]:
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
if since:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"SELECT * FROM snippets WHERE updated_at > ? ORDER BY updated_at",
|
||||||
SELECT s.* FROM snippets s
|
(since,),
|
||||||
JOIN snippet_collections sc ON s.id = sc.snippet_id
|
|
||||||
WHERE sc.collection_id = ?
|
|
||||||
ORDER BY s.updated_at DESC
|
|
||||||
""",
|
|
||||||
(collection_id,),
|
|
||||||
)
|
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
|
||||||
|
|
||||||
def export_all(self) -> list[dict[str, Any]]:
|
|
||||||
return self.list_snippets(limit=10000)
|
|
||||||
|
|
||||||
def import_snippet(
|
|
||||||
self,
|
|
||||||
data: dict[str, Any],
|
|
||||||
strategy: str = "skip",
|
|
||||||
) -> int | None:
|
|
||||||
existing = None
|
|
||||||
if "title" in data:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT id FROM snippets WHERE title = ?", (data["title"],))
|
|
||||||
existing = cursor.fetchone()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
if strategy == "skip":
|
|
||||||
return None
|
|
||||||
elif strategy == "replace":
|
|
||||||
self.update_snippet(
|
|
||||||
existing["id"],
|
|
||||||
title=data.get("title"),
|
|
||||||
description=data.get("description"),
|
|
||||||
code=data.get("code"),
|
|
||||||
language=data.get("language"),
|
|
||||||
tags=data.get("tags"),
|
|
||||||
)
|
)
|
||||||
return existing["id"]
|
else:
|
||||||
|
cursor.execute("SELECT * FROM snippets ORDER BY updated_at")
|
||||||
|
snippets = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
snippet = dict(row)
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT t.name FROM tags t
|
||||||
|
JOIN snippet_tags st ON st.tag_id = t.id
|
||||||
|
WHERE st.snippet_id = ?""",
|
||||||
|
(snippet["id"],),
|
||||||
|
)
|
||||||
|
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
|
||||||
|
snippets.append(snippet)
|
||||||
|
return snippets
|
||||||
|
|
||||||
return self.add_snippet(
|
def upsert_snippet(self, snippet_data: dict[str, Any]) -> int:
|
||||||
title=data.get("title", "Untitled"),
|
with self.get_cursor() as cursor:
|
||||||
code=data.get("code", ""),
|
cursor.execute("SELECT id FROM snippets WHERE id = ?", (snippet_data["id"],))
|
||||||
description=data.get("description", ""),
|
existing = cursor.fetchone()
|
||||||
language=data.get("language", ""),
|
if existing:
|
||||||
tags=data.get("tags", []),
|
fields = ["title", "description", "code", "language", "is_encrypted", "updated_at"]
|
||||||
)
|
updates = []
|
||||||
|
params = []
|
||||||
|
for field in fields:
|
||||||
|
if field in snippet_data:
|
||||||
|
updates.append(f"{field} = ?")
|
||||||
|
params.append(snippet_data[field])
|
||||||
|
params.append(snippet_data["id"])
|
||||||
|
cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params)
|
||||||
|
if "tags" in snippet_data:
|
||||||
|
self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"])
|
||||||
|
return snippet_data["id"]
|
||||||
|
else:
|
||||||
|
cursor.execute(
|
||||||
|
"""INSERT INTO snippets (id, title, description, code, language, is_encrypted, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
|
(
|
||||||
|
snippet_data["id"],
|
||||||
|
snippet_data["title"],
|
||||||
|
snippet_data.get("description", ""),
|
||||||
|
snippet_data["code"],
|
||||||
|
snippet_data.get("language", "text"),
|
||||||
|
snippet_data.get("is_encrypted", 0),
|
||||||
|
snippet_data["created_at"],
|
||||||
|
snippet_data["updated_at"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if "tags" in snippet_data:
|
||||||
|
self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"])
|
||||||
|
return snippet_data["id"]
|
||||||
|
|
||||||
def add_peer(self, peer_id: str, host: str, port: int):
|
def update_sync_meta(self, peer_id: str, peer_name: str | None = None, address: str | None = None, port: int | None = None):
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT OR REPLACE INTO sync_peers (peer_id, host, port, last_seen) VALUES (?, ?, ?, ?)",
|
"""INSERT INTO sync_meta (peer_id, peer_name, last_sync, peer_address, peer_port)
|
||||||
(peer_id, host, port, now),
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(peer_id) DO UPDATE SET
|
||||||
|
last_sync = COALESCE(excluded.last_sync, last_sync),
|
||||||
|
peer_name = COALESCE(excluded.peer_name, peer_name),
|
||||||
|
peer_address = COALESCE(excluded.peer_address, peer_address),
|
||||||
|
peer_port = COALESCE(excluded.peer_port, peer_port)""",
|
||||||
|
(peer_id, peer_name, now, address, port),
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_peers(self) -> list[dict[str, Any]]:
|
def get_sync_meta(self, peer_id: str) -> dict[str, Any] | None:
|
||||||
with self.get_connection() as conn:
|
with self.get_cursor() as cursor:
|
||||||
cursor = conn.cursor()
|
cursor.execute("SELECT * FROM sync_meta WHERE peer_id = ?", (peer_id,))
|
||||||
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
|
row = cursor.fetchone()
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return dict(row) if row else None
|
||||||
|
|
||||||
def list_sync_peers(self) -> list[dict[str, Any]]:
|
def list_sync_peers(self) -> list[dict[str, Any]]:
|
||||||
"""Alias for list_peers for backwards compatibility."""
|
with self.get_cursor() as cursor:
|
||||||
return self.list_peers()
|
cursor.execute("SELECT * FROM sync_meta ORDER BY last_sync DESC")
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
def update_sync_meta(self, peer_id: str, last_sync: str):
|
def close(self):
|
||||||
"""Update sync metadata for a peer."""
|
if hasattr(self._local, "conn") and self._local.conn:
|
||||||
with self.get_connection() as conn:
|
self._local.conn.close()
|
||||||
cursor = conn.cursor()
|
self._local.conn = None
|
||||||
cursor.execute(
|
|
||||||
"UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?",
|
|
||||||
(last_sync, peer_id),
|
_db_instance: Database | None = None
|
||||||
)
|
_db_path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_database(db_path: str | None = None) -> Database:
|
||||||
|
global _db_instance, _db_path
|
||||||
|
if db_path is not None and _db_path != db_path:
|
||||||
|
_db_instance = Database(db_path)
|
||||||
|
_db_instance.init_schema()
|
||||||
|
_db_path = db_path
|
||||||
|
elif _db_instance is None:
|
||||||
|
_db_instance = Database(db_path)
|
||||||
|
_db_instance.init_schema()
|
||||||
|
_db_path = db_path
|
||||||
|
return _db_instance
|
||||||
|
|||||||
Reference in New Issue
Block a user