diff --git a/README.md b/README.md index fb0429bd..e943eeb8 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,30 @@ How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enabl ``` -How to import Tutel-optimized MoE in Pytorch: +#### How to convert checkpoint files that adapt to different distributed world sizes: +``` +# Firstly, using 2 GPUs to train a model with 16 global experts (each GPU holds 8 local experts), saving checkpoint files in the end: +mpiexec -bind-to none -host localhost -x LOCAL_SIZE=2 python3 -m tutel.launcher.run -m tutel.examples.helloworld --num_local_experts=8 --checkpoint=./states/{rank}-of-{size}.ckpt --device=cuda + +# Secondly, convert the checkpoint files (based on 2 GPUs) into a single checkpoint file containing all parameters: +python3 -m tutel.checkpoint.gather --inputs=./states/{rank}-of-{size}.ckpt --input_size=2 --output ./model-synthetis.ckpt + +# Optionally, you can test the synthetis checkpoint using single CPU device, note that there will be 16 experts locally: +python3 -m tutel.examples.helloworld --num_local_experts=16 --checkpoint=./model-synthetis.ckpt --device=cpu --eval + +# Next, convert the synthetis checkpoint file that adapts to distributed training using 8 GPUs: +python3 -m tutel.checkpoint.scatter --input=./model-synthetis.ckpt --output_size=8 --outputs=./adapted-for-8-gpus/{rank}-of-{size}.ckpt + +# Then, using generated checkpoint files to train/eval using 8 GPUs, note that there will be 2 local experts this time: +mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld --num_local_experts=2 --checkpoint=./adapted-for-8-gpus/{rank}-of-{size}.ckpt --device=cuda + +# Similarly, the convertion tool also supports X global experts adapting to Y GPUs, where Y % X == 0, making num_local_experts to be -Y / X. +python3 -m tutel.checkpoint.scatter --input=./model-synthetis.ckpt --output_size=32 --outputs=./adapted-for-32-gpus/{rank}-of-{size}.ckpt +mpiexec -bind-to none -host localhost -x LOCAL_SIZE=32 python3 -m tutel.launcher.run -m tutel.examples.helloworld --num_local_experts=-2 --checkpoint=./adapted-for-32-gpus/{rank}-of-{size}.ckpt --device=cuda + +``` + +#### How to import Tutel-optimized MoE in Pytorch: ``` # Input Example: import torch @@ -104,7 +127,7 @@ y = moe_layer(x) print(y) ``` -Usage of MOELayer: +#### Usage of MOELayer: ``` * Usage of MOELayer Args: @@ -130,7 +153,7 @@ Usage of MOELayer: activation_fn : the custom-defined activation function between two linear layers (used for type == 'ffn' only) ``` -For Deepspeed MoE Acceleration (Deepspeed MoE Top-1 Gate has integrated Tutel acceleration): +#### For Deepspeed MoE Acceleration (Deepspeed MoE Top-1 Gate has integrated Tutel acceleration): ```sh # Without Tutel optimization: python3 -m tutel.examples.helloworld_deepspeed --top=1 @@ -139,25 +162,6 @@ python3 -m tutel.examples.helloworld_deepspeed --top=1 python3 -m tutel.examples.helloworld_deepspeed --top=1 --use_tutel ``` - -### Single-GPU Throughput (batches/sec) with default settings on NVIDIA A100 (40GB): -| batch-size | helloworld (top2) | helloworld_ddp (top2) | helloworld_deepspeed (top2) | -| :--------: | :--------: | :------------: | :------------------: | -| 8 | 672.75 | 672.24 | 188.27 | -| 16 | 715.86 | 714.95 | 115.43 | -| 24 | 725.95 | 725.04 | 81.02 | -| 32 | 729.02 | 729.02 | OOM | -| 64 | 687.92 | 686.31 | OOM | -| 128 | 619.75 | 619.03 | OOM | -| 256 | 577.08 | 577.49 | OOM | - -How to reproduce these results: -```shell -$ python3 -m tutel.examples.helloworld --batch_size= -$ python3 -m tutel.examples.helloworld_ddp --batch_size= -$ python3 -m tutel.examples.helloworld_deepspeed --batch_size= -``` - ## Reference You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel: ``` diff --git a/setup.py b/setup.py index 99ec8192..4fed37cf 100755 --- a/setup.py +++ b/setup.py @@ -72,12 +72,6 @@ def install(use_cuda, use_nccl): ext_libs += ['nccl'] ext_args += ['-DUSE_NCCL'] - for folder in ('build', 'dist',): - try: - shutil.rmtree(os.path.join(root_path, folder)) - except: - pass - setup( name='tutel', version='0.1', diff --git a/tutel/checkpoint/__init__.py b/tutel/checkpoint/__init__.py new file mode 100644 index 00000000..c45e0a75 --- /dev/null +++ b/tutel/checkpoint/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + diff --git a/tutel/checkpoint/gather.py b/tutel/checkpoint/gather.py new file mode 100644 index 00000000..70bd9f16 --- /dev/null +++ b/tutel/checkpoint/gather.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import argparse +import torch +import re + +from tutel.system import apply_rank_size_from_pattern + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_size', type=int, required=True) + parser.add_argument('--inputs', type=str, required=True) + parser.add_argument('--output', type=str, required=True) + args = parser.parse_args() + args.size = args.input_size + + mutate_size, expert_dict = {}, {} + + input_file = apply_rank_size_from_pattern(args.inputs, rank=0, size=args.size) + state_dict = torch.load(input_file, map_location=torch.device('cpu')) + for k in state_dict: + if k.endswith('._num_global_experts'): + entry = k[:k.rindex('.')] + '.experts.' + mutate_size[entry] = int(state_dict[k]) + + if not mutate_size: + raise Exception('No any Tutel MoE layer is found, as the provided checkpoint may be in legacy format. You need to reload this legacy checkpoint by corresponding application, re-checkpoint model\'s state_dict and get the latest format.') + + for rank in range(args.size): + input_file = apply_rank_size_from_pattern(args.inputs, rank=rank, size=args.size) + state_dict = torch.load(input_file, map_location=torch.device('cpu')) + for k in state_dict: + for e in mutate_size: + if k.startswith(e): + expert_dict[k] = expert_dict.get(k, [mutate_size[e],]) + [state_dict[k],] + + expert_dict = [(i, k, expert_dict[k]) for i, k in enumerate(expert_dict)] + for i, k, v in expert_dict: + num_global_experts, pieces = v[0], v[1:] + if num_global_experts % args.size == 0: + expert_dict[i] = torch.concat(pieces, dim=0).contiguous().clone() + assert expert_dict[i].size(0) == num_global_experts, "Unexpected group size of expert with num_global_experts: %d v.s. %d. Maybe you set a wrong --size value." % (expert_dict[i].size(0), num_global_experts) + elif args.size % num_global_experts == 0: + expert_dict[i] = torch.concat(pieces, dim=0).contiguous() + expert_dict[i] = expert_dict[i].view([num_global_experts, -1] + list(expert_dict[i].shape)[2:]).clone() + else: + raise Exception(f'Neither of "global_experts({num_global_experts}) / args.size({args.size})" nor "args.size({args.size}) / global_experts({num_global_experts})" is evenly divisible.') + expert_dict[i] = (k, expert_dict[i]) + + expert_dict = dict(expert_dict) + for k in state_dict: + if k in expert_dict: + state_dict[k] = expert_dict[k] + torch.save(state_dict, args.output) + +if __name__ == "__main__": + main() + diff --git a/tutel/checkpoint/scatter.py b/tutel/checkpoint/scatter.py new file mode 100644 index 00000000..89e671ef --- /dev/null +++ b/tutel/checkpoint/scatter.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import argparse +import torch +import re + +from tutel.system import apply_rank_size_from_pattern + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--output_size', type=int, required=True) + parser.add_argument('--input', type=str, required=True) + parser.add_argument('--outputs', type=str, required=True) + args = parser.parse_args() + args.size = args.output_size + + state_dict = torch.load(args.input, map_location=torch.device('cpu')) + mutate_size, expert_dict = {}, {} + + for k in state_dict: + if k.endswith('._num_global_experts'): + entry = k[:k.rindex('.')] + '.experts.' + mutate_size[entry] = int(state_dict[k]) + + if not mutate_size: + raise Exception('No any Tutel MoE layer is found, as the provided checkpoint may be in legacy format. You need to reload this legacy checkpoint by corresponding application, re-checkpoint model\'s state_dict and get the latest format.') + + for k in state_dict: + for e in mutate_size: + if k.startswith(e): + state = state_dict[k] + shape = state.shape + if shape[0] % args.size == 0: + state = state.view([args.size, shape[0] // args.size] + list(shape)[1:]) + elif args.size % shape[0] == 0: + divisor = args.size // shape[0] + for i in range(1, len(shape)): + if shape[i] <= 1: + continue + assert shape[i] % divisor == 0, f"The second non-squeezable dimension is to be sliced to {divisor} pieces from an parameter of shape {shape}, which isn't divisible evenly." + state = state.view([args.size,] + list(shape)[1:i] + [shape[i] // divisor,] + list(shape)[i+1:]) + else: + raise Exception(f'Neither of "global_experts({int(shape[0])}) / args.size({args.size})" nor "args.size({args.size}) / global_experts({int(shape[0])})" is evenly divisible.') + expert_dict[k] = state + + for rank in range(args.size): + generate_dict = dict() + for k in state_dict: + if k not in expert_dict: + generate_dict[k] = state_dict[k] + else: + generate_dict[k] = expert_dict[k][rank, :].contiguous().clone() + + output_file = apply_rank_size_from_pattern(args.outputs, rank=rank, size=args.size) + torch.save(generate_dict, output_file) + +if __name__ == "__main__": + main() + diff --git a/tutel/examples/helloworld.py b/tutel/examples/helloworld.py index cd0b696f..11d5535f 100755 --- a/tutel/examples/helloworld.py +++ b/tutel/examples/helloworld.py @@ -30,7 +30,7 @@ parser.add_argument('--allreduce_degree', type=int, default=1) parser.add_argument('--num_steps', type=int, default=100) parser.add_argument('--parallel_type', type=str, default='auto') -parser.add_argument('--save_load_checkpoint', default=False, action='store_true') +parser.add_argument('--checkpoint_path', type=str, default='') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--use_2dh', default=False, action='store_true') parser.add_argument('--eval', default=False, action='store_true') @@ -89,12 +89,12 @@ def forward(self, input): model = ExampleModel().to(device) dist_print(model) -if args.save_load_checkpoint: - checkpoint_path = './distributed-hellworld-%d-in-%d.ckpt' % (parallel_env.global_rank, parallel_env.global_size) +if args.checkpoint_path: + checkpoint_path = system.apply_rank_size_from_pattern(args.checkpoint_path, rank=parallel_env.global_rank, size=parallel_env.global_size) if os.path.exists(checkpoint_path): model.load_state_dict(torch.load(checkpoint_path)) else: - print('Checkpoint not loaded: file `%s` is not found' % checkpoint_path) + print('Checkpoint not loaded: file `%s` is not found. Will train the model from start.' % checkpoint_path) optimizer = torch.optim.SGD(model.parameters(), lr=1e-5) @@ -145,5 +145,5 @@ def forward(self, input): average_time /= 10 dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time) -if args.save_load_checkpoint: +if args.checkpoint_path: torch.save(model.state_dict(), checkpoint_path) diff --git a/tutel/examples/helloworld_ddp_tutel.py b/tutel/examples/helloworld_ddp_tutel.py index fdb1af43..df36d72d 100755 --- a/tutel/examples/helloworld_ddp_tutel.py +++ b/tutel/examples/helloworld_ddp_tutel.py @@ -30,7 +30,6 @@ parser.add_argument('--allreduce_degree', type=int, default=1) parser.add_argument('--num_steps', type=int, default=100) parser.add_argument('--parallel_type', type=str, default='auto') -parser.add_argument('--save_load_checkpoint', default=False, action='store_true') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--use_2dh', default=False, action='store_true') parser.add_argument('--eval', default=False, action='store_true') @@ -88,12 +87,6 @@ def forward(self, input): model = ExampleModel().to(device) dist_print(model) -if args.save_load_checkpoint: - checkpoint_path = './distributed-hellworld-%d-in-%d.ckpt' % (parallel_env.global_rank, parallel_env.global_size) - if os.path.exists(checkpoint_path): - model.load_state_dict(torch.load(checkpoint_path)) - else: - print('Checkpoint not loaded: file `%s` is not found' % checkpoint_path) optimizer = net.TutelDistributedOptimizer(model.parameters(), group=None, average_shared=True).warp_local(torch.optim.SGD, lr=1e-5) @@ -134,6 +127,3 @@ def forward(self, input): average_time /= 10 dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time) - -if args.save_load_checkpoint: - torch.save(model.state_dict(), checkpoint_path) diff --git a/tutel/experts/ffn.py b/tutel/experts/ffn.py index d5e2907c..1fe8a04a 100644 --- a/tutel/experts/ffn.py +++ b/tutel/experts/ffn.py @@ -29,14 +29,14 @@ def update(self, ctx): fc1_weight = torch.empty(1, local_experts, hidden_size, model_dim) fc2_weight = torch.empty(1, local_experts, hidden_size, self.output_dim) - fc1_bias = torch.empty(1, local_experts, 1, hidden_size) - fc2_bias = torch.empty(1, local_experts, 1, (self.output_dim + ctx.sharded_count - 1) // ctx.sharded_count) + fc1_bias = torch.empty(1, local_experts, hidden_size) + fc2_bias = torch.empty(1, local_experts, (self.output_dim + ctx.sharded_count - 1) // ctx.sharded_count) for i in range(local_experts): fc1 = torch.nn.Linear(model_dim, hidden_size) fc2 = torch.nn.Linear(hidden_size, self.output_dim) - fc1_weight[0, i, :, :], fc1_bias[0, i, :, :] = fc1.weight, fc1.bias - fc2_weight[0, i, :, :], fc2_bias[0, i, :, :] = fc2.weight.t(), fc2.bias[:fc2_bias.size(-1)] + fc1_weight[0, i, :, :], fc1_bias[0, i, :] = fc1.weight, fc1.bias + fc2_weight[0, i, :, :], fc2_bias[0, i, :] = fc2.weight.t(), fc2.bias[:fc2_bias.size(-1)] self.register_parameter(name='batched_fc1_w', param=torch.nn.Parameter(fc1_weight.squeeze(0))) self.register_parameter(name='batched_fc2_w', param=torch.nn.Parameter(fc2_weight.squeeze(0))) @@ -54,8 +54,8 @@ def forward(self, x, ctx): batched_fc1_w = self.batched_fc1_w batched_fc2_w = self.batched_fc2_w - batched_fc1_bias = self.batched_fc1_bias - batched_fc2_bias = self.batched_fc2_bias + batched_fc1_bias = self.batched_fc1_bias.unsqueeze(1) + batched_fc2_bias = self.batched_fc2_bias.unsqueeze(1) if ctx.ffn_zero_group is not None: if not ctx.use_model_parallel: @@ -64,7 +64,7 @@ def forward(self, x, ctx): batched_fc1_bias = zero_gather(batched_fc1_bias, group=ctx.ffn_zero_group).view(1, 1, -1) batched_fc2_bias = zero_gather(batched_fc2_bias, group=ctx.ffn_zero_group) - batched_fc2_bias = batched_fc2_bias.view(self.batched_fc2_bias.size(0), self.batched_fc2_bias.size(1), -1) + batched_fc2_bias = batched_fc2_bias.view(self.batched_fc2_bias.size(0), 1, -1) if batched_fc2_bias.size(-1) != self.output_dim: batched_fc2_bias = batched_fc2_bias[:, :, :self.output_dim] diff --git a/tutel/impls/communicate.py b/tutel/impls/communicate.py index 29c3db79..667069a4 100644 --- a/tutel/impls/communicate.py +++ b/tutel/impls/communicate.py @@ -170,7 +170,7 @@ def simple_reduce_scatter(input, group=None, op=torch.distributed.ReduceOp.SUM): input = input.contiguous() assert input.size(0) % world_size == 0, "Cannot evenly devide dim length %s into %s slices" % (input.size(0), world_size) if not input.is_cuda: - return simple_split(simple_all_reduce(input, group, op=op)) + return simple_split(simple_all_reduce(input, group, op=op), group=group) chunks = list(input.chunk(chunks=world_size, dim=0)) output = torch.empty_like(chunks[0]) dist.reduce_scatter(output=output, input_list=chunks, group=group, op=op) @@ -183,7 +183,7 @@ def simple_all_gather(input, group=None): input = input.contiguous() output = torch.empty([world_size, input.numel()], device=input.device, dtype=input.dtype) tensor_list = list(torch.chunk(output, chunks=world_size, dim=0)) - dist.all_gather(tensor_list=tensor_list, tensor=input, group=group) + dist.all_gather(tensor_list=tensor_list, tensor=input.view(1, -1), group=group) return output.view([-1,] + list(input.shape[1:])) class AllToAllStatus: diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py index cc9f8c75..c8a31d50 100644 --- a/tutel/impls/moe_layer.py +++ b/tutel/impls/moe_layer.py @@ -38,6 +38,31 @@ def global_expert_count(num_local_experts, group=None): assert world_size % -num_local_experts == 0, "Excepting {-num_local_experts} devices to share an expert param, while global device count is {world_size}." return world_size // -num_local_experts + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + buff_name = prefix + '_num_global_experts' + if buff_name not in state_dict: + logging.warning(f"\033[31mYou are loading a legacy format of checkpoint with at least one Tutel MoE layer inside, which wouldn't support new Tutel feature allowing the number of experts per checkpoint file to mutate.\033[0m") + logging.warning(f"\033[31m The next time you overwrite it with new checkpoint, the recording format will be updated automatically.\033[0m") + logging.warning(f"\033[31m However, the new format won't be compatible with early Tutel versions, unless you force loading it with `model.load_state_dict(.., strict=False)`.\033[0m") + state_dict[buff_name] = self._num_global_experts + else: + state_experts, expect_experts = int(state_dict[buff_name]), self.num_global_experts + assert state_experts == expect_experts, "Failed to load state from checkpoint: the number of global experts mismatch (%s <- %s)" % (expect_experts, state_experts) + + for name, param in self.experts.named_parameters(): + buff_name = prefix + 'experts.' + name + assert buff_name in state_dict, "Could not find parameter `%s` in state_dict." % buff_name + if state_dict[buff_name].numel() == param.numel(): + state_dict[buff_name] = state_dict[buff_name].view(param.shape) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return super().state_dict(destination, prefix, keep_vars) + + @property + def num_global_experts(self): + return int(self._num_global_experts) + def __init__( self, gate_type, @@ -71,7 +96,7 @@ def __init__( self.skip_moe = (int(os.environ.get('SKIP_MOE', '0')) != 0) self.num_local_experts = experts.pop('count_per_node', 1) - self.num_global_experts = MOELayer.global_expert_count(self.num_local_experts, self.group) + self.register_buffer('_num_global_experts', torch.tensor(MOELayer.global_expert_count(self.num_local_experts, self.group))) self.world_size = C.get_world_size(self.group) if self.num_global_experts < self.world_size: diff --git a/tutel/moe.py b/tutel/moe.py index 05cefa61..3b0905ce 100644 --- a/tutel/moe.py +++ b/tutel/moe.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. -# Level-level Ops +# Low-level Ops from .jit_kernels.gating import fast_cumsum_sub_one from .impls.fast_dispatch import fast_dispatcher, extract_critical, fast_encode, fast_decode diff --git a/tutel/system.py b/tutel/system.py index 324be4a6..c6276e95 100644 --- a/tutel/system.py +++ b/tutel/system.py @@ -87,3 +87,18 @@ def load(path, device=None): import torch npv = np.load(path) return torch.tensor(npv, device=device) + + +def apply_rank_size_from_pattern(filename, rank, size, create_dir=True): + if not re.search(r'\{rank\}', filename): + logging.warning('Keyword `{rank}` is not found in file pattern: %s, which may cause collision in file access.' % filename) + + filename = re.sub(r'\{rank\}', str(rank), re.sub(r'\{size\}', str(size), filename)) + if create_dir: + filedir = os.path.dirname(filename) + if filedir: + try: + os.makedirs(filedir) + except FileExistsError: + pass + return filename