Skip to content

Commit

Permalink
[DPOTrainer] Fix peft + DPO + bf16 if one uses `generate_during_eva…
Browse files Browse the repository at this point in the history
…l` or pre-computed logits (#1203)

* fix peft + DPO + bf16

* fix

* revert old behaviour

* fix tests

* fix

* fix

* fix

* fix
  • Loading branch information
younesbelkada authored Jan 9, 2024
1 parent a236c57 commit d116887
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 41 deletions.
125 changes: 124 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from trl import DPOTrainer

from .testing_utils import require_no_wandb, require_peft
from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft


class DPOTrainerTester(unittest.TestCase):
Expand Down Expand Up @@ -313,3 +313,126 @@ def test_dpo_lora_save(self):
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")

@require_peft
@require_bitsandbytes
@mark.peft_test
def test_dpo_lora_bf16_autocast_llama(self):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig

model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
bf16=True,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
generate_during_eval=True,
)

# train the model
trainer.train()

# save peft adapter
trainer.save_model()

@parameterized.expand(
[
["gpt2", "sigmoid", False, False],
["gpt2", "sigmoid", False, True],
["gpt2", "sigmoid", True, False],
["gpt2", "sigmoid", True, True],
["gpt2", "ipo", False, False],
["gpt2", "ipo", False, True],
["gpt2", "ipo", True, False],
["gpt2", "ipo", True, True],
["gpt2", "kto_pair", False, False],
["gpt2", "kto_pair", False, True],
["gpt2", "kto_pair", True, False],
["gpt2", "kto_pair", True, True],
]
)
@require_bitsandbytes
@require_peft
@mark.peft_test
def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_eval):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig

lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
bf16=True,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
generate_during_eval=gen_during_eval,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute,
)

# train the model
trainer.train()

# save peft adapter
trainer.save_model()
28 changes: 16 additions & 12 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

import torch

from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
from trl import (
is_bitsandbytes_available,
is_diffusers_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)


def require_peft(test_case):
Expand All @@ -27,6 +33,15 @@ def require_peft(test_case):
return test_case


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bnb. Skips the test if bnb is not available.
"""
if not is_bitsandbytes_available():
test_case = unittest.skip("test requires bnb")(test_case)
return test_case


def require_diffusers(test_case):
"""
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
Expand Down Expand Up @@ -55,17 +70,6 @@ def require_no_wandb(test_case):
return require_wandb(test_case, required=False)


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
"""
try:
import bitsandbytes # noqa: F401
except ImportError:
test_case = unittest.skip("test requires bitsandbytes")(test_case)
return test_case


def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs.
Expand Down
1 change: 1 addition & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_bitsandbytes_available,
is_diffusers_available,
is_npu_available,
is_peft_available,
Expand Down
5 changes: 4 additions & 1 deletion trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def is_diffusers_available() -> bool:


def is_bitsandbytes_available() -> bool:
return importlib.util.find_spec("bitsandbytes") is not None
import torch

# bnb can be imported without GPU but is not usable.
return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()


def is_torchvision_available() -> bool:
Expand Down
69 changes: 44 additions & 25 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def __init__(
)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)

# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
Expand Down Expand Up @@ -230,6 +234,8 @@ def make_inputs_require_grad(module, input, output):
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True

# For models that use gradient_checkpoiting, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
Expand Down Expand Up @@ -726,8 +732,10 @@ def null_ref_context(self):

def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

# compute reference logps
with torch.no_grad():
with torch.no_grad(), compte_ref_context_manager():
if self.ref_model is None:
with self.null_ref_context():
(
Expand Down Expand Up @@ -1040,7 +1048,11 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# force log the metrics
if self.accelerator.is_main_process:
Expand All @@ -1053,35 +1065,40 @@ def compute_loss(
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
# If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast

with generate_context_manager():
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

# if reference_output in batch use that otherwise use the reference model
if "reference_output" in batch:
reference_output = batch["reference_output"]
else:
if self.ref_model is None:
with self.null_ref_context():
reference_output = self.model.generate(
# if reference_output in batch use that otherwise use the reference model
if "reference_output" in batch:
reference_output = batch["reference_output"]
else:
if self.ref_model is None:
with self.null_ref_context():
reference_output = self.model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
reference_output = self.ref_model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
reference_output = self.ref_model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
Expand Down Expand Up @@ -1109,7 +1126,9 @@ def prediction_step(
else:
ignore_keys = []

with torch.no_grad():
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

with torch.no_grad(), prediction_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

# force log the metrics
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ def peft_module_casting_to_bf16(model):
for name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
module = module.to(torch.bfloat16)
if "norm" in name:
elif isinstance(module, torch.nn.LayerNorm) or "norm" in name:
module = module.to(torch.float32)
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"):
if module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
Expand Down

0 comments on commit d116887

Please sign in to comment.