Skip to content

Commit

Permalink
Fix error caused by all_reduce call in domino (#6880)
Browse files Browse the repository at this point in the history
Fix #6851 
Initialize communication backend to fix error caused by all_reduce call
in the Domino transformer layer.
Verified correctness in local test.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Dec 26, 2024
1 parent eea5304 commit 85cc5f9
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import deepspeed
from deepspeed import comm as dist
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator


Expand Down Expand Up @@ -97,7 +96,7 @@ def backward(ctx, grad_output):
return grad_output

# Async All-reduce.
handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
ctx.handle_dic[ctx.h_id] = handle
return None, grad_output, None, None

Expand Down Expand Up @@ -249,6 +248,10 @@ def __init__(self,
output_bias=None):
super(DominoTransformerLayer, self).__init__()

if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm is not initialized!"

self.llama_model = config.llama_model
self.layer_number = layer_number
self.layer_type = layer_type
Expand Down Expand Up @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = deepspeed.comm.all_reduce(attention_output0,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle1 = deepspeed.comm.all_reduce(attention_output1,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()

# Residual0 connection.
Expand Down Expand Up @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
output0 = output0 + bias_c
output0 = self.mlp_activation_func(output0)
output0 = torch.matmul(output0, self.weight_r.t())
handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

handle1.wait()

Expand All @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if bias_c is not None:
output1 = output1 + bias_c
output1 = torch.matmul(output1, self.weight_r.t())
deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())

handle2.wait()

Expand Down

0 comments on commit 85cc5f9

Please sign in to comment.