Skip to content

Commit

Permalink
Fix FSDP failing
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Dec 11, 2024
1 parent 6181c6b commit e38294e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -2304,12 +2304,13 @@ def _inner_training_loop(
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()

if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
Expand Down Expand Up @@ -4187,7 +4188,7 @@ def evaluation_loop(
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled or self.is_fsdp_enabled
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
Expand Down

0 comments on commit e38294e

Please sign in to comment.