Fix sync protocol - sync_with_peer takes peer dict, add get_database support
This commit is contained in:
@@ -4,11 +4,12 @@ import http.server
|
|||||||
import json
|
import json
|
||||||
import socketserver
|
import socketserver
|
||||||
import threading
|
import threading
|
||||||
|
import urllib.request
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
from snip.db.database import Database
|
from snip.db.database import Database, get_database
|
||||||
|
|
||||||
|
|
||||||
class SyncRequestHandler(http.server.BaseHTTPRequestHandler):
|
class SyncRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||||
@@ -52,8 +53,11 @@ class SyncServer(socketserver.TCPServer):
|
|||||||
|
|
||||||
|
|
||||||
class SyncProtocol:
|
class SyncProtocol:
|
||||||
def __init__(self, db: Database, port: int = 8765):
|
def __init__(self, db: Database | str | None = None, port: int = 8765):
|
||||||
self.db = db
|
if isinstance(db, str) or db is None:
|
||||||
|
self.db = get_database(db)
|
||||||
|
else:
|
||||||
|
self.db = db
|
||||||
self.port = port
|
self.port = port
|
||||||
self.server = None
|
self.server = None
|
||||||
self.server_thread = None
|
self.server_thread = None
|
||||||
@@ -71,24 +75,42 @@ class SyncProtocol:
|
|||||||
self.server.shutdown()
|
self.server.shutdown()
|
||||||
self.server = None
|
self.server = None
|
||||||
|
|
||||||
def sync_with_peer(self, host: str, port: int) -> int:
|
def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Sync snippets with a peer."""
|
"""Sync snippets with a peer."""
|
||||||
|
host = peer.get("host") or peer.get("addresses", [""])[0]
|
||||||
|
port = peer.get("port", 8765)
|
||||||
|
|
||||||
snippets = []
|
snippets = []
|
||||||
synced = 0
|
merged = 0
|
||||||
|
pushed = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response:
|
with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response:
|
||||||
snippets = json.loads(response.read())
|
snippets = json.loads(response.read())
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
for snippet in snippets:
|
for snippet in snippets:
|
||||||
if "id" in snippet:
|
if "id" in snippet:
|
||||||
del snippet["id"]
|
del snippet["id"]
|
||||||
self.db.import_snippet(snippet, strategy="skip")
|
result = self.db.import_snippet(snippet, strategy="skip")
|
||||||
synced += 1
|
if result is not None:
|
||||||
|
merged += 1
|
||||||
|
|
||||||
return synced
|
local_snippets = self.db.export_all()
|
||||||
|
try:
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"http://{host}:{port}/snippets",
|
||||||
|
data=json.dumps(local_snippets).encode(),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
with urlopen(req, timeout=30) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
pushed = len(local_snippets)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"status": "success", "merged": merged, "pushed": pushed}
|
||||||
|
|
||||||
def push_to_peer(self, host: str, port: int) -> int:
|
def push_to_peer(self, host: str, port: int) -> int:
|
||||||
"""Push local snippets to a peer."""
|
"""Push local snippets to a peer."""
|
||||||
@@ -107,4 +129,4 @@ class SyncProtocol:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return pushed
|
return pushed
|
||||||
Reference in New Issue
Block a user