From d57a181163e2a16c626b6fb7fbcddd391018f3d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 31 Oct 2024 23:10:11 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=A9=20Add=20`optimizer=5Fcls=5Fand=5Fk?= =?UTF-8?q?wargs`=20attribute=20to=20`PPOTrainer`=20and=20`RLOOTrainer`=20?= =?UTF-8?q?(#2302)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/ppo_trainer.py | 1 + trl/trainer/rloo_trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index b36be8ffff..5b3b6f05f5 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -131,6 +131,7 @@ def __init__( self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 ######### # calculate various batch sizes diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 941a90e0a7..7bbd39264d 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -110,6 +110,7 @@ def __init__( self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 ######### # calculate various batch sizes