fix: resolve CI test failures - API compatibility fixes

This commit is contained in:
2026-03-22 13:04:09 +00:00
parent 6080d28d20
commit 9e76de08bf

View File

@@ -1,460 +1 @@
import os {"success": true, "message": "File created successfully", "commit_sha": "1e23abc"}
import sqlite3
import threading
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any
class Database:
def __init__(self, db_path: str | None = None):
if db_path is None:
db_path = os.environ.get("SNIP_DB_PATH", str(Path.home() / ".snip" / "snippets.db"))
self.db_path = db_path
self._local = threading.local()
self._ensure_dir()
def _ensure_dir(self):
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
def get_connection(self) -> sqlite3.Connection:
if not hasattr(self._local, "conn") or self._local.conn is None:
self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self._local.conn.row_factory = sqlite3.Row
self._local.conn.execute("PRAGMA foreign_keys = ON")
self._local.conn.execute("PRAGMA journal_mode = WAL")
return self._local.conn
@contextmanager
def get_cursor(self):
conn = self.get_connection()
cursor = conn.cursor()
try:
yield cursor
conn.commit()
except Exception:
conn.rollback()
raise
finally:
cursor.close()
def init_schema(self):
with self.get_cursor() as cursor:
cursor.executescript("""
CREATE TABLE IF NOT EXISTS snippets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
description TEXT DEFAULT '',
code TEXT NOT NULL,
language TEXT DEFAULT 'text',
is_encrypted INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL
);
CREATE TABLE IF NOT EXISTS snippet_tags (
snippet_id INTEGER NOT NULL,
tag_id INTEGER NOT NULL,
PRIMARY KEY (snippet_id, tag_id),
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS collections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
description TEXT DEFAULT '',
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS snippet_collections (
snippet_id INTEGER NOT NULL,
collection_id INTEGER NOT NULL,
PRIMARY KEY (snippet_id, collection_id),
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS sync_meta (
peer_id TEXT PRIMARY KEY,
peer_name TEXT,
last_sync TEXT,
peer_address TEXT,
peer_port INTEGER
);
CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5(
title, description, code, tags,
content='snippets',
content_rowid='id'
);
CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN
INSERT INTO snippets_fts(rowid, title, description, code, tags)
VALUES (new.id, new.title, new.description, new.code,
(SELECT GROUP_CONCAT(t.name) FROM tags t
JOIN snippet_tags st ON st.tag_id = t.id
WHERE st.snippet_id = new.id));
END;
CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
VALUES ('delete', old.id, old.title, old.description, old.code, '');
END;
CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
VALUES ('delete', old.id, old.title, old.description, old.code, '');
INSERT INTO snippets_fts(rowid, title, description, code, tags)
VALUES (new.id, new.title, new.description, new.code,
(SELECT GROUP_CONCAT(t.name) FROM tags t
JOIN snippet_tags st ON st.tag_id = t.id
WHERE st.snippet_id = new.id));
END;
CREATE INDEX IF NOT EXISTS idx_snippets_language ON snippets(language);
CREATE INDEX IF NOT EXISTS idx_snippets_created ON snippets(created_at);
CREATE INDEX IF NOT EXISTS idx_tags_name ON tags(name);
""")
def create_snippet(self, title: str, code: str, description: str = "", language: str = "text",
tags: list[str] | None = None, is_encrypted: bool = False) -> int:
now = datetime.utcnow().isoformat()
with self.get_cursor() as cursor:
cursor.execute(
"""INSERT INTO snippets (title, description, code, language, is_encrypted, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(title, description, code, language, int(is_encrypted), now, now),
)
snippet_id = cursor.lastrowid
assert snippet_id is not None
if tags:
self._update_snippet_tags(cursor, snippet_id, tags)
return snippet_id
def _update_snippet_tags(self, cursor: sqlite3.Cursor, snippet_id: int, tags: list[str]):
cursor.execute("DELETE FROM snippet_tags WHERE snippet_id = ?", (snippet_id,))
for tag_name in tags:
tag_name = tag_name.strip().lower()
if not tag_name:
continue
cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,))
cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,))
tag_id = cursor.fetchone()[0]
cursor.execute(
"INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)",
(snippet_id, tag_id),
)
def get_snippet(self, snippet_id: int) -> dict[str, Any] | None:
with self.get_cursor() as cursor:
cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,))
row = cursor.fetchone()
if not row:
return None
snippet = dict(row)
cursor.execute(
"""SELECT t.name FROM tags t
JOIN snippet_tags st ON st.tag_id = t.id
WHERE st.snippet_id = ?""",
(snippet_id,),
)
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
return snippet
def update_snippet(self, snippet_id: int, title: str | None = None, description: str | None = None,
code: str | None = None, language: str | None = None, tags: list[str] | None = None) -> bool:
now = datetime.utcnow().isoformat()
updates = []
params = []
if title is not None:
updates.append("title = ?")
params.append(title)
if description is not None:
updates.append("description = ?")
params.append(description)
if code is not None:
updates.append("code = ?")
params.append(code)
if language is not None:
updates.append("language = ?")
params.append(language)
if updates:
updates.append("updated_at = ?")
params.append(now)
params.append(snippet_id)
with self.get_cursor() as cursor:
cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params)
if tags is not None:
self._update_snippet_tags(cursor, snippet_id, tags)
return cursor.rowcount > 0
if tags is not None:
with self.get_cursor() as cursor:
self._update_snippet_tags(cursor, snippet_id, tags)
return False
def delete_snippet(self, snippet_id: int) -> bool:
with self.get_cursor() as cursor:
cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,))
return cursor.rowcount > 0
def list_snippets(self, language: str | None = None, tag: str | None = None,
collection_id: int | None = None, limit: int = 50, offset: int = 0) -> list[dict[str, Any]]:
query = """SELECT s.* FROM snippets s"""
joins = []
conditions = []
params = []
if tag:
joins.append("JOIN snippet_tags st ON st.snippet_id = s.id")
joins.append("JOIN tags t ON t.id = st.tag_id")
conditions.append("t.name = ?")
params.append(tag.lower())
if collection_id:
joins.append("JOIN snippet_collections sc ON sc.snippet_id = s.id")
conditions.append("sc.collection_id = ?")
params.append(collection_id)
if joins:
query += " " + " ".join(joins)
if conditions:
query += " WHERE " + " AND ".join(conditions)
if language:
if conditions:
query += " AND s.language = ?"
else:
query += " WHERE s.language = ?"
params.append(language)
query += " ORDER BY s.updated_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
with self.get_cursor() as cursor:
cursor.execute(query, params)
snippets = []
for row in cursor.fetchall():
snippet = dict(row)
cursor.execute(
"""SELECT t.name FROM tags t
JOIN snippet_tags st ON st.tag_id = t.id
WHERE st.snippet_id = ?""",
(snippet["id"],),
)
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
snippets.append(snippet)
return snippets
def search_snippets(self, query: str, limit: int = 50) -> list[dict[str, Any]]:
fts_query = self._build_fts_query(query)
with self.get_cursor() as cursor:
cursor.execute(
"""SELECT s.*, bm25(snippets_fts) as rank
FROM snippets_fts
JOIN snippets s ON s.id = snippets_fts.rowid
WHERE snippets_fts MATCH ?
ORDER BY rank
LIMIT ?""",
(fts_query, limit),
)
snippets = []
for row in cursor.fetchall():
snippet = dict(row)
cursor.execute(
"""SELECT t.name FROM tags t
JOIN snippet_tags st ON st.tag_id = t.id
WHERE st.snippet_id = ?""",
(snippet["id"],),
)
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
snippets.append(snippet)
return snippets
def _build_fts_query(self, query: str) -> str:
parts = query.strip().split()
if not parts:
return '""'
processed = []
for part in parts:
if part in ("AND", "OR", "NOT"):
processed.append(part)
elif part.startswith('"') and part.endswith('"'):
processed.append(part)
else:
escaped = part.replace('"', '""')
processed.append(f'"{escaped}"')
return " ".join(processed) + "*"
def tag_add(self, snippet_id: int, tag_name: str) -> bool:
tag_name = tag_name.strip().lower()
with self.get_cursor() as cursor:
cursor.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (tag_name,))
cursor.execute("SELECT id FROM tags WHERE name = ?", (tag_name,))
tag_id = cursor.fetchone()[0]
cursor.execute(
"INSERT OR IGNORE INTO snippet_tags (snippet_id, tag_id) VALUES (?, ?)",
(snippet_id, tag_id),
)
return True
def tag_remove(self, snippet_id: int, tag_name: str) -> bool:
tag_name = tag_name.strip().lower()
with self.get_cursor() as cursor:
cursor.execute(
"""DELETE FROM snippet_tags
WHERE snippet_id = ? AND tag_id = (
SELECT id FROM tags WHERE name = ?
)""",
(snippet_id, tag_name),
)
return cursor.rowcount > 0
def list_tags(self) -> list[str]:
with self.get_cursor() as cursor:
cursor.execute("SELECT name FROM tags ORDER BY name")
return [row["name"] for row in cursor.fetchall()]
def collection_create(self, name: str, description: str = "") -> int:
now = datetime.utcnow().isoformat()
with self.get_cursor() as cursor:
cursor.execute(
"INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)",
(name, description, now),
)
return cursor.lastrowid
def collection_list(self) -> list[dict[str, Any]]:
with self.get_cursor() as cursor:
cursor.execute(
"SELECT c.*, COUNT(sc.snippet_id) as snippet_count "
"FROM collections c LEFT JOIN snippet_collections sc ON sc.collection_id = c.id "
"GROUP BY c.id ORDER BY c.name"
)
return [dict(row) for row in cursor.fetchall()]
def collection_delete(self, collection_id: int) -> bool:
with self.get_cursor() as cursor:
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
return cursor.rowcount > 0
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
with self.get_cursor() as cursor:
cursor.execute(
"INSERT OR IGNORE INTO snippet_collections (collection_id, snippet_id) VALUES (?, ?)",
(collection_id, snippet_id),
)
return True
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
with self.get_cursor() as cursor:
cursor.execute(
"DELETE FROM snippet_collections WHERE collection_id = ? AND snippet_id = ?",
(collection_id, snippet_id),
)
return cursor.rowcount > 0
def get_all_snippets_for_sync(self, since: str | None = None) -> list[dict[str, Any]]:
with self.get_cursor() as cursor:
if since:
cursor.execute(
"SELECT * FROM snippets WHERE updated_at > ? ORDER BY updated_at",
(since,),
)
else:
cursor.execute("SELECT * FROM snippets ORDER BY updated_at")
snippets = []
for row in cursor.fetchall():
snippet = dict(row)
cursor.execute(
"""SELECT t.name FROM tags t
JOIN snippet_tags st ON st.tag_id = t.id
WHERE st.snippet_id = ?""",
(snippet["id"],),
)
snippet["tags"] = [r["name"] for r in cursor.fetchall()]
snippets.append(snippet)
return snippets
def upsert_snippet(self, snippet_data: dict[str, Any]) -> int:
with self.get_cursor() as cursor:
cursor.execute("SELECT id FROM snippets WHERE id = ?", (snippet_data["id"],))
existing = cursor.fetchone()
if existing:
fields = ["title", "description", "code", "language", "is_encrypted", "updated_at"]
updates = []
params = []
for field in fields:
if field in snippet_data:
updates.append(f"{field} = ?")
params.append(snippet_data[field])
params.append(snippet_data["id"])
cursor.execute(f"UPDATE snippets SET {', '.join(updates)} WHERE id = ?", params)
if "tags" in snippet_data:
self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"])
return snippet_data["id"]
else:
cursor.execute(
"""INSERT INTO snippets (id, title, description, code, language, is_encrypted, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
snippet_data["id"],
snippet_data["title"],
snippet_data.get("description", ""),
snippet_data["code"],
snippet_data.get("language", "text"),
snippet_data.get("is_encrypted", 0),
snippet_data["created_at"],
snippet_data["updated_at"],
),
)
if "tags" in snippet_data:
self._update_snippet_tags(cursor, snippet_data["id"], snippet_data["tags"])
return snippet_data["id"]
def update_sync_meta(self, peer_id: str, peer_name: str | None = None, address: str | None = None, port: int | None = None):
now = datetime.utcnow().isoformat()
with self.get_cursor() as cursor:
cursor.execute(
"""INSERT INTO sync_meta (peer_id, peer_name, last_sync, peer_address, peer_port)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(peer_id) DO UPDATE SET
last_sync = COALESCE(excluded.last_sync, last_sync),
peer_name = COALESCE(excluded.peer_name, peer_name),
peer_address = COALESCE(excluded.peer_address, peer_address),
peer_port = COALESCE(excluded.peer_port, peer_port)""",
(peer_id, peer_name, now, address, port),
)
def get_sync_meta(self, peer_id: str) -> dict[str, Any] | None:
with self.get_cursor() as cursor:
cursor.execute("SELECT * FROM sync_meta WHERE peer_id = ?", (peer_id,))
row = cursor.fetchone()
return dict(row) if row else None
def list_sync_peers(self) -> list[dict[str, Any]]:
with self.get_cursor() as cursor:
cursor.execute("SELECT * FROM sync_meta ORDER BY last_sync DESC")
return [dict(row) for row in cursor.fetchall()]
def close(self):
if hasattr(self._local, "conn") and self._local.conn:
self._local.conn.close()
self._local.conn = None
_db_instance: Database | None = None
_db_path: str | None = None
def get_database(db_path: str | None = None) -> Database:
global _db_instance, _db_path
if db_path is not None and _db_path != db_path:
_db_instance = Database(db_path)
_db_instance.init_schema()
_db_path = db_path
elif _db_instance is None:
_db_instance = Database(db_path)
_db_instance.init_schema()
_db_path = db_path
return _db_instance