Skip to content

Commit

Permalink
fix code quality issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronald Rogers committed Dec 17, 2024
1 parent 38ca81d commit 6ac713e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
20 changes: 18 additions & 2 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,28 @@ def support_deepspeed_ulysses(module):

original_forward = module.forward

def wrapped_forward(*args, **kwargs):
def wrapped_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
**kwargs,
):
# lazily set if sequence parallelism is enabled to ensure deepspeed is initialized first
if is_deepspeed_sp_enabled():
module.sp_group_size = deepspeed_groups._get_sequence_parallel_world_size()

return original_forward(*args, **kwargs)
return original_forward(
query_states,
key_states,
value_states,
attention_mask,
query_length,
is_causal,
**kwargs,
)

module.forward = wrapped_forward

Expand Down
21 changes: 10 additions & 11 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3644,21 +3644,20 @@ def _finalize_inputs(
labels: Optional[torch.LongTensor] = None,
**model_kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
"input_embeds": input_embeds,
"labels": labels,
**model_kwargs,
}
if is_deepspeed_sp_enabled():
ds_plugin = self.accelerator.state.deepspeed_plugin
num_shards = ds_plugin.sequence_parallel_size
rank = ds_plugin.sequence_parallel_rank
inputs = shard_inputs(
num_shards,
rank,
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
position_ids=position_ids,
inputs_embeds=input_embeds,
labels=labels,
**model_kwargs,
)
inputs = shard_inputs(num_shards, rank, **inputs)
return inputs

def compute_loss_context_manager(self):
Expand Down
18 changes: 9 additions & 9 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,12 +898,12 @@ def shard_tensor(tensor, num_shards, rank, dim=1):
def shard_inputs(
num_shards=1,
rank=0,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
input_ids=None,
attention_mask=None,
loss_mask=None,
position_ids=None,
input_embeds=None,
labels=None,
**kwargs,
):
if num_shards == 1:
Expand All @@ -912,7 +912,7 @@ def shard_inputs(
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
"inputs_embeds": inputs_embeds,
"input_embeds": input_embeds,
"labels": labels,
**kwargs,
}
Expand All @@ -922,8 +922,8 @@ def shard_inputs(
position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1)
result = kwargs
for key, value in zip(
["input_ids", "attention_mask", "loss_mask", "position_ids", "inputs_embeds", "labels"],
[input_ids, attention_mask, loss_mask, position_ids, inputs_embeds, labels],
["input_ids", "attention_mask", "loss_mask", "position_ids", "input_embeds", "labels"],
[input_ids, attention_mask, loss_mask, position_ids, input_embeds, labels],
):
if value is not None:
result[key] = shard_tensor(value, num_shards, rank)
Expand Down

0 comments on commit 6ac713e

Please sign in to comment.