From 8d50fda64433023e19f51789105c1202c7339976 Mon Sep 17 00:00:00 2001 From: AbdelKarim ELJANDOUBI <78537694+eljandoubi@users.noreply.github.com> Date: Fri, 15 Nov 2024 23:00:03 +0100 Subject: [PATCH] Remove FSDP wrapping from sub-models. (#34452) * Remove FSDP wrapping from sub-models. * solve conflict trainer.py * make fixup * add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size * put back extract_model_from_parallel * use transformers unwrap_model --- src/transformers/trainer.py | 9 ++++++--- tests/trainer/test_trainer_fsdp.py | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fec4bc4d6b283c..1603a4ec215557 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -66,7 +66,7 @@ from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary -from .modeling_utils import PreTrainedModel, load_sharded_checkpoint +from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES, @@ -2277,8 +2277,11 @@ def _inner_training_loop( # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False - # configure fsdp plugin for qlora if any - if use_accelerator_prepare: + if use_accelerator_prepare and self.is_fsdp_enabled: + # In case of auto_find_batch_size=True + # Remove FSDP wrapping from sub-models. + self.model = unwrap_model(self.model, recursive=True) + # configure fsdp plugin for qlora if any self._fsdp_qlora_plugin_updates() if delay_optimizer_creation: diff --git a/tests/trainer/test_trainer_fsdp.py b/tests/trainer/test_trainer_fsdp.py index 4bcf5de04520e2..eca6a30664f045 100644 --- a/tests/trainer/test_trainer_fsdp.py +++ b/tests/trainer/test_trainer_fsdp.py @@ -117,6 +117,33 @@ def test_trainer(self): execute_subprocess_async(cmd, env=self.get_env()) # successful return here == success - any errors would have caused an error in the sub-call + class TestFSDPTrainerWrap(TestCasePlus): + @require_accelerate + @require_torch_multi_gpu + @require_fsdp + def test_trainer(self): + output_dir = self.get_auto_remove_tmp_dir() + cmd = [ + "accelerate", + "launch", + "--use_fsdp", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "--num_processes", + f"{torch.cuda.device_count()}", + "--fsdp_transformer_layer_cls_to_wrap", + "GPT2Block", + f"{self.test_file_dir}/test_trainer_fsdp.py", + "--output_dir", + f"{output_dir}", + "--report_to", + "none", + "--auto_find_batch_size", + "True", + ] + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + if __name__ == "__main__": parser = HfArgumentParser((Seq2SeqTrainingArguments,))