Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce saving at end of training if saving option chosen #30160

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,9 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
# End training
if state.global_step >= state.max_steps:
control.should_training_stop = True
# Save the model at the end if we have a save strategy
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True

return control

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ class TrainingArguments:
- `"no"`: No save is done during training.
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`.

If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
very end of training, always.
save_steps (`int` or `float`, *optional*, defaults to 500):
Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
Expand Down
51 changes: 51 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
if is_safetensors_available():
import safetensors.torch


# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
Expand Down Expand Up @@ -1968,6 +1969,56 @@ def test_safe_checkpoints(self):
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
)

def test_load_best_model_with_save(self):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
save_steps=5,
evaluation_strategy="steps",
eval_steps=5,
max_steps=9,
)
trainer.train()
# Check that we have the last known step:
assert os.path.exists(
os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
), f"Could not find checkpoint-{trainer.state.max_steps}"
# And then check the last step
assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9"

# Now test that using a limit works
# Should result in:
# - save at step 5 (but is deleted)
# - save at step 10 (loaded in at the end when `load_best_model=True`)
# - save at step 11
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
save_steps=5,
evaluation_strategy="steps",
eval_steps=5,
load_best_model_at_end=True,
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
save_total_limit=2,
max_steps=11,
)
trainer.train()
# Check that we have the last known step:
assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11"
# And then check the last multiple
assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10"
# Finally check that we don't have an old one
assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected"

# Finally check that the right model was loaded in, checkpoint-10
# this goes by the last `eval` step check to do so, so it won't be
# the last model *saved*
model_state = trainer.model.state_dict()
final_model_weights = safetensors.torch.load_file(
os.path.join(tmpdir, "checkpoint-10", "model.safetensors")
)
for k, v in model_state.items():
assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same"

@require_torch_multi_accelerator
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_expected_events(self, trainer):
expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0:
if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps:
expected_events.append("on_save")
expected_events.append("on_epoch_end")
if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
Expand Down
Loading