forked from petals-infra/health.petals.dev
-
Notifications
You must be signed in to change notification settings - Fork 0
/
p2p_utils.py
46 lines (38 loc) · 1.88 KB
/
p2p_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import asyncio
import hivemind
from async_timeout import timeout
from petals.server.handler import TransformerConnectionHandler
info_cache = hivemind.TimedStorage()
async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True):
if use_cache:
entry = info_cache.get(peer_id)
if entry is not None:
return entry.value
try:
with timeout(connect_timeout):
if fetch_info: # For Petals servers
stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id)
response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID())
rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info)
rpc_info["ok"] = True
else: # For DHT-only bootstrap peers
await node.p2p._client.connect(peer_id, [])
await node.p2p._client.disconnect(peer_id)
rpc_info = {"ok": True}
except Exception as e:
# Actual connection error
if not isinstance(e, asyncio.TimeoutError):
message = str(e) if str(e) else repr(e)
if message == "protocol not supported":
# This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1
return {"ok": True}
else:
message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections"
rpc_info = {"ok": False, "error": message}
info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration)
return rpc_info
async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False):
rpc_infos = await asyncio.gather(
*[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids]
)
return dict(zip(peer_ids, rpc_infos))