fix: resolve CI test failures - API compatibility fixes
This commit is contained in:
@@ -1,132 +1,143 @@
|
||||
"""HTTP-based P2P sync protocol for snippets."""
|
||||
|
||||
import http.server
|
||||
import http.client
|
||||
import json
|
||||
import socketserver
|
||||
import threading
|
||||
import urllib.request
|
||||
from datetime import datetime
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any
|
||||
from urllib.request import urlopen
|
||||
|
||||
from snip.db.database import Database, get_database
|
||||
from ..db import get_database
|
||||
|
||||
|
||||
class SyncRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||
class SyncProtocol:
|
||||
def __init__(self, db_path: str | None = None, port: int = 8765):
|
||||
self.port = port
|
||||
self.db = get_database(db_path)
|
||||
self._server: HTTPServer | None = None
|
||||
self._server_thread: threading.Thread | None = None
|
||||
self._running = False
|
||||
|
||||
def start_server(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._server_thread = threading.Thread(target=self._run_server, daemon=True)
|
||||
self._server_thread.start()
|
||||
|
||||
def _run_server(self):
|
||||
class Handler(SyncRequestHandler):
|
||||
protocol = self
|
||||
self._server = HTTPServer(("0.0.0.0", self.port), Handler)
|
||||
while self._running:
|
||||
try:
|
||||
self._server.timeout = 1.0
|
||||
self._server.handle_request()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def stop_server(self):
|
||||
self._running = False
|
||||
if self._server:
|
||||
try:
|
||||
self._server.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
self._server = None
|
||||
if self._server_thread:
|
||||
self._server_thread.join(timeout=2.0)
|
||||
self._server_thread = None
|
||||
|
||||
def fetch_snippets(self, peer_address: str, peer_port: int, since: str | None = None) -> list[dict[str, Any]]:
|
||||
try:
|
||||
conn = http.client.HTTPConnection(peer_address, peer_port, timeout=30)
|
||||
path = "/snippets"
|
||||
if since:
|
||||
path += f"?since={since}"
|
||||
conn.request("GET", path)
|
||||
response = conn.getresponse()
|
||||
if response.status == 200:
|
||||
data = json.loads(response.read().decode())
|
||||
return data.get("snippets", [])
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
def push_snippets(self, peer_address: str, peer_port: int, snippets: list[dict[str, Any]]) -> bool:
|
||||
try:
|
||||
conn = http.client.HTTPConnection(peer_address, peer_port, timeout=30)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
body = json.dumps({"snippets": snippets})
|
||||
conn.request("POST", "/snippets", body, headers)
|
||||
response = conn.getresponse()
|
||||
conn.close()
|
||||
return response.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
|
||||
peer_id = peer["peer_id"]
|
||||
addresses = peer.get("addresses", [])
|
||||
port = peer.get("port", self.port)
|
||||
if not addresses:
|
||||
return {"status": "error", "message": "No peer address available"}
|
||||
peer_address = addresses[0]
|
||||
meta = self.db.get_sync_meta(peer_id)
|
||||
since = meta["last_sync"] if meta else None
|
||||
local_snippets = self.db.get_all_snippets_for_sync(since)
|
||||
remote_snippets = self.fetch_snippets(peer_address, port, since)
|
||||
merged = 0
|
||||
for snippet in remote_snippets:
|
||||
self.db.upsert_snippet(snippet)
|
||||
merged += 1
|
||||
if local_snippets:
|
||||
self.push_snippets(peer_address, port, local_snippets)
|
||||
self.db.update_sync_meta(peer_id, peer.get("peer_name"), peer_address, port)
|
||||
return {
|
||||
"status": "success",
|
||||
"merged": merged,
|
||||
"pushed": len(local_snippets),
|
||||
"peer": peer_id,
|
||||
}
|
||||
|
||||
|
||||
class SyncRequestHandler(BaseHTTPRequestHandler):
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith("/snippets"):
|
||||
since = self.headers.get("X-Since", "1970-01-01T00:00:00")
|
||||
snippets = self.server.db.list_snippets(limit=10000)
|
||||
snippets = [s for s in snippets if s["updated_at"] > since]
|
||||
db = self.protocol.db
|
||||
since = None
|
||||
if "?" in self.path:
|
||||
query = self.path.split("?", 1)[1]
|
||||
for param in query.split("&"):
|
||||
if param.startswith("since="):
|
||||
since = param.split("=", 1)[1]
|
||||
snippets = db.get_all_snippets_for_sync(since)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(snippets).encode())
|
||||
self.wfile.write(json.dumps({"snippets": snippets}).encode())
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
if self.path == "/snippets":
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
data = json.loads(self.rfile.read(content_length))
|
||||
for snippet in data:
|
||||
self.server.db.import_snippet(snippet, strategy="duplicate")
|
||||
if self.path.startswith("/snippets"):
|
||||
try:
|
||||
content_length = int(self.headers.get("Content-Length", 0))
|
||||
body = self.rfile.read(content_length).decode()
|
||||
data = json.loads(body)
|
||||
snippets = data.get("snippets", [])
|
||||
db = self.protocol.db
|
||||
for snippet in snippets:
|
||||
db.upsert_snippet(snippet)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "ok"}).encode())
|
||||
except Exception:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
|
||||
class SyncServer(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
|
||||
def __init__(self, port: int, db: Database):
|
||||
self.db = db
|
||||
super().__init__(("", port), SyncRequestHandler)
|
||||
|
||||
|
||||
class SyncProtocol:
|
||||
def __init__(self, db: Database | str | None = None, port: int = 8765):
|
||||
if isinstance(db, str) or db is None:
|
||||
self.db = get_database(db)
|
||||
else:
|
||||
self.db = db
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.server_thread = None
|
||||
|
||||
def start_server(self):
|
||||
"""Start the sync server in a background thread."""
|
||||
self.server = SyncServer(self.port, self.db)
|
||||
self.server_thread = threading.Thread(target=self.server.serve_forever)
|
||||
self.server_thread.daemon = True
|
||||
self.server_thread.start()
|
||||
|
||||
def stop_server(self):
|
||||
"""Stop the sync server."""
|
||||
if self.server:
|
||||
self.server.shutdown()
|
||||
self.server = None
|
||||
|
||||
def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sync snippets with a peer."""
|
||||
host = peer.get("host") or peer.get("addresses", [""])[0]
|
||||
port = peer.get("port", 8765)
|
||||
|
||||
snippets = []
|
||||
merged = 0
|
||||
pushed = 0
|
||||
|
||||
try:
|
||||
with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response:
|
||||
snippets = json.loads(response.read())
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
for snippet in snippets:
|
||||
if "id" in snippet:
|
||||
del snippet["id"]
|
||||
result = self.db.import_snippet(snippet, strategy="skip")
|
||||
if result is not None:
|
||||
merged += 1
|
||||
|
||||
local_snippets = self.db.export_all()
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"http://{host}:{port}/snippets",
|
||||
data=json.dumps(local_snippets).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urlopen(req, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
pushed = len(local_snippets)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"status": "success", "merged": merged, "pushed": pushed}
|
||||
|
||||
def push_to_peer(self, host: str, port: int) -> int:
|
||||
"""Push local snippets to a peer."""
|
||||
snippets = self.db.export_all()
|
||||
pushed = 0
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"http://{host}:{port}/snippets",
|
||||
data=json.dumps(snippets).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urlopen(req, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
pushed = len(snippets)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return pushed
|
||||
Reference in New Issue
Block a user