fix: resolve CI test failures - API compatibility fixes
This commit is contained in:
@@ -1,132 +1,143 @@
|
|||||||
"""HTTP-based P2P sync protocol for snippets."""
|
import http.client
|
||||||
|
|
||||||
import http.server
|
|
||||||
import json
|
import json
|
||||||
import socketserver
|
|
||||||
import threading
|
import threading
|
||||||
import urllib.request
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.request import urlopen
|
|
||||||
|
|
||||||
from snip.db.database import Database, get_database
|
from ..db import get_database
|
||||||
|
|
||||||
|
|
||||||
class SyncRequestHandler(http.server.BaseHTTPRequestHandler):
|
class SyncProtocol:
|
||||||
|
def __init__(self, db_path: str | None = None, port: int = 8765):
|
||||||
|
self.port = port
|
||||||
|
self.db = get_database(db_path)
|
||||||
|
self._server: HTTPServer | None = None
|
||||||
|
self._server_thread: threading.Thread | None = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def start_server(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._server_thread = threading.Thread(target=self._run_server, daemon=True)
|
||||||
|
self._server_thread.start()
|
||||||
|
|
||||||
|
def _run_server(self):
|
||||||
|
class Handler(SyncRequestHandler):
|
||||||
|
protocol = self
|
||||||
|
self._server = HTTPServer(("0.0.0.0", self.port), Handler)
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self._server.timeout = 1.0
|
||||||
|
self._server.handle_request()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop_server(self):
|
||||||
|
self._running = False
|
||||||
|
if self._server:
|
||||||
|
try:
|
||||||
|
self._server.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._server = None
|
||||||
|
if self._server_thread:
|
||||||
|
self._server_thread.join(timeout=2.0)
|
||||||
|
self._server_thread = None
|
||||||
|
|
||||||
|
def fetch_snippets(self, peer_address: str, peer_port: int, since: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
conn = http.client.HTTPConnection(peer_address, peer_port, timeout=30)
|
||||||
|
path = "/snippets"
|
||||||
|
if since:
|
||||||
|
path += f"?since={since}"
|
||||||
|
conn.request("GET", path)
|
||||||
|
response = conn.getresponse()
|
||||||
|
if response.status == 200:
|
||||||
|
data = json.loads(response.read().decode())
|
||||||
|
return data.get("snippets", [])
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
def push_snippets(self, peer_address: str, peer_port: int, snippets: list[dict[str, Any]]) -> bool:
|
||||||
|
try:
|
||||||
|
conn = http.client.HTTPConnection(peer_address, peer_port, timeout=30)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
body = json.dumps({"snippets": snippets})
|
||||||
|
conn.request("POST", "/snippets", body, headers)
|
||||||
|
response = conn.getresponse()
|
||||||
|
conn.close()
|
||||||
|
return response.status == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
peer_id = peer["peer_id"]
|
||||||
|
addresses = peer.get("addresses", [])
|
||||||
|
port = peer.get("port", self.port)
|
||||||
|
if not addresses:
|
||||||
|
return {"status": "error", "message": "No peer address available"}
|
||||||
|
peer_address = addresses[0]
|
||||||
|
meta = self.db.get_sync_meta(peer_id)
|
||||||
|
since = meta["last_sync"] if meta else None
|
||||||
|
local_snippets = self.db.get_all_snippets_for_sync(since)
|
||||||
|
remote_snippets = self.fetch_snippets(peer_address, port, since)
|
||||||
|
merged = 0
|
||||||
|
for snippet in remote_snippets:
|
||||||
|
self.db.upsert_snippet(snippet)
|
||||||
|
merged += 1
|
||||||
|
if local_snippets:
|
||||||
|
self.push_snippets(peer_address, port, local_snippets)
|
||||||
|
self.db.update_sync_meta(peer_id, peer.get("peer_name"), peer_address, port)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"merged": merged,
|
||||||
|
"pushed": len(local_snippets),
|
||||||
|
"peer": peer_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SyncRequestHandler(BaseHTTPRequestHandler):
|
||||||
|
def log_message(self, format, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith("/snippets"):
|
if self.path.startswith("/snippets"):
|
||||||
since = self.headers.get("X-Since", "1970-01-01T00:00:00")
|
db = self.protocol.db
|
||||||
snippets = self.server.db.list_snippets(limit=10000)
|
since = None
|
||||||
snippets = [s for s in snippets if s["updated_at"] > since]
|
if "?" in self.path:
|
||||||
|
query = self.path.split("?", 1)[1]
|
||||||
|
for param in query.split("&"):
|
||||||
|
if param.startswith("since="):
|
||||||
|
since = param.split("=", 1)[1]
|
||||||
|
snippets = db.get_all_snippets_for_sync(since)
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-Type", "application/json")
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(json.dumps(snippets).encode())
|
self.wfile.write(json.dumps({"snippets": snippets}).encode())
|
||||||
else:
|
else:
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
if self.path == "/snippets":
|
if self.path.startswith("/snippets"):
|
||||||
content_length = int(self.headers["Content-Length"])
|
try:
|
||||||
data = json.loads(self.rfile.read(content_length))
|
content_length = int(self.headers.get("Content-Length", 0))
|
||||||
for snippet in data:
|
body = self.rfile.read(content_length).decode()
|
||||||
self.server.db.import_snippet(snippet, strategy="duplicate")
|
data = json.loads(body)
|
||||||
|
snippets = data.get("snippets", [])
|
||||||
|
db = self.protocol.db
|
||||||
|
for snippet in snippets:
|
||||||
|
db.upsert_snippet(snippet)
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-Type", "application/json")
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(json.dumps({"status": "ok"}).encode())
|
self.wfile.write(json.dumps({"status": "ok"}).encode())
|
||||||
|
except Exception:
|
||||||
|
self.send_response(400)
|
||||||
|
self.end_headers()
|
||||||
else:
|
else:
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
def log_message(self, format, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SyncServer(socketserver.TCPServer):
|
|
||||||
allow_reuse_address = True
|
|
||||||
|
|
||||||
def __init__(self, port: int, db: Database):
|
|
||||||
self.db = db
|
|
||||||
super().__init__(("", port), SyncRequestHandler)
|
|
||||||
|
|
||||||
|
|
||||||
class SyncProtocol:
|
|
||||||
def __init__(self, db: Database | str | None = None, port: int = 8765):
|
|
||||||
if isinstance(db, str) or db is None:
|
|
||||||
self.db = get_database(db)
|
|
||||||
else:
|
|
||||||
self.db = db
|
|
||||||
self.port = port
|
|
||||||
self.server = None
|
|
||||||
self.server_thread = None
|
|
||||||
|
|
||||||
def start_server(self):
|
|
||||||
"""Start the sync server in a background thread."""
|
|
||||||
self.server = SyncServer(self.port, self.db)
|
|
||||||
self.server_thread = threading.Thread(target=self.server.serve_forever)
|
|
||||||
self.server_thread.daemon = True
|
|
||||||
self.server_thread.start()
|
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
"""Stop the sync server."""
|
|
||||||
if self.server:
|
|
||||||
self.server.shutdown()
|
|
||||||
self.server = None
|
|
||||||
|
|
||||||
def sync_with_peer(self, peer: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Sync snippets with a peer."""
|
|
||||||
host = peer.get("host") or peer.get("addresses", [""])[0]
|
|
||||||
port = peer.get("port", 8765)
|
|
||||||
|
|
||||||
snippets = []
|
|
||||||
merged = 0
|
|
||||||
pushed = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
with urlopen(f"http://{host}:{port}/snippets", timeout=30) as response:
|
|
||||||
snippets = json.loads(response.read())
|
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
for snippet in snippets:
|
|
||||||
if "id" in snippet:
|
|
||||||
del snippet["id"]
|
|
||||||
result = self.db.import_snippet(snippet, strategy="skip")
|
|
||||||
if result is not None:
|
|
||||||
merged += 1
|
|
||||||
|
|
||||||
local_snippets = self.db.export_all()
|
|
||||||
try:
|
|
||||||
req = urllib.request.Request(
|
|
||||||
f"http://{host}:{port}/snippets",
|
|
||||||
data=json.dumps(local_snippets).encode(),
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
)
|
|
||||||
with urlopen(req, timeout=30) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
pushed = len(local_snippets)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {"status": "success", "merged": merged, "pushed": pushed}
|
|
||||||
|
|
||||||
def push_to_peer(self, host: str, port: int) -> int:
|
|
||||||
"""Push local snippets to a peer."""
|
|
||||||
snippets = self.db.export_all()
|
|
||||||
pushed = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
req = urllib.request.Request(
|
|
||||||
f"http://{host}:{port}/snippets",
|
|
||||||
data=json.dumps(snippets).encode(),
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
)
|
|
||||||
with urlopen(req, timeout=30) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
pushed = len(snippets)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return pushed
|
|
||||||
Reference in New Issue
Block a user