Skip to content

Commit

Permalink
Fixing the if/else block on which the optimization should take place
Browse files Browse the repository at this point in the history
  • Loading branch information
nelyahu committed Oct 9, 2023
1 parent 75e4ef0 commit 932ca43
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,14 @@ def _aggregate_total_loss(self):
agg_loss /= self.dp_world_size

assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
losses = torch.stack([self.dp_group_loss, agg_loss])
if self.is_pipe_parallel:
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else:
# Get loss from last stage
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
losses = torch.stack([self.dp_group_loss, agg_loss])
losses = torch.Tensor([0., 0.]).to(self.device)
dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
self.dp_group_loss = losses[0].clone().detach()
agg_loss = losses[1].clone().detach()
Expand Down

0 comments on commit 932ca43

Please sign in to comment.