Skip to content

Commit

Permalink
switch to uvloop (faster asyncio event loop) and optimise grpc settings
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 17, 2024
1 parent 58f0a0f commit 0a07223
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 38 deletions.
8 changes: 4 additions & 4 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.optimizers as optim
from ..inference_engine import InferenceEngine
from .sharded_utils import load_shard, get_image_from_str
from .losses import loss_fns
from .losses import loss_fns
from ..shard import Shard
from typing import Dict, Optional, Tuple
from exo.download.shard_download import ShardDownloader
Expand Down Expand Up @@ -56,7 +56,7 @@ async def save_checkpoint(self, shard: Shard, path: str):
async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
self.model.load_weights(path)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
await self.ensure_shard(shard)
state = await self.poll_state(request_id)
Expand Down Expand Up @@ -102,7 +102,7 @@ def train_step(inp, tar, lng):

score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
#print(f"{score=}")

layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])

Expand All @@ -117,7 +117,7 @@ async def ensure_shard(self, shard: Shard):
if self.shard != shard:
model_shard, self.tokenizer = await load_shard(model_path, shard)
self.shard = shard
self.model = model_shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}

72 changes: 54 additions & 18 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np
from functools import partial
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from exo.train.dataset import load_dataset, iterate_batches, compose
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.manual.network_topology_config import NetworkTopology
Expand All @@ -33,6 +32,41 @@
from exo.models import build_base_shard, get_repo
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
import uvloop
from contextlib import asynccontextmanager
import concurrent.futures
import socket
import resource
import psutil

# Configure uvloop for maximum performance
def configure_uvloop():
# Install uvloop as event loop policy
uvloop.install()

# Create new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Increase file descriptor limits on Unix systems
if not psutil.WINDOWS:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
except ValueError:
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
except ValueError:
pass

# Configure thread pool for blocking operations
loop.set_default_executor(
concurrent.futures.ThreadPoolExecutor(
max_workers=min(32, (os.cpu_count() or 1) * 4)
)
)

return loop

# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
Expand Down Expand Up @@ -223,7 +257,7 @@ def clean_path(path):
async def hold_outstanding(node: Node):
while node.outstanding_requests:
await asyncio.sleep(.5)
return
return

async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
losses = []
Expand All @@ -234,7 +268,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
tokens.append(np.sum(lengths))
total_tokens = np.sum(tokens)
total_loss = np.sum(losses) / total_tokens

return total_loss, total_tokens

async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
Expand Down Expand Up @@ -270,7 +304,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
await hold_outstanding(node)
await hold_outstanding(node)


async def main():
loop = asyncio.get_running_loop()

Expand All @@ -285,7 +319,7 @@ async def main():
{"❌ No read access" if not has_read else ""}
{"❌ No write access" if not has_write else ""}
""")

if not args.models_seed_dir is None:
try:
models_seed_dir = clean_path(args.models_seed_dir)
Expand Down Expand Up @@ -330,29 +364,31 @@ def handle_exit():
print("Error: This train ain't leaving the station without a model")
return
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)

else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()

if args.wait_for_peers > 0:
print("Cooldown to allow peers to exit gracefully")
for i in tqdm(range(50)):
await asyncio.sleep(.1)

@asynccontextmanager
async def setup_node(args):
# Rest of setup_node implementation...
pass

def run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())

except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
loop.close()

loop = None
try:
loop = configure_uvloop()
loop.run_until_complete(main())
except KeyboardInterrupt:
print("\nShutdown requested... exiting")
finally:
if loop:
loop.close()

if __name__ == "__main__":
run()
31 changes: 25 additions & 6 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ def __init__(self, _id: str, address: str, desc: str, device_capabilities: Devic
self._device_capabilities = device_capabilities
self.channel = None
self.stub = None
self.channel_options = [
("grpc.max_metadata_size", 64 * 1024 * 1024),
("grpc.max_receive_message_length", 256 * 1024 * 1024),
("grpc.max_send_message_length", 256 * 1024 * 1024),
("grpc.max_concurrent_streams", 100),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.keepalive_time_ms", 20000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_pings_without_data", 0),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
]

def id(self) -> str:
return self._id
Expand All @@ -36,11 +49,11 @@ def device_capabilities(self) -> DeviceCapabilities:

async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()

Expand All @@ -54,7 +67,13 @@ async def disconnect(self):
self.stub = None

async def _ensure_connected(self):
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
if not await self.is_connected():
try:
await asyncio.wait_for(self.connect(), timeout=10.0)
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
await self.disconnect()
raise

async def health_check(self) -> bool:
try:
Expand Down
23 changes: 13 additions & 10 deletions exo/networking/udp/udp_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, message: str, broadcast_port: int):
def connection_made(self, transport):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))


class UDPDiscovery(Discovery):
Expand Down Expand Up @@ -84,36 +84,39 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")

while True:
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
for addr, interface_name in get_all_ip_addresses_and_interfaces():
interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
"priority": interface_priority,
"interface_name": interface_name,
"interface_type": interface_type,
})
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")

transport = None
try:
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
# Create socket with explicit broadcast permission
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.bind((addr, 0))

# Create transport with the pre-configured socket
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: BroadcastProtocol(message, self.broadcast_port),
sock=sock
)
except Exception as e:
print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
finally:
if transport:
try: transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
if DEBUG_DISCOVERY >= 2: traceback.print_exc()

await asyncio.sleep(self.broadcast_interval)

async def on_listen_message(self, data, addr):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"tqdm==4.66.4",
"transformers==4.46.3",
"uuid==1.30",
"uvloop==0.21.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
]

Expand Down

0 comments on commit 0a07223

Please sign in to comment.