Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FSDP no longer working #35212

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -2304,12 +2304,13 @@ def _inner_training_loop(
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()

if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
Expand Down Expand Up @@ -4187,7 +4188,7 @@ def evaluation_loop(
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled or self.is_fsdp_enabled
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
Expand Down
Loading