diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 009e24ade045c5..5844599071273e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1507,6 +1507,10 @@ def train( and not self.is_fsdp_enabled ): self._load_from_checkpoint(resume_from_checkpoint) + # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly + state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + if state.train_batch_size is not None: + self._train_batch_size = state.train_batch_size # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: @@ -1542,6 +1546,8 @@ def _inner_training_loop( ): self.accelerator.free_memory() self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + self.state.train_batch_size = self._train_batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -1618,6 +1624,7 @@ def _inner_training_loop( self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 13b2dcb6b0896b..7533d7219c19db 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -59,6 +59,9 @@ class TrainerState: Run an evaluation every X steps. save_steps (`int`, *optional*, defaults to 500): Save checkpoint every X updates steps. + train_batch_size (`int`, *optional*): + The batch size for the training dataloader. Only needed when + `auto_find_batch_size` has been used. num_input_tokens_seen (`int`, *optional*, defaults to 0): The number of tokens seen during training (number of input tokens, not the number of prediction tokens). total_flos (`float`, *optional*, defaults to 0): @@ -88,6 +91,7 @@ class TrainerState: logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 + train_batch_size: int = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 129f40fc40968e..22c43071aabeac 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -38,6 +38,7 @@ AutoTokenizer, IntervalStrategy, PretrainedConfig, + TrainerCallback, TrainingArguments, get_polynomial_decay_schedule_with_warmup, is_torch_available, @@ -1546,6 +1547,41 @@ def test_auto_batch_size_finder(self): with patch.object(sys, "argv", testargs): run_glue.main() + def test_auto_batch_size_with_resume_from_checkpoint(self): + train_dataset = RegressionDataset(length=128) + + config = RegressionModelConfig(a=0, b=2) + model = RegressionRandomPreTrainedModel(config) + + tmp_dir = self.get_auto_remove_tmp_dir() + + class MockCudaOOMCallback(TrainerCallback): + def on_step_end(self, args, state, control, **kwargs): + # simulate OOM on the first step + if state.train_batch_size == 16: + raise RuntimeError("CUDA out of memory.") + + args = RegressionTrainingArguments( + tmp_dir, + do_train=True, + max_steps=2, + save_steps=1, + per_device_train_batch_size=16, + auto_find_batch_size=True, + ) + trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) + trainer.train() + # After `auto_find_batch_size` is ran we should now be at 8 + self.assertEqual(trainer._train_batch_size, 8) + + # We can then make a new Trainer + trainer = Trainer(model, args, train_dataset=train_dataset) + # Check we are at 16 to start + self.assertEqual(trainer._train_batch_size, 16) + trainer.train(resume_from_checkpoint=True) + # We should be back to 8 again, picking up based upon the last ran Trainer + self.assertEqual(trainer._train_batch_size, 8) + # regression for this issue: https://github.com/huggingface/transformers/issues/12970 def test_training_with_resume_from_checkpoint_false(self): train_dataset = RegressionDataset(length=128)