From 6b434d936b8d892725a0ff0020d2a41f6aa43a3e Mon Sep 17 00:00:00 2001 From: ghostplant Date: Mon, 15 Nov 2021 08:31:01 +0000 Subject: [PATCH] fix parallel methods (#40) --- tutel/impls/communicate.py | 46 ++++++++++++++++++++++++++------------ tutel/impls/moe_layer.py | 3 ++- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/tutel/impls/communicate.py b/tutel/impls/communicate.py index 686bc612..9c6e7634 100644 --- a/tutel/impls/communicate.py +++ b/tutel/impls/communicate.py @@ -68,30 +68,48 @@ def backward(ctx: Any, grad_output: Tensor): class PreAllreduceSum(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor): + def forward(ctx, group, input): ctx.group = group - return input - + ctx.num_nodes = get_world_size(ctx.group) + if ctx.num_nodes <= 1: + return input + ctx.input_shape = input.shape + output = torch.empty([ctx.num_nodes, input.numel()], device=input.device, dtype=input.dtype) + tensor_list = [x.contiguous() for x in torch.chunk(output, chunks=ctx.num_nodes, dim=0)] + dist.all_gather(tensor_list=tensor_list, tensor=input.contiguous()) + output = output.view(list(input.shape[:0]) + [input.shape[0] * ctx.num_nodes] + list(input.shape[1:])) + return output @staticmethod - def backward(ctx: Any, grad_output: Tensor): + def backward(ctx, doutput): if get_world_size(ctx.group) <= 1: - return (None, grad_output) - dinput = torch.clone(grad_output).contiguous() - dist.all_reduce(dinput, op=torch.distributed.ReduceOp.SUM) + return (None, doutput) + dinput = torch.empty(ctx.input_shape, device=doutput.device, dtype=doutput.dtype) + chunks = [x.contiguous() for x in torch.chunk(doutput.view(ctx.num_nodes, -1), chunks=ctx.num_nodes, dim=0)] + dist.reduce_scatter(output=dinput, input_list=chunks) return (None, dinput) class PostAllreduceSum(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor): - if get_world_size(group) <= 1: + def forward(ctx, group, input): + ctx.group = group + ctx.num_nodes = get_world_size(ctx.group) + if ctx.num_nodes <= 1: return input - output = torch.clone(input).contiguous() - dist.all_reduce(output, op=torch.distributed.ReduceOp.SUM) + ctx.input_shape = input.shape + ctx.leading_dim = 0 + chunks = [x.contiguous() for x in torch.chunk(input, chunks=ctx.num_nodes, dim=ctx.leading_dim)] + assert len(chunks) == ctx.num_nodes + output = torch.empty_like(chunks[0]) + dist.reduce_scatter(output=output, input_list=list(chunks)) return output - @staticmethod - def backward(ctx: Any, grad_output: Tensor): - return (None, grad_output) + def backward(ctx, doutput): + if ctx.num_nodes <= 1: + return (None, doutput) + dinput = torch.empty(ctx.input_shape, device=doutput.device, dtype=doutput.dtype) + tensor_list = [x.contiguous() for x in torch.chunk(dinput, chunks=ctx.num_nodes, dim=ctx.leading_dim)] + dist.all_gather(tensor_list=tensor_list, tensor=doutput) + return (None, dinput) # A2A_TYPE: 0 for skip AllToAll, 1 for standard Pytorch AllToAll, 9 for standard Pytorch AllToAll with Timing diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py index d6ac8f99..e2e9fd68 100644 --- a/tutel/impls/moe_layer.py +++ b/tutel/impls/moe_layer.py @@ -157,7 +157,8 @@ def named_parameters(self): def apply_on_expert_fn(self, input, expert_fn, group): if self.l_zero is None: self.l_zero = torch.tensor(0, dtype=input.dtype, device=input.device) - result_output = expert_fn(PreAllreduceSum.apply(group, input)) + gathered_input = PreAllreduceSum.apply(group, input) + result_output = expert_fn(gathered_input) result_output = PostAllreduceSum.apply(group, result_output) return result_output, self.l_zero