diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index a21c46fea9fe2a..d2570ed8ba4317 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -544,6 +544,9 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra # End training if state.global_step >= state.max_steps: control.should_training_stop = True + # Save the model at the end if we have a save strategy + if args.save_strategy != IntervalStrategy.NO: + control.should_save = True return control diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 24bb3c7d027a25..60a81f74542b94 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -335,6 +335,9 @@ class TrainingArguments: - `"no"`: No save is done during training. - `"epoch"`: Save is done at the end of each epoch. - `"steps"`: Save is done every `save_steps`. + + If `"epoch"` or `"steps"` is chosen, saving will also be performed at the + very end of training, always. save_steps (`int` or `float`, *optional*, defaults to 500): Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fc7c3dc7834e90..4d3fc57340054d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -129,6 +129,7 @@ if is_safetensors_available(): import safetensors.torch + # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") @@ -2016,6 +2017,56 @@ def test_safe_checkpoints(self): tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors ) + def test_load_best_model_with_save(self): + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + save_steps=5, + evaluation_strategy="steps", + eval_steps=5, + max_steps=9, + ) + trainer.train() + # Check that we have the last known step: + assert os.path.exists( + os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}") + ), f"Could not find checkpoint-{trainer.state.max_steps}" + # And then check the last step + assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9" + + # Now test that using a limit works + # Should result in: + # - save at step 5 (but is deleted) + # - save at step 10 (loaded in at the end when `load_best_model=True`) + # - save at step 11 + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + save_steps=5, + evaluation_strategy="steps", + eval_steps=5, + load_best_model_at_end=True, + save_total_limit=2, + max_steps=11, + ) + trainer.train() + # Check that we have the last known step: + assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11" + # And then check the last multiple + assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10" + # Finally check that we don't have an old one + assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected" + + # Finally check that the right model was loaded in, checkpoint-10 + # this goes by the last `eval` step check to do so, so it won't be + # the last model *saved* + model_state = trainer.model.state_dict() + final_model_weights = safetensors.torch.load_file( + os.path.join(tmpdir, "checkpoint-10", "model.safetensors") + ) + for k, v in model_state.items(): + assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same" + @require_torch_multi_accelerator def test_run_seq2seq_double_train_wrap_once(self): # test that we don't wrap the model more than once diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 8c0c9367d8d779..9eeb1d5e412e17 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -153,7 +153,7 @@ def get_expected_events(self, trainer): expected_events.append("on_log") if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: expected_events += evaluation_events.copy() - if step % trainer.args.save_steps == 0: + if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps: expected_events.append("on_save") expected_events.append("on_epoch_end") if trainer.args.eval_strategy == IntervalStrategy.EPOCH: