diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b96f8e0c8a..dd9420e01a 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -41,7 +41,13 @@ from ..import_utils import is_peft_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model -from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging +from .utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + peft_module_casting_to_bf16, + trl_sanitze_kwargs_for_tagging, +) if is_peft_available(): @@ -216,6 +222,8 @@ def make_inputs_require_grad(module, input, output): # get peft model with the given config model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) # For models that use gradient_checkpoiting, we need to attach a hook that enables input # to explicitly have `requires_grad=True`, otherwise training will either silently