Files
schema2mock/snip/export/handlers.py

151 lines
5.6 KiB
Python

import json
from datetime import datetime
from pathlib import Path
from typing import Any
from ..db import get_database
class ExportHandler:
SCHEMA_VERSION = "1.0"
def __init__(self, db_path: str | None = None):
self.db = get_database(db_path)
def export_all(self) -> dict[str, Any]:
snippets = self.db.get_all_snippets_for_sync()
collections = self.db.collection_list()
return {
"version": self.SCHEMA_VERSION,
"exported_at": datetime.utcnow().isoformat(),
"snippets": snippets,
"collections": collections,
}
def export_collection(self, collection_id: int) -> dict[str, Any]:
snippets = self.db.list_snippets(collection_id=collection_id, limit=1000)
self.db.get_cursor()
cursor = self.db.get_cursor().__enter__()
cursor.execute(
"SELECT * FROM collections WHERE id = ?", (collection_id,)
)
collection = dict(cursor.fetchone()) if cursor.fetchone() else None
return {
"version": self.SCHEMA_VERSION,
"exported_at": datetime.utcnow().isoformat(),
"collection": collection,
"snippets": snippets,
}
def export_snippet(self, snippet_id: int) -> dict[str, Any] | None:
snippet = self.db.get_snippet(snippet_id)
if not snippet:
return None
return {
"version": self.SCHEMA_VERSION,
"exported_at": datetime.utcnow().isoformat(),
"snippets": [snippet],
}
def write_export(self, filepath: str | Path, export_data: dict[str, Any], indent: int = 2):
with open(filepath, "w") as f:
json.dump(export_data, f, indent=indent)
def import_data(self, filepath: str | Path, strategy: str = "skip") -> dict[str, Any]:
with open(filepath) as f:
data = json.load(f)
if not self._validate_import_data(data):
raise ValueError("Invalid import file format")
snippets = data.get("snippets", [])
collections = data.get("collections", [])
results = {
"imported": 0,
"skipped": 0,
"replaced": 0,
"duplicates": 0,
"errors": [],
}
for snippet in snippets:
try:
result = self._import_snippet(snippet, strategy)
if result == "imported":
results["imported"] += 1
elif result == "skipped":
results["skipped"] += 1
elif result == "replaced":
results["replaced"] += 1
elif result == "duplicate":
results["duplicates"] += 1
except Exception as e:
results["errors"].append({"snippet": snippet.get("title", "unknown"), "error": str(e)})
for collection in collections:
try:
self._import_collection(collection)
except Exception:
pass
return results
def _import_snippet(self, snippet: dict[str, Any], strategy: str) -> str:
snippet_id = snippet.get("id")
existing = None
if snippet_id:
existing = self.db.get_snippet(snippet_id)
if existing and strategy == "skip":
return "skipped"
if existing and strategy == "replace":
self.db.upsert_snippet(snippet)
return "replaced"
if existing and strategy == "duplicate":
new_snippet = dict(snippet)
new_snippet["id"] = None
new_snippet["title"] = f"{snippet.get('title', 'Snippet')} (imported)"
self.db.create_snippet(
title=new_snippet["title"],
code=new_snippet["code"],
description=new_snippet.get("description", ""),
language=new_snippet.get("language", "text"),
tags=new_snippet.get("tags", []),
is_encrypted=bool(new_snippet.get("is_encrypted")),
)
return "duplicates"
self.db.upsert_snippet(snippet)
return "imported"
def _import_collection(self, collection: dict[str, Any]):
name = collection.get("name")
if not name:
return
try:
collection_id = self.db.collection_create(
name=name,
description=collection.get("description", ""),
)
snippet_ids = collection.get("snippet_ids", [])
for snippet_id in snippet_ids:
self.db.collection_add_snippet(collection_id, snippet_id)
except Exception:
pass
def _validate_import_data(self, data: dict[str, Any]) -> bool:
if not isinstance(data, dict):
return False
if "snippets" not in data and "collections" not in data:
return False
if "snippets" in data and not isinstance(data["snippets"], list):
return False
return True
def generate_import_summary(self, results: dict[str, Any]) -> str:
lines = ["Import Results:", f" Imported: {results['imported']}"]
if results["skipped"]:
lines.append(f" Skipped: {results['skipped']}")
if results["replaced"]:
lines.append(f" Replaced: {results['replaced']}")
if results["duplicates"]:
lines.append(f" Duplicates: {results['duplicates']}")
if results["errors"]:
lines.append(f" Errors: {len(results['errors'])}")
for err in results["errors"][:5]:
lines.append(f" - {err['snippet']}: {err['error']}")
return "\n".join(lines)