Skip to content

Commit

Permalink
Remove FSDP wrapping from sub-models. (#34452)
Browse files Browse the repository at this point in the history
* Remove FSDP wrapping from sub-models.

* solve conflict trainer.py

* make fixup

* add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size

* put back extract_model_from_parallel

* use transformers unwrap_model
  • Loading branch information
eljandoubi authored Nov 15, 2024
1 parent b0c0ba7 commit 8d50fda
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
Expand Down Expand Up @@ -2277,8 +2277,11 @@ def _inner_training_loop(
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

# configure fsdp plugin for qlora if any
if use_accelerator_prepare:
if use_accelerator_prepare and self.is_fsdp_enabled:
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()

if delay_optimizer_creation:
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_trainer_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,33 @@ def test_trainer(self):
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call

class TestFSDPTrainerWrap(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
@require_fsdp
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{torch.cuda.device_count()}",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--output_dir",
f"{output_dir}",
"--report_to",
"none",
"--auto_find_batch_size",
"True",
]
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call


if __name__ == "__main__":
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
Expand Down

0 comments on commit 8d50fda

Please sign in to comment.