diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index f8e2b4725f42..dd1fd3dff5df 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -549,14 +549,14 @@ def _aggregate_total_loss(self): agg_loss /= self.dp_world_size assert self.global_rank in self.grid.pp_group - losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device) + losses = torch.stack([self.dp_group_loss, agg_loss]) if self.is_pipe_parallel: dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: # Get loss from last stage src_rank = self.grid.stage_to_global(self.num_stages - 1) assert src_rank in self.grid.pp_group - losses = torch.stack([self.dp_group_loss, agg_loss]) + losses = torch.Tensor([0., 0.]).to(self.device) dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group()) self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach()