diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml index 5c5caff1ebb0..c0576360cd61 100644 --- a/.github/workflows/hpu-gaudi2-nightly.yml +++ b/.github/workflows/hpu-gaudi2-nightly.yml @@ -21,7 +21,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index a06f871b7c56..b8b6f3cb5502 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 8eb95e49c29d..88c5494c8147 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,8 +6,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed -from deepspeed import comm as dist +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -97,7 +96,7 @@ def backward(ctx, grad_output): return grad_output # Async All-reduce. - handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) + handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) ctx.handle_dic[ctx.h_id] = handle return None, grad_output, None, None @@ -249,6 +248,10 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm is not initialized!" + self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): layernorm_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = deepspeed.comm.all_reduce(attention_output0, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) attention_output1, attention_bias1 = \ self.self_attention( layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle1 = deepspeed.comm.all_reduce(attention_output1, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle0.wait() # Residual0 connection. @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): output0 = output0 + bias_c output0 = self.mlp_activation_func(output0) output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle1.wait() @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if bias_c is not None: output1 = output1 + bias_c output1 = torch.matmul(output1, self.weight_r.t()) - deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) + dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) handle2.wait()