fix: resolve CI test failures - API compatibility fixes
This commit is contained in:
@@ -1,123 +1,150 @@
|
|||||||
"""JSON import/export handlers for snippets."""
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from snip.db.database import Database, get_database
|
from ..db import get_database
|
||||||
|
|
||||||
|
|
||||||
class ExportHandler:
|
class ExportHandler:
|
||||||
|
SCHEMA_VERSION = "1.0"
|
||||||
|
|
||||||
def __init__(self, db_path: str | None = None):
|
def __init__(self, db_path: str | None = None):
|
||||||
self.db = get_database(db_path)
|
self.db = get_database(db_path)
|
||||||
|
|
||||||
def export_all(self) -> dict[str, Any]:
|
def export_all(self) -> dict[str, Any]:
|
||||||
"""Export all snippets."""
|
snippets = self.db.get_all_snippets_for_sync()
|
||||||
snippets = self.db.list_snippets(limit=10000)
|
collections = self.db.collection_list()
|
||||||
for s in snippets:
|
|
||||||
if isinstance(s.get("tags"), str):
|
|
||||||
s["tags"] = json.loads(s["tags"])
|
|
||||||
return {
|
return {
|
||||||
"version": "1.0",
|
"version": self.SCHEMA_VERSION,
|
||||||
"exported_at": datetime.utcnow().isoformat() + "Z",
|
"exported_at": datetime.utcnow().isoformat(),
|
||||||
"snippets": snippets,
|
"snippets": snippets,
|
||||||
|
"collections": collections,
|
||||||
}
|
}
|
||||||
|
|
||||||
def export_collection(self, collection_id: int) -> dict[str, Any]:
|
def export_collection(self, collection_id: int) -> dict[str, Any]:
|
||||||
"""Export a specific collection."""
|
snippets = self.db.list_snippets(collection_id=collection_id, limit=1000)
|
||||||
snippets = self.db.get_collection_snippets(collection_id)
|
self.db.get_cursor()
|
||||||
for s in snippets:
|
cursor = self.db.get_cursor().__enter__()
|
||||||
if isinstance(s.get("tags"), str):
|
cursor.execute(
|
||||||
s["tags"] = json.loads(s["tags"])
|
"SELECT * FROM collections WHERE id = ?", (collection_id,)
|
||||||
|
)
|
||||||
|
collection = dict(cursor.fetchone()) if cursor.fetchone() else None
|
||||||
return {
|
return {
|
||||||
"version": "1.0",
|
"version": self.SCHEMA_VERSION,
|
||||||
"exported_at": datetime.utcnow().isoformat() + "Z",
|
"exported_at": datetime.utcnow().isoformat(),
|
||||||
|
"collection": collection,
|
||||||
"snippets": snippets,
|
"snippets": snippets,
|
||||||
}
|
}
|
||||||
|
|
||||||
def export_snippet(self, snippet_id: int) -> dict[str, Any] | None:
|
def export_snippet(self, snippet_id: int) -> dict[str, Any] | None:
|
||||||
"""Export a single snippet."""
|
|
||||||
snippet = self.db.get_snippet(snippet_id)
|
snippet = self.db.get_snippet(snippet_id)
|
||||||
if not snippet:
|
if not snippet:
|
||||||
return None
|
return None
|
||||||
if isinstance(snippet.get("tags"), str):
|
|
||||||
snippet["tags"] = json.loads(snippet["tags"])
|
|
||||||
return {
|
return {
|
||||||
"version": "1.0",
|
"version": self.SCHEMA_VERSION,
|
||||||
"exported_at": datetime.utcnow().isoformat() + "Z",
|
"exported_at": datetime.utcnow().isoformat(),
|
||||||
"snippets": [snippet],
|
"snippets": [snippet],
|
||||||
}
|
}
|
||||||
|
|
||||||
def write_export(self, file_path: str, data: dict[str, Any]):
|
def write_export(self, filepath: str | Path, export_data: dict[str, Any], indent: int = 2):
|
||||||
"""Write export data to a file."""
|
with open(filepath, "w") as f:
|
||||||
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
json.dump(export_data, f, indent=indent)
|
||||||
with open(file_path, "w") as f:
|
|
||||||
json.dump(data, f, indent=2)
|
|
||||||
|
|
||||||
def import_data(self, file_path: str, strategy: str = "skip") -> dict[str, Any]:
|
def import_data(self, filepath: str | Path, strategy: str = "skip") -> dict[str, Any]:
|
||||||
"""Import snippets from a JSON file."""
|
with open(filepath) as f:
|
||||||
with open(file_path, "r") as f:
|
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
if not self._validate_import_data(data):
|
||||||
|
raise ValueError("Invalid import file format")
|
||||||
snippets = data.get("snippets", [])
|
snippets = data.get("snippets", [])
|
||||||
imported = 0
|
collections = data.get("collections", [])
|
||||||
skipped = 0
|
results = {
|
||||||
replaced = 0
|
"imported": 0,
|
||||||
|
"skipped": 0,
|
||||||
for snippet_data in snippets:
|
"replaced": 0,
|
||||||
result = self.db.import_snippet(
|
"duplicates": 0,
|
||||||
snippet_data, strategy=strategy)
|
"errors": [],
|
||||||
if result is None:
|
|
||||||
skipped += 1
|
|
||||||
elif result == -1:
|
|
||||||
replaced += 1
|
|
||||||
else:
|
|
||||||
imported += 1
|
|
||||||
|
|
||||||
return {
|
|
||||||
"imported": imported,
|
|
||||||
"skipped": skipped,
|
|
||||||
"replaced": replaced,
|
|
||||||
"total": len(snippets),
|
|
||||||
}
|
}
|
||||||
|
for snippet in snippets:
|
||||||
|
try:
|
||||||
|
result = self._import_snippet(snippet, strategy)
|
||||||
|
if result == "imported":
|
||||||
|
results["imported"] += 1
|
||||||
|
elif result == "skipped":
|
||||||
|
results["skipped"] += 1
|
||||||
|
elif result == "replaced":
|
||||||
|
results["replaced"] += 1
|
||||||
|
elif result == "duplicate":
|
||||||
|
results["duplicates"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
results["errors"].append({"snippet": snippet.get("title", "unknown"), "error": str(e)})
|
||||||
|
for collection in collections:
|
||||||
|
try:
|
||||||
|
self._import_collection(collection)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _import_snippet(self, snippet: dict[str, Any], strategy: str) -> str:
|
||||||
|
snippet_id = snippet.get("id")
|
||||||
|
existing = None
|
||||||
|
if snippet_id:
|
||||||
|
existing = self.db.get_snippet(snippet_id)
|
||||||
|
if existing and strategy == "skip":
|
||||||
|
return "skipped"
|
||||||
|
if existing and strategy == "replace":
|
||||||
|
self.db.upsert_snippet(snippet)
|
||||||
|
return "replaced"
|
||||||
|
if existing and strategy == "duplicate":
|
||||||
|
new_snippet = dict(snippet)
|
||||||
|
new_snippet["id"] = None
|
||||||
|
new_snippet["title"] = f"{snippet.get('title', 'Snippet')} (imported)"
|
||||||
|
self.db.create_snippet(
|
||||||
|
title=new_snippet["title"],
|
||||||
|
code=new_snippet["code"],
|
||||||
|
description=new_snippet.get("description", ""),
|
||||||
|
language=new_snippet.get("language", "text"),
|
||||||
|
tags=new_snippet.get("tags", []),
|
||||||
|
is_encrypted=bool(new_snippet.get("is_encrypted")),
|
||||||
|
)
|
||||||
|
return "duplicates"
|
||||||
|
self.db.upsert_snippet(snippet)
|
||||||
|
return "imported"
|
||||||
|
|
||||||
|
def _import_collection(self, collection: dict[str, Any]):
|
||||||
|
name = collection.get("name")
|
||||||
|
if not name:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
collection_id = self.db.collection_create(
|
||||||
|
name=name,
|
||||||
|
description=collection.get("description", ""),
|
||||||
|
)
|
||||||
|
snippet_ids = collection.get("snippet_ids", [])
|
||||||
|
for snippet_id in snippet_ids:
|
||||||
|
self.db.collection_add_snippet(collection_id, snippet_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _validate_import_data(self, data: dict[str, Any]) -> bool:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return False
|
||||||
|
if "snippets" not in data and "collections" not in data:
|
||||||
|
return False
|
||||||
|
if "snippets" in data and not isinstance(data["snippets"], list):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def generate_import_summary(self, results: dict[str, Any]) -> str:
|
def generate_import_summary(self, results: dict[str, Any]) -> str:
|
||||||
"""Generate a human-readable import summary."""
|
lines = ["Import Results:", f" Imported: {results['imported']}"]
|
||||||
lines = [
|
if results["skipped"]:
|
||||||
f"Import complete!",
|
lines.append(f" Skipped: {results['skipped']}")
|
||||||
f" Imported: {results.get('imported', 0)}",
|
if results["replaced"]:
|
||||||
f" Skipped: {results.get('skipped', 0)}",
|
lines.append(f" Replaced: {results['replaced']}")
|
||||||
f" Replaced: {results.get('replaced', 0)}",
|
if results["duplicates"]:
|
||||||
]
|
lines.append(f" Duplicates: {results['duplicates']}")
|
||||||
|
if results["errors"]:
|
||||||
|
lines.append(f" Errors: {len(results['errors'])}")
|
||||||
|
for err in results["errors"][:5]:
|
||||||
|
lines.append(f" - {err['snippet']}: {err['error']}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def export_snippets(snippets: list[dict[str, Any]], file_path: str):
|
|
||||||
"""Export snippets to a JSON file (standalone function)."""
|
|
||||||
export_data = {
|
|
||||||
"version": "1.0",
|
|
||||||
"exported_at": datetime.utcnow().isoformat() + "Z",
|
|
||||||
"snippets": snippets,
|
|
||||||
}
|
|
||||||
with open(file_path, "w") as f:
|
|
||||||
json.dump(export_data, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def import_snippets(db: Database, file_path: str, strategy: str = "skip") -> tuple[int, int]:
|
|
||||||
"""Import snippets from a JSON file (standalone function)."""
|
|
||||||
with open(file_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
snippets = data.get("snippets", [])
|
|
||||||
imported = 0
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
for snippet_data in snippets:
|
|
||||||
result = db.import_snippet(snippet_data, strategy=strategy)
|
|
||||||
if result is None:
|
|
||||||
skipped += 1
|
|
||||||
else:
|
|
||||||
imported += 1
|
|
||||||
|
|
||||||
return imported, skipped
|
|
||||||
Reference in New Issue
Block a user