diff --git a/src/llamafactory/train/utils.py b/src/llamafactory/train/utils.py index 23834f2d0a..230fdc1eb0 100644 --- a/src/llamafactory/train/utils.py +++ b/src/llamafactory/train/utils.py @@ -379,6 +379,7 @@ def create_custom_scheduler( optimizer=optimizer_dict[param], num_warmup_steps=training_args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, + scheduler_specific_kwargs=training_args.lr_scheduler_kwargs, ) def scheduler_hook(param: "torch.nn.Parameter"):