Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Dec 15, 2023
1 parent fe81718 commit 8b6b801
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
slow,
torch_device,
)
from transformers.trainer import FSDP_MODEL_NAME
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import FSDPOption, set_seed
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
Expand Down Expand Up @@ -211,6 +212,19 @@ def test_training_and_can_resume_normally(self, state_dict_type):
# resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-115")
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()

is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)

logs_resume = self.run_cmd_and_get_logs(
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
)
Expand Down

0 comments on commit 8b6b801

Please sign in to comment.