diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 90bd991926..c9dc6b15f4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -487,7 +487,10 @@ def create_optimizer(self): elif self.args.alternate_optimizer == "soap": from axolotl.utils.optimizers.soap import SOAP - optim_args = {} + optim_args = { + "lr": optimizer_kwargs.pop("lr"), + "eps": optimizer_kwargs.pop("eps"), + } if self.cfg.optim_args: optim_args.update(self.cfg.optim_args) @@ -1600,6 +1603,7 @@ def build(self, total_num_steps): "ao_adamw_4bit", "ao_adamw_8bit", "ao_adamw_fp8", + "soap", ]: # Set default so transformers doesn't throw training_arguments_kwargs["optim"] = "adamw_hf"