"""Tests for import/export functionality.""" import json import os import tempfile import pytest from snip.db.database import Database from snip.export.handlers import export_snippets, import_snippets @pytest.fixture def db(): with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: db_path = f.name database = Database(db_path) database.init_db() yield database os.unlink(db_path) @pytest.fixture def export_file(): with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: file_path = f.name yield file_path if os.path.exists(file_path): os.unlink(file_path) def test_export_all(db, export_file): """Test exporting all snippets.""" db.add_snippet(title="Test 1", code="code1", tags=["test"]) db.add_snippet(title="Test 2", code="code2", tags=["test"]) snippets = db.export_all() export_snippets(snippets, export_file) with open(export_file, "r") as f: data = json.load(f) assert data["version"] == "1.0" assert "exported_at" in data assert len(data["snippets"]) == 2 def test_import_skip_strategy(db, export_file): """Test import with skip strategy.""" db.add_snippet(title="Existing", code="existing_code") snippets = [{"title": "Existing", "code": "new_code"}, {"title": "New", "code": "new_code"}] export_snippets(snippets, export_file) imported, skipped = import_snippets(db, export_file, strategy="skip") assert imported == 1 assert skipped == 1 def test_import_replace_strategy(db, export_file): """Test import with replace strategy.""" snippet_id = db.add_snippet(title="Existing", code="old_code") snippets = [{"title": "Existing", "code": "new_code"}] export_snippets(snippets, export_file) imported, skipped = import_snippets(db, export_file, strategy="replace") assert imported == 1 updated = db.get_snippet(snippet_id) assert updated["code"] == "new_code" def test_import_duplicate_strategy(db, export_file): """Test import with duplicate strategy.""" db.add_snippet(title="Existing", code="existing") snippets = [{"title": "Existing", "code": "existing"}] export_snippets(snippets, export_file) imported, skipped = import_snippets(db, export_file, strategy="duplicate") assert imported == 1 all_snippets = db.list_snippets(limit=100) assert len(all_snippets) == 2