386 lines
14 KiB
Python
386 lines
14 KiB
Python
"""SQLite database with FTS5 search for snippet storage."""
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import sqlite3
|
|
|
|
|
|
class Database:
|
|
def __init__(self, db_path: str | None = None):
|
|
if db_path is None:
|
|
db_path = os.environ.get("SNIP_DB_PATH", "~/.snip/snippets.db")
|
|
self.db_path = os.path.expanduser(db_path)
|
|
self._ensure_dir()
|
|
self.conn = None
|
|
|
|
def _ensure_dir(self):
|
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
@contextmanager
|
|
def get_connection(self):
|
|
if self.conn is None:
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
self.conn.row_factory = sqlite3.Row
|
|
try:
|
|
yield self.conn
|
|
self.conn.commit()
|
|
except Exception:
|
|
self.conn.rollback()
|
|
raise
|
|
|
|
def init_db(self):
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS snippets (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
title TEXT NOT NULL,
|
|
description TEXT,
|
|
code TEXT NOT NULL,
|
|
language TEXT,
|
|
tags TEXT DEFAULT '[]',
|
|
is_encrypted INTEGER DEFAULT 0,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS collections (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL UNIQUE,
|
|
description TEXT,
|
|
created_at TEXT NOT NULL
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS snippet_collections (
|
|
snippet_id INTEGER NOT NULL,
|
|
collection_id INTEGER NOT NULL,
|
|
PRIMARY KEY (snippet_id, collection_id),
|
|
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
|
|
FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS sync_peers (
|
|
peer_id TEXT PRIMARY KEY,
|
|
host TEXT NOT NULL,
|
|
port INTEGER NOT NULL,
|
|
last_seen TEXT NOT NULL
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5(
|
|
title, description, code, tags,
|
|
content='snippets',
|
|
content_rowid='id'
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN
|
|
INSERT INTO snippets_fts(rowid, title, description, code, tags)
|
|
VALUES (new.id, new.title, new.description, new.code, new.tags);
|
|
END
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN
|
|
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
|
|
VALUES ('delete', old.id, old.title, old.description, old.code, old.tags);
|
|
END
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN
|
|
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
|
|
VALUES ('delete', old.id, old.title, old.description, old.code, old.tags);
|
|
INSERT INTO snippets_fts(rowid, title, description, code, tags)
|
|
VALUES (new.id, new.title, new.description, new.code, new.tags);
|
|
END
|
|
""")
|
|
|
|
def add_snippet(
|
|
self,
|
|
title: str,
|
|
code: str,
|
|
description: str = "",
|
|
language: str = "",
|
|
tags: list[str] | None = None,
|
|
is_encrypted: bool = False,
|
|
) -> int:
|
|
tags = tags or []
|
|
now = datetime.utcnow().isoformat()
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO snippets (title, description, code, language, tags, is_encrypted, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(title, description, code, language, json.dumps(tags), int(is_encrypted), now, now),
|
|
)
|
|
return cursor.lastrowid
|
|
|
|
def get_snippet(self, snippet_id: int) -> dict[str, Any] | None:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return dict(row)
|
|
return None
|
|
|
|
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None) -> list[dict[str, Any]]:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
if tag:
|
|
cursor.execute(
|
|
"SELECT * FROM snippets WHERE tags LIKE ? ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
|
(f'%"{tag}"%', limit, offset),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"SELECT * FROM snippets ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
|
(limit, offset),
|
|
)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def update_snippet(
|
|
self,
|
|
snippet_id: int,
|
|
title: str | None = None,
|
|
description: str | None = None,
|
|
code: str | None = None,
|
|
language: str | None = None,
|
|
tags: list[str] | None = None,
|
|
) -> bool:
|
|
snippet = self.get_snippet(snippet_id)
|
|
if not snippet:
|
|
return False
|
|
|
|
now = datetime.utcnow().isoformat()
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
UPDATE snippets SET
|
|
title = COALESCE(?, title),
|
|
description = COALESCE(?, description),
|
|
code = COALESCE(?, code),
|
|
language = COALESCE(?, language),
|
|
tags = COALESCE(?, tags),
|
|
updated_at = ?
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
title,
|
|
description,
|
|
code,
|
|
language,
|
|
json.dumps(tags) if tags is not None else None,
|
|
now,
|
|
snippet_id,
|
|
),
|
|
)
|
|
return True
|
|
|
|
def delete_snippet(self, snippet_id: int) -> bool:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,))
|
|
return cursor.rowcount > 0
|
|
|
|
def search_snippets(
|
|
self,
|
|
query: str,
|
|
limit: int = 50,
|
|
language: str | None = None,
|
|
tag: str | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
if language:
|
|
fts_query = f"{query} AND language:{language}"
|
|
else:
|
|
fts_query = query
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT s.*, bm25(snippets_fts) as rank
|
|
FROM snippets s
|
|
JOIN snippets_fts ON s.id = snippets_fts.rowid
|
|
WHERE snippets_fts MATCH ?
|
|
ORDER BY rank
|
|
LIMIT ?
|
|
""",
|
|
(fts_query, limit),
|
|
)
|
|
results = [dict(row) for row in cursor.fetchall()]
|
|
|
|
if tag:
|
|
results = [r for r in results if tag in json.loads(r.get("tags", "[]"))]
|
|
|
|
return results
|
|
|
|
def add_tag(self, snippet_id: int, tag: str) -> bool:
|
|
snippet = self.get_snippet(snippet_id)
|
|
if not snippet:
|
|
return False
|
|
tags = json.loads(snippet["tags"])
|
|
if tag not in tags:
|
|
tags.append(tag)
|
|
return self.update_snippet(snippet_id, tags=tags)
|
|
return True
|
|
|
|
def remove_tag(self, snippet_id: int, tag: str) -> bool:
|
|
snippet = self.get_snippet(snippet_id)
|
|
if not snippet:
|
|
return False
|
|
tags = json.loads(snippet["tags"])
|
|
if tag in tags:
|
|
tags.remove(tag)
|
|
return self.update_snippet(snippet_id, tags=tags)
|
|
return True
|
|
|
|
def list_tags(self) -> list[str]:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT tags FROM snippets")
|
|
all_tags: set[str] = set()
|
|
for row in cursor.fetchall():
|
|
all_tags.update(json.loads(row["tags"]))
|
|
return sorted(all_tags)
|
|
|
|
def create_collection(self, name: str, description: str = "") -> int:
|
|
now = datetime.utcnow().isoformat()
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)",
|
|
(name, description, now),
|
|
)
|
|
return cursor.lastrowid
|
|
|
|
def list_collections(self) -> list[dict[str, Any]]:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM collections ORDER BY name")
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM collections WHERE id = ?", (collection_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return dict(row)
|
|
return None
|
|
|
|
def delete_collection(self, collection_id: int) -> bool:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
|
|
return cursor.rowcount > 0
|
|
|
|
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
|
|
snippet = self.get_snippet(snippet_id)
|
|
collection = self.get_collection(collection_id)
|
|
if not snippet or not collection:
|
|
return False
|
|
try:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
|
|
(snippet_id, collection_id),
|
|
)
|
|
return True
|
|
except sqlite3.IntegrityError:
|
|
return True
|
|
|
|
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM snippet_collections WHERE snippet_id = ? AND collection_id = ?",
|
|
(snippet_id, collection_id),
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def get_collection_snippets(self, collection_id: int) -> list[dict[str, Any]]:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT s.* FROM snippets s
|
|
JOIN snippet_collections sc ON s.id = sc.snippet_id
|
|
WHERE sc.collection_id = ?
|
|
ORDER BY s.updated_at DESC
|
|
""",
|
|
(collection_id,),
|
|
)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def export_all(self) -> list[dict[str, Any]]:
|
|
return self.list_snippets(limit=10000)
|
|
|
|
def import_snippet(
|
|
self,
|
|
data: dict[str, Any],
|
|
strategy: str = "skip",
|
|
) -> int | None:
|
|
existing = None
|
|
if "title" in data:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT id FROM snippets WHERE title = ?", (data["title"],))
|
|
existing = cursor.fetchone()
|
|
|
|
if existing:
|
|
if strategy == "skip":
|
|
return None
|
|
elif strategy == "replace":
|
|
self.update_snippet(
|
|
existing["id"],
|
|
title=data.get("title"),
|
|
description=data.get("description"),
|
|
code=data.get("code"),
|
|
language=data.get("language"),
|
|
tags=data.get("tags"),
|
|
)
|
|
return existing["id"]
|
|
|
|
return self.add_snippet(
|
|
title=data.get("title", "Untitled"),
|
|
code=data.get("code", ""),
|
|
description=data.get("description", ""),
|
|
language=data.get("language", ""),
|
|
tags=data.get("tags", []),
|
|
)
|
|
|
|
def add_peer(self, peer_id: str, host: str, port: int):
|
|
now = datetime.utcnow().isoformat()
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT OR REPLACE INTO sync_peers (peer_id, host, port, last_seen) VALUES (?, ?, ?, ?)",
|
|
(peer_id, host, port, now),
|
|
)
|
|
|
|
def list_peers(self) -> list[dict[str, Any]]:
|
|
with self.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
|
|
return [dict(row) for row in cursor.fetchall()]
|