Fix sync protocol - sync_with_peer takes peer dict, add get_database support

This commit is contained in:
2026-03-22 11:50:04 +00:00
parent 7a4dec2e53
commit cadc435d17

View File

@@ -4,11 +4,12 @@ import http.server
import json import json
import socketserver import socketserver
import threading import threading
import urllib.request
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from urllib.request import urlopen from urllib.request import urlopen
from snip.db.database import Database from snip.db.database import Database, get_database
class SyncRequestHandler(http.server.BaseHTTPRequestHandler): class SyncRequestHandler(http.server.BaseHTTPRequestHandler):
@@ -52,8 +53,11 @@ class SyncServer(socketserver.TCPServer):
class SyncProtocol: class SyncProtocol:
def __init__(self, db: Database, port: int = 8765): def __init__(self, db: Database | str | None = None, port: int = 8765):
self.db = db if isinstance(db, str) or db is None:
self.db = get_database(db)
else:
self.db = db
self.port = port self.port = port
self.server = None self.server = None
self.server_thread = None self.server_thread = None
@@ -71,24 +75,42 @@ class SyncProtocol:
self.server.shutdown() self.server.shutdown()
self.server = None self.server = None
def sync_with_peer(self, host: str, port: int) -> int: def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
"""Sync snippets with a peer.""" """Sync snippets with a peer."""
host = peer.get("host") or peer.get("addresses", [""])[0]
port = peer.get("port", 8765)
snippets = [] snippets = []
synced = 0 merged = 0
pushed = 0
try: try:
with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response: with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response:
snippets = json.loads(response.read()) snippets = json.loads(response.read())
except Exception: except Exception as e:
pass return {"status": "error", "message": str(e)}
for snippet in snippets: for snippet in snippets:
if "id" in snippet: if "id" in snippet:
del snippet["id"] del snippet["id"]
self.db.import_snippet(snippet, strategy="skip") result = self.db.import_snippet(snippet, strategy="skip")
synced += 1 if result is not None:
merged += 1
return synced local_snippets = self.db.export_all()
try:
req = urllib.request.Request(
f"http://{host}:{port}/snippets",
data=json.dumps(local_snippets).encode(),
headers={"Content-Type": "application/json"},
)
with urlopen(req, timeout=30) as response:
if response.status == 200:
pushed = len(local_snippets)
except Exception:
pass
return {"status": "success", "merged": merged, "pushed": pushed}
def push_to_peer(self, host: str, port: int) -> int: def push_to_peer(self, host: str, port: int) -> int:
"""Push local snippets to a peer.""" """Push local snippets to a peer."""