Skip to content

Commit

Permalink
add peft_module_casting_to_bf16 in DPOTrainer (#1143)
Browse files Browse the repository at this point in the history
* add peft_module_casting_to_bf16 in DPOTrainer

Signed-off-by: Wang, Yi A <[email protected]>

* Update trl/trainer/dpo_trainer.py

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
sywangyi and kashif authored Dec 26, 2023
1 parent 3539f3e commit 95ec857
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@

from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging
from .utils import (
DPODataCollatorWithPadding,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)


if is_peft_available():
Expand Down Expand Up @@ -216,6 +222,8 @@ def make_inputs_require_grad(module, input, output):

# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)

# For models that use gradient_checkpoiting, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
Expand Down

0 comments on commit 95ec857

Please sign in to comment.