From 7e405f7435d61bf54cd225e469f9a71ef867d147 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 29 Apr 2024 13:28:37 -0400 Subject: [PATCH] Rework test --- tests/trainer/test_trainer.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9fff7a1c9ceacc..69672f2b7164b1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -34,6 +34,7 @@ from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files from parameterized import parameterized from requests.exceptions import HTTPError +from safetensors.torch import load_file from transformers import ( AutoTokenizer, @@ -1778,19 +1779,21 @@ def test_load_best_model_with_save(self): save_steps=5, evaluation_strategy="steps", eval_steps=5, + max_steps=9, ) trainer.train() # Check that we have the last known step: assert os.path.exists( os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}") ), f"Could not find checkpoint-{trainer.state.max_steps}" - # And then check the last multiple - last_multiple = trainer.state.max_steps - trainer.state.max_steps % 5 - assert os.path.exists( - os.path.join(tmpdir, f"checkpoint-{last_multiple}") - ), f"Could not find checkpoint-{last_multiple}" + # And then check the last step + assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9" # Now test that using a limit works + # Should result in: + # - save at step 5 (but is deleted) + # - save at step 10 (loaded in at the end when `load_best_model=True`) + # - save at step 11 with tempfile.TemporaryDirectory() as tmpdir: trainer = get_regression_trainer( output_dir=tmpdir, @@ -1799,21 +1802,23 @@ def test_load_best_model_with_save(self): eval_steps=5, load_best_model_at_end=True, save_total_limit=2, + max_steps=11, ) trainer.train() # Check that we have the last known step: - assert os.path.exists( - os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}") - ), f"Could not find checkpoint-{trainer.state.max_steps}" + assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11" # And then check the last multiple - last_multiple = trainer.state.max_steps - trainer.state.max_steps % 5 - assert os.path.exists( - os.path.join(tmpdir, f"checkpoint-{last_multiple}") - ), f"Could not find checkpoint-{last_multiple}" + assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10" # Finally check that we don't have an old one - assert not os.path.exists( - os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps-10}") - ), f"Found checkpoint-{trainer.state.max_steps-10}, limit not respected" + assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected" + + # Finally check that the right model was loaded in, checkpoint-10 + # this goes by the last `eval` step check to do so, so it won't be + # the last model *saved* + model_state = trainer.model.state_dict() + final_model_weights = load_file(os.path.join(tmpdir, "checkpoint-10", "model.safetensors")) + for k, v in model_state.items(): + assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same" @require_torch_multi_accelerator def test_run_seq2seq_double_train_wrap_once(self):