Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Dec 26, 2024
2 parents 2795563 + 85cc5f9 commit 9090a79
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import deepspeed
from deepspeed import comm as dist
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator


Expand Down Expand Up @@ -97,7 +96,7 @@ def backward(ctx, grad_output):
return grad_output

# Async All-reduce.
handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
ctx.handle_dic[ctx.h_id] = handle
return None, grad_output, None, None

Expand Down Expand Up @@ -249,6 +248,10 @@ def __init__(self,
output_bias=None):
super(DominoTransformerLayer, self).__init__()

if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm is not initialized!"

self.llama_model = config.llama_model
self.layer_number = layer_number
self.layer_type = layer_type
Expand Down Expand Up @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = deepspeed.comm.all_reduce(attention_output0,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle1 = deepspeed.comm.all_reduce(attention_output1,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()

# Residual0 connection.
Expand Down Expand Up @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
output0 = output0 + bias_c
output0 = self.mlp_activation_func(output0)
output0 = torch.matmul(output0, self.weight_r.t())
handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

handle1.wait()

Expand All @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if bias_c is not None:
output1 = output1 + bias_c
output1 = torch.matmul(output1, self.weight_r.t())
deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())

handle2.wait()

Expand Down

0 comments on commit 9090a79

Please sign in to comment.