Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exclude fsdp from delay_optimizer_creation #34140

Merged
merged 34 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
cd0e8bb
exclude fsdp from delay_optimizer_creation
eljandoubi Oct 13, 2024
d18e642
add test case for trainer: FSDP mode and fp8 as mixed precision
eljandoubi Oct 14, 2024
3344110
rearrange imports
eljandoubi Oct 14, 2024
5055e2a
ruff formatted
eljandoubi Oct 14, 2024
656d7cc
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 15, 2024
22cc58d
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 16, 2024
4827a39
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 16, 2024
4a84f0f
adapt _init_fsdp to fp8
eljandoubi Oct 16, 2024
2e91c5f
use _init_fsdp only when resume_from_checkpoint
eljandoubi Oct 16, 2024
f5a3796
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 16, 2024
af73835
In case of FDP, self.layer will be CheckpointWrapper which has no len…
eljandoubi Oct 16, 2024
a2f30b0
delete _init_fsdp
eljandoubi Oct 16, 2024
a838ba5
solve conflict
eljandoubi Oct 16, 2024
cc5b4c3
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 16, 2024
d84336f
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 17, 2024
acffb63
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 17, 2024
78eed70
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 17, 2024
49882f8
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 17, 2024
9ac4664
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 18, 2024
5acf8e0
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 18, 2024
d7a0194
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 18, 2024
58d18f6
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 21, 2024
b94376d
fix conflict
eljandoubi Oct 22, 2024
b9b9eb4
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 22, 2024
f466513
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 23, 2024
2948b29
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 24, 2024
748270d
Merge branch 'main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 24, 2024
0ec8e58
make fixup
eljandoubi Oct 24, 2024
09df2ed
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 25, 2024
a3265d9
Merge branch 'main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 25, 2024
33902fd
Merge branch 'main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 25, 2024
cfd8152
Merge branch 'main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 25, 2024
571e58f
Merge branch 'main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 25, 2024
02a63c7
Merge branch 'main' into fix_fsdp_with_fp8_in_trainer
eljandoubi Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils.imports import is_fp8_available


if is_pytest_available():
Expand Down Expand Up @@ -1000,6 +1001,13 @@ def require_torch_fp16(test_case):
)(test_case)


def require_fp8(test_case):
"""Decorator marking a test that requires supports for fp8"""
return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
test_case
)


def require_torch_bf16(test_case):
"""Decorator marking a test that requires a device that supports bf16"""
return unittest.skipUnless(
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,7 +2209,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -2258,9 +2258,12 @@ def _inner_training_loop(
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

# configure fsdp plugin for qlora if any
if use_accelerator_prepare:
self._fsdp_qlora_plugin_updates()

if delay_optimizer_creation:
if use_accelerator_prepare:
self._fsdp_qlora_plugin_updates()
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
execute_subprocess_async,
get_torch_dist_unique_port,
require_accelerate,
require_fp8,
require_fsdp,
require_torch_multi_gpu,
)

Expand Down Expand Up @@ -64,6 +66,7 @@ def __getitem__(self, i: int) -> str:
class TestFSDPTrainer(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
@require_fsdp
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
Expand All @@ -86,6 +89,35 @@ def test_trainer(self):
# successful return here == success - any errors would have caused an error in the sub-call


class TestFSDPTrainerFP8(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
@require_fsdp
@require_fp8
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{torch.cuda.device_count()}",
"--mixed_precision",
"fp8",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--output_dir",
f"{output_dir}",
"--report_to",
"none",
]
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call


if __name__ == "__main__":
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
Expand Down