From 574e68d554b1b52503e49708faa3cb88e86447fb Mon Sep 17 00:00:00 2001 From: Apoorv Khandelwal Date: Thu, 11 Jul 2024 17:13:06 -0400 Subject: [PATCH] Allow `Trainer.get_optimizer_cls_and_kwargs` to be overridden (#31875) * Change `Trainer.get_optimizer_cls_and_kwargs` to `self.` * Make `get_optimizer_cls_and_kwargs` an instance method * Fixing typo * Revert `get_optimizer_cls_and_kwargs` to staticmethod * restore newline to trainer.py eof --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 70f4745f9a30bb..485f6cd61e0e7b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1079,7 +1079,7 @@ def create_optimizer(self): }, ] - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer.