Skip to content

Commit

Permalink
adds trainer method to finalize inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronald Rogers committed Dec 15, 2024
1 parent 29a8f8f commit d3b00ec
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
44 changes: 36 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
number_of_arguments,
seed_worker,
set_seed,
shard_inputs,
speed_metrics,
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
Expand Down Expand Up @@ -256,7 +257,7 @@

if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
from deepspeed.utils import groups as deepspeed_groups
from deepspeed.utils import groups as deepspeed_mpu

if is_accelerate_available("0.28.0"):
from accelerate.utils import DataLoaderConfiguration
Expand Down Expand Up @@ -2323,23 +2324,21 @@ def _inner_training_loop(

if delay_optimizer_creation:
if use_accelerator_prepare:
self.model, train_dataloader = self.accelerator.prepare(self.model, train_dataloader)
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model, train_dataloader = self.accelerator.prepare(self.model, train_dataloader)
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer, train_dataloader = self.accelerator.prepare(
self.model, self.optimizer, train_dataloader
)
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler, train_dataloader = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler, train_dataloader
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
Expand Down Expand Up @@ -3617,6 +3616,33 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s

return inputs

def _finalize_inputs(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
input_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**model_kwargs,
):
if is_deepspeed_sp_enabled():
ds_plugin = self.accelerator.state.deepspeed_plugin
num_shards = ds_plugin.sequence_parallel_size
rank = ds_plugin.sequence_parallel_rank
inputs = shard_inputs(
num_shards,
rank,
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
position_ids=position_ids,
inputs_embeds=input_embeds,
labels=labels,
**model_kwargs,
)
return inputs

def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
Expand Down Expand Up @@ -3721,7 +3747,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}

inputs = self._finalize_inputs(**inputs)
outputs = model(**inputs)

# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
Expand Down
43 changes: 43 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,46 @@ def check_target_module_exists(optim_target_modules, key: str, return_is_regex:
return target_module_found, is_regex

return target_module_found


def shard_tensor(tensor, num_shards, rank, dim=1):
seq_length = tensor.shape[dim]
sub_seq_length = seq_length // num_shards
indices = [slice(None)] * tensor.ndim
indices[dim] = slice(rank * sub_seq_length, (rank + 1) * sub_seq_length)
return tensor[tuple(indices)]


def shard_inputs(
num_shards=1,
rank=0,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
):
if num_shards == 1:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
"inputs_embeds": inputs_embeds,
"labels": labels,
**kwargs,
}
if position_ids is None and input_ids is not None:
# expand the position_ids to match batch size
position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1)
result = kwargs
for key, value in zip(
["input_ids", "attention_mask", "loss_mask", "position_ids", "inputs_embeds", "labels"],
[input_ids, attention_mask, loss_mask, position_ids, inputs_embeds, labels],
):
if value is not None:
result[key] = shard_tensor(value, num_shards, rank)
return result

0 comments on commit d3b00ec

Please sign in to comment.