Skip to content

Commit

Permalink
Enable customized optimizer for DeepSpeed (huggingface#32049)
Browse files Browse the repository at this point in the history
* transformers: enable custom optimizer for DeepSpeed

* transformers: modify error message

---------

Co-authored-by: datakim1201 <[email protected]>
  • Loading branch information
dataKim1201 and datakim1201 authored Oct 7, 2024
1 parent 7bae833 commit 55be7c4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,11 @@ def __init__(
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if (self.is_deepspeed_enabled or self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
"Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled. "
"Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
Expand Down

0 comments on commit 55be7c4

Please sign in to comment.