diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 70f4745f9a30bb..485f6cd61e0e7b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1079,7 +1079,7 @@ def create_optimizer(self): }, ] - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer.