Skip to content

Commit

Permalink
Allow resume_from_checkpoint to handle auto_find_batch_size (#27568)
Browse files Browse the repository at this point in the history
* Fuffill request

* Add test

* Better test

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Arthur <[email protected]>

* Better test

* Better test

* MOre comments

---------

Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2023
1 parent aa7ab98 commit 6757ed2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,10 @@ def train(
and not self.is_fsdp_enabled
):
self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
if state.train_batch_size is not None:
self._train_batch_size = state.train_batch_size

# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
Expand Down Expand Up @@ -1542,6 +1546,8 @@ def _inner_training_loop(
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
Expand Down Expand Up @@ -1618,6 +1624,7 @@ def _inner_training_loop(

self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class TrainerState:
Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps.
train_batch_size (`int`, *optional*):
The batch size for the training dataloader. Only needed when
`auto_find_batch_size` has been used.
num_input_tokens_seen (`int`, *optional*, defaults to 0):
The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
total_flos (`float`, *optional*, defaults to 0):
Expand Down Expand Up @@ -88,6 +91,7 @@ class TrainerState:
logging_steps: int = 500
eval_steps: int = 500
save_steps: int = 500
train_batch_size: int = None
num_train_epochs: int = 0
num_input_tokens_seen: int = 0
total_flos: float = 0
Expand Down
36 changes: 36 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AutoTokenizer,
IntervalStrategy,
PretrainedConfig,
TrainerCallback,
TrainingArguments,
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
Expand Down Expand Up @@ -1546,6 +1547,41 @@ def test_auto_batch_size_finder(self):
with patch.object(sys, "argv", testargs):
run_glue.main()

def test_auto_batch_size_with_resume_from_checkpoint(self):
train_dataset = RegressionDataset(length=128)

config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()

class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size == 16:
raise RuntimeError("CUDA out of memory.")

args = RegressionTrainingArguments(
tmp_dir,
do_train=True,
max_steps=2,
save_steps=1,
per_device_train_batch_size=16,
auto_find_batch_size=True,
)
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
trainer.train()
# After `auto_find_batch_size` is ran we should now be at 8
self.assertEqual(trainer._train_batch_size, 8)

# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16)
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)

# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128)
Expand Down

0 comments on commit 6757ed2

Please sign in to comment.