diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 2a9473c862ffa9..c2d539b3e99e75 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -34,6 +34,7 @@ slow, torch_device, ) +from transformers.trainer import FSDP_MODEL_NAME from transformers.trainer_callback import TrainerState from transformers.trainer_utils import FSDPOption, set_seed from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device @@ -211,6 +212,19 @@ def test_training_and_can_resume_normally(self, state_dict_type): # resume from ckpt checkpoint = os.path.join(output_dir, "checkpoint-115") resume_args = args + f"--resume_from_checkpoint {checkpoint}".split() + + is_fsdp_ckpt = os.path.isdir(checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(checkpoint) + if os.path.isdir(os.path.join(checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin")) + ) + self.assertTrue(is_fsdp_ckpt) + logs_resume = self.run_cmd_and_get_logs( use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir )