Skip to content

Commit

Permalink
Enforce saving at end of training if saving option chosen (#30160)
Browse files Browse the repository at this point in the history
* Enforce saving at end of training

* Fix test

* Rework test

* Fixup tests'

* Update comment based on sourab feedback

* Clean
  • Loading branch information
muellerzr authored and Ita Zaporozhets committed May 24, 2024
1 parent 62fa86e commit 4cf1249
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,9 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
# End training
if state.global_step >= state.max_steps:
control.should_training_stop = True
# Save the model at the end if we have a save strategy
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True

return control

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@ class TrainingArguments:
- `"no"`: No save is done during training.
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`.
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
very end of training, always.
save_steps (`int` or `float`, *optional*, defaults to 500):
Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
Expand Down
51 changes: 51 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
if is_safetensors_available():
import safetensors.torch


# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
Expand Down Expand Up @@ -2016,6 +2017,56 @@ def test_safe_checkpoints(self):
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
)

def test_load_best_model_with_save(self):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
save_steps=5,
evaluation_strategy="steps",
eval_steps=5,
max_steps=9,
)
trainer.train()
# Check that we have the last known step:
assert os.path.exists(
os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")
), f"Could not find checkpoint-{trainer.state.max_steps}"
# And then check the last step
assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9"

# Now test that using a limit works
# Should result in:
# - save at step 5 (but is deleted)
# - save at step 10 (loaded in at the end when `load_best_model=True`)
# - save at step 11
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
save_steps=5,
evaluation_strategy="steps",
eval_steps=5,
load_best_model_at_end=True,
save_total_limit=2,
max_steps=11,
)
trainer.train()
# Check that we have the last known step:
assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11"
# And then check the last multiple
assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10"
# Finally check that we don't have an old one
assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected"

# Finally check that the right model was loaded in, checkpoint-10
# this goes by the last `eval` step check to do so, so it won't be
# the last model *saved*
model_state = trainer.model.state_dict()
final_model_weights = safetensors.torch.load_file(
os.path.join(tmpdir, "checkpoint-10", "model.safetensors")
)
for k, v in model_state.items():
assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same"

@require_torch_multi_accelerator
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_expected_events(self, trainer):
expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0:
if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps:
expected_events.append("on_save")
expected_events.append("on_epoch_end")
if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
Expand Down

0 comments on commit 4cf1249

Please sign in to comment.