diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index d65f91aa94ca11..0eaf7db7c90127 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -494,12 +494,28 @@ def support_deepspeed_ulysses(module): original_forward = module.forward - def wrapped_forward(*args, **kwargs): + def wrapped_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + **kwargs, + ): # lazily set if sequence parallelism is enabled to ensure deepspeed is initialized first if is_deepspeed_sp_enabled(): module.sp_group_size = deepspeed_groups._get_sequence_parallel_world_size() - return original_forward(*args, **kwargs) + return original_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal, + **kwargs, + ) module.forward = wrapped_forward diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e93f54830a677a..8114aa21232bae 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3644,21 +3644,20 @@ def _finalize_inputs( labels: Optional[torch.LongTensor] = None, **model_kwargs, ): + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "position_ids": position_ids, + "input_embeds": input_embeds, + "labels": labels, + **model_kwargs, + } if is_deepspeed_sp_enabled(): ds_plugin = self.accelerator.state.deepspeed_plugin num_shards = ds_plugin.sequence_parallel_size rank = ds_plugin.sequence_parallel_rank - inputs = shard_inputs( - num_shards, - rank, - input_ids=input_ids, - attention_mask=attention_mask, - loss_mask=loss_mask, - position_ids=position_ids, - inputs_embeds=input_embeds, - labels=labels, - **model_kwargs, - ) + inputs = shard_inputs(num_shards, rank, **inputs) return inputs def compute_loss_context_manager(self): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index d8d5d453883595..dc793233211300 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -898,12 +898,12 @@ def shard_tensor(tensor, num_shards, rank, dim=1): def shard_inputs( num_shards=1, rank=0, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - loss_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, + input_ids=None, + attention_mask=None, + loss_mask=None, + position_ids=None, + input_embeds=None, + labels=None, **kwargs, ): if num_shards == 1: @@ -912,7 +912,7 @@ def shard_inputs( "attention_mask": attention_mask, "loss_mask": loss_mask, "position_ids": position_ids, - "inputs_embeds": inputs_embeds, + "input_embeds": input_embeds, "labels": labels, **kwargs, } @@ -922,8 +922,8 @@ def shard_inputs( position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1) result = kwargs for key, value in zip( - ["input_ids", "attention_mask", "loss_mask", "position_ids", "inputs_embeds", "labels"], - [input_ids, attention_mask, loss_mask, position_ids, inputs_embeds, labels], + ["input_ids", "attention_mask", "loss_mask", "position_ids", "input_embeds", "labels"], + [input_ids, attention_mask, loss_mask, position_ids, input_embeds, labels], ): if value is not None: result[key] = shard_tensor(value, num_shards, rank)