Files
agentic-codebase-memory-man…/snip/db/database.py

452 lines
16 KiB
Python

"""SQLite database with FTS5 search for snippet storage."""
import json
import os
import time
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any
import sqlite3
def get_database(db_path: str | None = None) -> "Database":
"""Get a Database instance."""
return Database(db_path)
class Database:
def __init__(self, db_path: str | None = None):
if db_path is None:
db_path = os.environ.get("SNIP_DB_PATH", "~/.snip/snippets.db")
self.db_path = os.path.expanduser(db_path)
self._ensure_dir()
self.conn = None
def _ensure_dir(self):
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
@contextmanager
def get_connection(self):
if self.conn is None:
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = sqlite3.Row
try:
yield self.conn
self.conn.commit()
except Exception:
self.conn.rollback()
raise
def init_db(self):
"""Initialize database schema."""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS snippets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
description TEXT,
code TEXT NOT NULL,
language TEXT,
tags TEXT DEFAULT '[]',
is_encrypted INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS collections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
description TEXT,
created_at TEXT NOT NULL
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS snippet_collections (
snippet_id INTEGER NOT NULL,
collection_id INTEGER NOT NULL,
PRIMARY KEY (snippet_id, collection_id),
FOREIGN KEY (snippet_id) REFERENCES snippets(id) ON DELETE CASCADE,
FOREIGN KEY (collection_id) REFERENCES collections(id) ON DELETE CASCADE
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_peers (
peer_id TEXT PRIMARY KEY,
peer_name TEXT,
peer_address TEXT,
port INTEGER,
last_sync TEXT,
last_seen TEXT NOT NULL
)
""")
cursor.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS snippets_fts USING fts5(
title, description, code, tags,
content='snippets',
content_rowid='id'
)
""")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS snippets_ai AFTER INSERT ON snippets BEGIN
INSERT INTO snippets_fts(rowid, title, description, code, tags)
VALUES (new.id, new.title, new.description, new.code, new.tags);
END
""")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS snippets_ad AFTER DELETE ON snippets BEGIN
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
VALUES ('delete', old.id, old.title, old.description, old.code, old.tags);
END
""")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS snippets_au AFTER UPDATE ON snippets BEGIN
INSERT INTO snippets_fts(snippets_fts, rowid, title, description, code, tags)
VALUES ('delete', old.id, old.title, old.description, old.code, old.tags);
INSERT INTO snippets_fts(rowid, title, description, code, tags)
VALUES (new.id, new.title, new.description, new.code, new.tags);
END
""")
def init_schema(self):
"""Alias for init_db for backwards compatibility."""
self.init_db()
def create_snippet(
self,
title: str,
code: str,
description: str = "",
language: str = "",
tags: list[str] | None = None,
is_encrypted: bool = False,
) -> int:
"""Alias for add_snippet for backwards compatibility."""
return self.add_snippet(title, code, description, language, tags, is_encrypted)
def add_snippet(
self,
title: str,
code: str,
description: str = "",
language: str = "",
tags: list[str] | None = None,
is_encrypted: bool = False,
) -> int:
tags = tags or []
now = datetime.utcnow().isoformat()
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO snippets (title, description, code, language, tags, is_encrypted, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(title, description, code, language, json.dumps(tags), int(is_encrypted), now, now),
)
return cursor.lastrowid
def get_snippet(self, snippet_id: int) -> dict[str, Any] | None:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM snippets WHERE id = ?", (snippet_id,))
row = cursor.fetchone()
if row:
return dict(row)
return None
def list_snippets(self, limit: int = 50, offset: int = 0, tag: str | None = None, language: str | None = None, collection_id: int | None = None) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
query = "SELECT * FROM snippets WHERE 1=1"
params = []
if tag:
query += " AND tags LIKE ?"
params.append(f'%"{tag}"%')
if language:
query += " AND language = ?"
params.append(language)
if collection_id:
query += " AND id IN (SELECT snippet_id FROM snippet_collections WHERE collection_id = ?)"
params.append(collection_id)
query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def update_snippet(
self,
snippet_id: int,
title: str | None = None,
description: str | None = None,
code: str | None = None,
language: str | None = None,
tags: list[str] | None = None,
) -> bool:
snippet = self.get_snippet(snippet_id)
if not snippet:
return False
now = datetime.utcnow().isoformat()
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE snippets SET
title = COALESCE(?, title),
description = COALESCE(?, description),
code = COALESCE(?, code),
language = COALESCE(?, language),
tags = COALESCE(?, tags),
updated_at = ?
WHERE id = ?
""",
(
title,
description,
code,
language,
json.dumps(tags) if tags is not None else None,
now,
snippet_id,
),
)
return True
def delete_snippet(self, snippet_id: int) -> bool:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM snippets WHERE id = ?", (snippet_id,))
return cursor.rowcount > 0
def search_snippets(
self,
query: str,
limit: int = 50,
language: str | None = None,
tag: str | None = None,
) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
if language:
fts_query = f"{query} AND language:{language}"
else:
fts_query = query
cursor.execute(
"""
SELECT s.*, bm25(snippets_fts) as rank
FROM snippets s
JOIN snippets_fts ON s.id = snippets_fts.rowid
WHERE snippets_fts MATCH ?
ORDER BY rank
LIMIT ?
""",
(fts_query, limit),
)
results = [dict(row) for row in cursor.fetchall()]
if tag:
results = [r for r in results if tag in json.loads(r.get("tags", "[]"))]
return results
def tag_add(self, snippet_id: int, tag: str) -> bool:
snippet = self.get_snippet(snippet_id)
if not snippet:
return False
tags = json.loads(snippet["tags"])
if tag not in tags:
tags.append(tag)
return self.update_snippet(snippet_id, tags=tags)
return True
def tag_remove(self, snippet_id: int, tag: str) -> bool:
snippet = self.get_snippet(snippet_id)
if not snippet:
return False
tags = json.loads(snippet["tags"])
if tag in tags:
tags.remove(tag)
return self.update_snippet(snippet_id, tags=tags)
return True
def list_tags(self) -> list[str]:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT tags FROM snippets")
all_tags: set[str] = set()
for row in cursor.fetchall():
all_tags.update(json.loads(row["tags"]))
return sorted(all_tags)
def collection_create(self, name: str, description: str = "") -> int:
"""Alias for create_collection for backwards compatibility."""
return self.create_collection(name, description)
def create_collection(self, name: str, description: str = "") -> int:
now = datetime.utcnow().isoformat()
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO collections (name, description, created_at) VALUES (?, ?, ?)",
(name, description, now),
)
return cursor.lastrowid
def collection_list(self) -> list[dict[str, Any]]:
"""Alias for list_collections for backwards compatibility."""
return self.list_collections()
def list_collections(self) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT c.*, COUNT(sc.snippet_id) as snippet_count
FROM collections c
LEFT JOIN snippet_collections sc ON c.id = sc.collection_id
GROUP BY c.id
ORDER BY c.name
""")
return [dict(row) for row in cursor.fetchall()]
def get_collection(self, collection_id: int) -> dict[str, Any] | None:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM collections WHERE id = ?", (collection_id,))
row = cursor.fetchone()
if row:
return dict(row)
return None
def collection_delete(self, collection_id: int) -> bool:
"""Alias for delete_collection for backwards compatibility."""
return self.delete_collection(collection_id)
def delete_collection(self, collection_id: int) -> bool:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM collections WHERE id = ?", (collection_id,))
return cursor.rowcount > 0
def collection_add_snippet(self, collection_id: int, snippet_id: int) -> bool:
"""Alias for add_snippet_to_collection for backwards compatibility."""
return self.add_snippet_to_collection(snippet_id, collection_id)
def add_snippet_to_collection(self, snippet_id: int, collection_id: int) -> bool:
snippet = self.get_snippet(snippet_id)
collection = self.get_collection(collection_id)
if not snippet or not collection:
return False
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR IGNORE INTO snippet_collections (snippet_id, collection_id) VALUES (?, ?)",
(snippet_id, collection_id),
)
return True
except sqlite3.IntegrityError:
return True
def collection_remove_snippet(self, collection_id: int, snippet_id: int) -> bool:
"""Alias for remove_snippet_from_collection for backwards compatibility."""
return self.remove_snippet_from_collection(snippet_id, collection_id)
def remove_snippet_from_collection(self, snippet_id: int, collection_id: int) -> bool:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM snippet_collections WHERE snippet_id = ? AND collection_id = ?",
(snippet_id, collection_id),
)
return cursor.rowcount > 0
def get_collection_snippets(self, collection_id: int) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT s.* FROM snippets s
JOIN snippet_collections sc ON s.id = sc.snippet_id
WHERE sc.collection_id = ?
ORDER BY s.updated_at DESC
""",
(collection_id,),
)
return [dict(row) for row in cursor.fetchall()]
def export_all(self) -> list[dict[str, Any]]:
return self.list_snippets(limit=10000)
def import_snippet(
self,
data: dict[str, Any],
strategy: str = "skip",
) -> int | None:
existing = None
if "title" in data:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT id FROM snippets WHERE title = ?", (data["title"],))
existing = cursor.fetchone()
if existing:
if strategy == "skip":
return None
elif strategy == "replace":
self.update_snippet(
existing["id"],
title=data.get("title"),
description=data.get("description"),
code=data.get("code"),
language=data.get("language"),
tags=data.get("tags"),
)
return existing["id"]
return self.add_snippet(
title=data.get("title", "Untitled"),
code=data.get("code", ""),
description=data.get("description", ""),
language=data.get("language", ""),
tags=data.get("tags", []),
)
def add_peer(self, peer_id: str, host: str, port: int):
now = datetime.utcnow().isoformat()
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO sync_peers (peer_id, host, port, last_seen) VALUES (?, ?, ?, ?)",
(peer_id, host, port, now),
)
def list_peers(self) -> list[dict[str, Any]]:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM sync_peers ORDER BY last_seen DESC")
return [dict(row) for row in cursor.fetchall()]
def list_sync_peers(self) -> list[dict[str, Any]]:
"""Alias for list_peers for backwards compatibility."""
return self.list_peers()
def update_sync_meta(self, peer_id: str, last_sync: str):
"""Update sync metadata for a peer."""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE sync_peers SET last_sync = ? WHERE peer_id = ?",
(last_sync, peer_id),
)