From e7d165f478af8c4b2bf51f564a8dc40e1e289d4a Mon Sep 17 00:00:00 2001 From: ghostplant Date: Sat, 25 Dec 2021 10:19:44 +0000 Subject: [PATCH] add fast launch method based on openmpi (#69) --- README.md | 10 ++- tutel/launcher/__init__.py | 3 + tutel/launcher/execl.py | 42 +++++++++++ tutel/launcher/run.py | 25 +++++++ tutel/system_init.py | 150 +++++++++++++++++++------------------ 5 files changed, 153 insertions(+), 77 deletions(-) create mode 100644 tutel/launcher/__init__.py create mode 100644 tutel/launcher/execl.py create mode 100644 tutel/launcher/run.py diff --git a/README.md b/README.md index 09760404..dde12c7e 100644 --- a/README.md +++ b/README.md @@ -32,13 +32,15 @@ How to setup Tutel MoE for Pytorch: * Run Tutel MoE in Distributed Mode: - (Single-Node Multi-GPU based on standard Pytorch distributed launcher:) - $ python3 -m torch.distributed.launch --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16 - - (Multi-Node Multi-GPU based on standard Pytorch distributed launcher:) + (Method A - Torch launcher for `Multi-Node x Multi-GPU`:) $ ssh python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr= -m tutel.examples.helloworld --batch_size=16 $ ssh python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr= -m tutel.examples.helloworld --batch_size=16 + (Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:) + $ mpiexec -host ,,.. \ + -x LOCAL_SIZE=8 -x MASTER_ADDR= \ + python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16 + ``` How to import Tutel-optimized MoE in Pytorch: diff --git a/tutel/launcher/__init__.py b/tutel/launcher/__init__.py new file mode 100644 index 00000000..c45e0a75 --- /dev/null +++ b/tutel/launcher/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + diff --git a/tutel/launcher/execl.py b/tutel/launcher/execl.py new file mode 100644 index 00000000..b1d020f3 --- /dev/null +++ b/tutel/launcher/execl.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os, re, sys +import logging +import argparse + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', default=False, action='store_true') + parser.add_argument('rest', nargs=argparse.REMAINDER) + args = parser.parse_args() + + local_rank = int(os.environ['LOCAL_RANK']) + local_size = int(os.environ['LOCAL_SIZE']) + + os.environ['TUTEL_CUDA_SANDBOX'] = '1' + os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) + + cmd_args = [] + try: + if not os.path.exists('/usr/bin/numactl'): + raise + local_size = int(os.environ['LOCAL_SIZE']) + cpu_nodes = sorted([str(x[4:]) for x in os.listdir('/sys/devices/system/node') if re.match('node[0-9]+', x)]) + if len(cpu_nodes) <= local_size: + sel_nodes = cpu_nodes[(local_rank // (local_size // len(cpu_nodes))) % len(cpu_nodes)] + else: + sel_nodes = cpu_nodes[local_rank::local_size] + sel_nodes = ','.join(sel_nodes) + + cmd_args = ['/usr/bin/numactl', '--cpunodebind=%s' % sel_nodes] + except Exception as ex: + if local_rank == 0: + logging.warning('`numactl` is not enabled by tutel.launcher.execl') + + cmd_args += [sys.executable, '-m'] if args.m else [] + cmd_args += args.rest + os.execl(cmd_args[0], *cmd_args) + +if __name__ == "__main__": + main() diff --git a/tutel/launcher/run.py b/tutel/launcher/run.py new file mode 100644 index 00000000..4e1f345f --- /dev/null +++ b/tutel/launcher/run.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os, sys + +def main(): + host_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + host_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_size = int(os.environ.get('LOCAL_SIZE', 1)) + + master_addr = os.environ['MASTER_ADDR'] if host_size > 1 else 'localhost' + master_port = int(os.environ.get('MASTER_PORT', 23232)) + + cmd_args = [sys.executable, '-m', 'torch.distributed.launch', '--use_env', + '--nproc_per_node=%d' % local_size, + '--nnodes=%d' % host_size, + '--node_rank=%d' % host_rank, + '--master_addr=%s' % master_addr, + '--master_port=%s' % master_port, + '-m', 'tutel.launcher.execl', + ] + sys.argv[1:] + os.execl(cmd_args[0], *cmd_args) + +if __name__ == "__main__": + main() diff --git a/tutel/system_init.py b/tutel/system_init.py index 8153e87d..d1300720 100644 --- a/tutel/system_init.py +++ b/tutel/system_init.py @@ -25,77 +25,81 @@ def init_affinity_at_program_beginning(): logging.warning('Failed to set NUMA status: %s' % ex) def init_data_model_parallel(group_count=1, backend='nccl'): - import torch - import torch.distributed as dist - try: - if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ): - dist.init_process_group(backend=backend, - init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')), - rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE'])) - dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + import torch + import torch.distributed as dist + try: + if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ): + dist.init_process_group(backend=backend, + init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')), + rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE'])) + dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + else: + dist.init_process_group(backend=backend) + dist_local_rank = int(os.environ.get('LOCAL_RANK', 0)) + if TUTEL_CUDA_SANDBOX: + dist_local_rank = 0 + glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank() + is_distributed = True + + def dist_print(*args): + if glob_world_rank == 0: + print(*args) + except ValueError: + glob_world_size, glob_world_rank, dist_local_rank = 1, 0, 0 + is_distributed = False + dist_print = print + + assert glob_world_size % group_count == 0, f"Expected to evenly divide devices into {group_count} groups, while the world size of current sesion is {glob_world_size}." + + dist_group_size = group_count + dist_world_size = glob_world_size // dist_group_size + dist_world_rank = glob_world_rank % dist_world_size + dist_group_rank = glob_world_rank // dist_world_size + + if is_distributed: + global_group = model_group = data_group = dist.group.WORLD + + if dist_group_size != glob_world_size: + groups, inner_ranks = [], [] + for gr in range(dist_group_size): + group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)] + groups += [dist.new_group(ranks=group_ranks)] + inner_ranks += [group_ranks] + model_group = groups[dist_group_rank] + + if dist_world_size != glob_world_size: + groups, outer_ranks = [], [] + for gr in range(dist_world_size): + group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)] + groups += [dist.new_group(ranks=group_ranks)] + outer_ranks += [group_ranks] + data_group = groups[dist_world_rank] else: - dist.init_process_group(backend=backend) - dist_local_rank = int(os.environ.get('LOCAL_RANK', 0)) - if TUTEL_CUDA_SANDBOX: - dist_local_rank = 0 - glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank() - is_distributed = True - - def dist_print(*args): - if glob_world_rank == 0: - print(*args) - except ValueError: - glob_world_size, glob_world_rank, dist_local_rank = 1, 0, 0 - is_distributed = False - dist_print = print - - assert glob_world_size % group_count == 0, f"Expected to evenly divide devices into {group_count} groups, while the world size of current sesion is {glob_world_size}." - - dist_group_size = group_count - dist_world_size = glob_world_size // dist_group_size - dist_world_rank = glob_world_rank % dist_world_size - dist_group_rank = glob_world_rank // dist_world_size - - if is_distributed: - global_group = model_group = data_group = dist.group.WORLD - - if dist_group_size != glob_world_size: - groups, inner_ranks = [], [] - for gr in range(dist_group_size): - group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)] - groups += [dist.new_group(ranks=group_ranks)] - inner_ranks += [group_ranks] - model_group = groups[dist_group_rank] - - if dist_world_size != glob_world_size: - groups, outer_ranks = [], [] - for gr in range(dist_world_size): - group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)] - groups += [dist.new_group(ranks=group_ranks)] - outer_ranks += [group_ranks] - data_group = groups[dist_world_rank] - else: - model_group, data_group, global_group = None, None, None - - result = init_data_model_parallel - result.global_size = glob_world_size - result.global_rank = glob_world_rank - result.group_count = dist_group_size - result.data_rank = dist_group_rank - result.model_rank = dist_world_rank - - if backend == 'nccl': - result.local_device = torch.device('cuda', dist_local_rank) - torch.cuda.set_device(result.local_device) - else: - result.local_device = torch.device('cpu') - - result.data_group = data_group - result.model_group = model_group - result.global_group = global_group - - result.is_distributed = is_distributed - result.dist_print = dist_print - - logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}') - return result + model_group, data_group, global_group = None, None, None + + result = init_data_model_parallel + result.global_size = glob_world_size + result.global_rank = glob_world_rank + result.group_count = dist_group_size + result.data_rank = dist_group_rank + result.model_rank = dist_world_rank + + if backend == 'nccl': + result.local_device = torch.device('cuda', dist_local_rank) + torch.cuda.set_device(result.local_device) + else: + result.local_device = torch.device('cpu') + + result.data_group = data_group + result.model_group = model_group + result.global_group = global_group + + result.is_distributed = is_distributed + result.dist_print = dist_print + + # Temp work around for: https://github.com/pytorch/pytorch/issues/56390 + import atexit + atexit.register(lambda *args: os._exit(0)) + + logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}') + return result