Fix sync protocol - sync_with_peer takes peer dict, add get_database support
This commit is contained in:
@@ -4,11 +4,12 @@ import http.server
|
||||
import json
|
||||
import socketserver
|
||||
import threading
|
||||
import urllib.request
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.request import urlopen
|
||||
|
||||
from snip.db.database import Database
|
||||
from snip.db.database import Database, get_database
|
||||
|
||||
|
||||
class SyncRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||
@@ -52,8 +53,11 @@ class SyncServer(socketserver.TCPServer):
|
||||
|
||||
|
||||
class SyncProtocol:
|
||||
def __init__(self, db: Database, port: int = 8765):
|
||||
self.db = db
|
||||
def __init__(self, db: Database | str | None = None, port: int = 8765):
|
||||
if isinstance(db, str) or db is None:
|
||||
self.db = get_database(db)
|
||||
else:
|
||||
self.db = db
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.server_thread = None
|
||||
@@ -71,24 +75,42 @@ class SyncProtocol:
|
||||
self.server.shutdown()
|
||||
self.server = None
|
||||
|
||||
def sync_with_peer(self, host: str, port: int) -> int:
|
||||
def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sync snippets with a peer."""
|
||||
host = peer.get("host") or peer.get("addresses", [""])[0]
|
||||
port = peer.get("port", 8765)
|
||||
|
||||
snippets = []
|
||||
synced = 0
|
||||
merged = 0
|
||||
pushed = 0
|
||||
|
||||
try:
|
||||
with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response:
|
||||
snippets = json.loads(response.read())
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
for snippet in snippets:
|
||||
if "id" in snippet:
|
||||
del snippet["id"]
|
||||
self.db.import_snippet(snippet, strategy="skip")
|
||||
synced += 1
|
||||
result = self.db.import_snippet(snippet, strategy="skip")
|
||||
if result is not None:
|
||||
merged += 1
|
||||
|
||||
return synced
|
||||
local_snippets = self.db.export_all()
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"http://{host}:{port}/snippets",
|
||||
data=json.dumps(local_snippets).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urlopen(req, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
pushed = len(local_snippets)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"status": "success", "merged": merged, "pushed": pushed}
|
||||
|
||||
def push_to_peer(self, host: str, port: int) -> int:
|
||||
"""Push local snippets to a peer."""
|
||||
|
||||
Reference in New Issue
Block a user