import json from datetime import datetime from pathlib import Path from typing import Any from ..db import get_database class ExportHandler: SCHEMA_VERSION = "1.0" def __init__(self, db_path: str | None = None): self.db = get_database(db_path) def export_all(self) -> dict[str, Any]: snippets = self.db.get_all_snippets_for_sync() collections = self.db.collection_list() return { "version": self.SCHEMA_VERSION, "exported_at": datetime.utcnow().isoformat(), "snippets": snippets, "collections": collections, } def export_collection(self, collection_id: int) -> dict[str, Any]: snippets = self.db.list_snippets(collection_id=collection_id, limit=1000) self.db.get_cursor() cursor = self.db.get_cursor().__enter__() cursor.execute( "SELECT * FROM collections WHERE id = ?", (collection_id,) ) collection = dict(cursor.fetchone()) if cursor.fetchone() else None return { "version": self.SCHEMA_VERSION, "exported_at": datetime.utcnow().isoformat(), "collection": collection, "snippets": snippets, } def export_snippet(self, snippet_id: int) -> dict[str, Any] | None: snippet = self.db.get_snippet(snippet_id) if not snippet: return None return { "version": self.SCHEMA_VERSION, "exported_at": datetime.utcnow().isoformat(), "snippets": [snippet], } def write_export(self, filepath: str | Path, export_data: dict[str, Any], indent: int = 2): with open(filepath, "w") as f: json.dump(export_data, f, indent=indent) def import_data(self, filepath: str | Path, strategy: str = "skip") -> dict[str, Any]: with open(filepath) as f: data = json.load(f) if not self._validate_import_data(data): raise ValueError("Invalid import file format") snippets = data.get("snippets", []) collections = data.get("collections", []) results = { "imported": 0, "skipped": 0, "replaced": 0, "duplicates": 0, "errors": [], } for snippet in snippets: try: result = self._import_snippet(snippet, strategy) if result == "imported": results["imported"] += 1 elif result == "skipped": results["skipped"] += 1 elif result == "replaced": results["replaced"] += 1 elif result == "duplicate": results["duplicates"] += 1 except Exception as e: results["errors"].append({"snippet": snippet.get("title", "unknown"), "error": str(e)}) for collection in collections: try: self._import_collection(collection) except Exception: pass return results def _import_snippet(self, snippet: dict[str, Any], strategy: str) -> str: snippet_id = snippet.get("id") existing = None if snippet_id: existing = self.db.get_snippet(snippet_id) if existing and strategy == "skip": return "skipped" if existing and strategy == "replace": self.db.upsert_snippet(snippet) return "replaced" if existing and strategy == "duplicate": new_snippet = dict(snippet) new_snippet["id"] = None new_snippet["title"] = f"{snippet.get('title', 'Snippet')} (imported)" self.db.create_snippet( title=new_snippet["title"], code=new_snippet["code"], description=new_snippet.get("description", ""), language=new_snippet.get("language", "text"), tags=new_snippet.get("tags", []), is_encrypted=bool(new_snippet.get("is_encrypted")), ) return "duplicates" self.db.upsert_snippet(snippet) return "imported" def _import_collection(self, collection: dict[str, Any]): name = collection.get("name") if not name: return try: collection_id = self.db.collection_create( name=name, description=collection.get("description", ""), ) snippet_ids = collection.get("snippet_ids", []) for snippet_id in snippet_ids: self.db.collection_add_snippet(collection_id, snippet_id) except Exception: pass def _validate_import_data(self, data: dict[str, Any]) -> bool: if not isinstance(data, dict): return False if "snippets" not in data and "collections" not in data: return False if "snippets" in data and not isinstance(data["snippets"], list): return False return True def generate_import_summary(self, results: dict[str, Any]) -> str: lines = ["Import Results:", f" Imported: {results['imported']}"] if results["skipped"]: lines.append(f" Skipped: {results['skipped']}") if results["replaced"]: lines.append(f" Replaced: {results['replaced']}") if results["duplicates"]: lines.append(f" Duplicates: {results['duplicates']}") if results["errors"]: lines.append(f" Errors: {len(results['errors'])}") for err in results["errors"][:5]: lines.append(f" - {err['snippet']}: {err['error']}") return "\n".join(lines)