Skip to content

Commit

Permalink
exclude fsdp from delay_optimizer_creation (huggingface#34140)
Browse files Browse the repository at this point in the history
* exclude fsdp from delay_optimizer_creation

* add test case for trainer: FSDP mode and fp8 as mixed precision

* rearrange imports

* ruff formatted

* adapt _init_fsdp to fp8

* use _init_fsdp only when resume_from_checkpoint

* In case of FDP, self.layer will be CheckpointWrapper which has no len() method

* delete _init_fsdp

* solve conflict

* fix conflict

* make fixup
  • Loading branch information
eljandoubi authored and BernardZach committed Dec 6, 2024
1 parent 1e84943 commit de07def
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils.imports import is_fp8_available


if is_pytest_available():
Expand Down Expand Up @@ -1000,6 +1001,13 @@ def require_torch_fp16(test_case):
)(test_case)


def require_fp8(test_case):
"""Decorator marking a test that requires supports for fp8"""
return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
test_case
)


def require_torch_bf16(test_case):
"""Decorator marking a test that requires a device that supports bf16"""
return unittest.skipUnless(
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,7 +2209,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -2258,9 +2258,12 @@ def _inner_training_loop(
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

# configure fsdp plugin for qlora if any
if use_accelerator_prepare:
self._fsdp_qlora_plugin_updates()

if delay_optimizer_creation:
if use_accelerator_prepare:
self._fsdp_qlora_plugin_updates()
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
execute_subprocess_async,
get_torch_dist_unique_port,
require_accelerate,
require_fp8,
require_fsdp,
require_torch_multi_gpu,
)

Expand Down Expand Up @@ -64,6 +66,7 @@ def __getitem__(self, i: int) -> str:
class TestFSDPTrainer(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
@require_fsdp
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
Expand All @@ -86,6 +89,35 @@ def test_trainer(self):
# successful return here == success - any errors would have caused an error in the sub-call


class TestFSDPTrainerFP8(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
@require_fsdp
@require_fp8
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{torch.cuda.device_count()}",
"--mixed_precision",
"fp8",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--output_dir",
f"{output_dir}",
"--report_to",
"none",
]
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call


if __name__ == "__main__":
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
Expand Down

0 comments on commit de07def

Please sign in to comment.