Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger] add native liger-kernel orpo loss #2482

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"diffusers": ["diffusers>=0.18.0"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"liger": ["liger-kernel>=0.5.1; sys_platform != 'win32'"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
Expand Down
35 changes: 34 additions & 1 deletion tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft

from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -148,3 +148,36 @@ def test_orpo_trainer_with_lora(self, config_name):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_liger_kernel
def test_orpo_trainer_with_liger(self):
"""Test ORPO trainer with Liger loss enabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(

This comment was marked as outdated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the plot was with the same parameters with and without the liger_loss flag... rather than in general... so trying to figure out why there is a difference between the two settings...

output_dir=tmp_dir,
report_to="none",
use_liger_loss=True, # Enable Liger loss
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = ORPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
7 changes: 7 additions & 0 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class ORPOConfig(TrainingArguments):
string.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_class_name (`str`, *optional*, defaults to `"model"`):
The name of the base model class (e.g. `"model"` for `LlamaForCausalLM`) to use for skipping the
LM head when `use_liger_loss` is `True`.
"""

learning_rate: float = 1e-6
Expand All @@ -76,3 +81,5 @@ class ORPOConfig(TrainingArguments):
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_num_proc: Optional[int] = None
use_liger_loss: bool = False
base_model_class_name: str = "model"
180 changes: 120 additions & 60 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
Expand All @@ -68,7 +68,6 @@
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training

kashif marked this conversation as resolved.
Show resolved Hide resolved

if is_wandb_available():
import wandb

Expand All @@ -78,6 +77,9 @@
if is_torch_xla_available():
import torch_xla.core.xla_model as xm

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss


class ORPOTrainer(Trainer):
r"""
Expand Down Expand Up @@ -357,6 +359,15 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Import Liger loss if enabled
kashif marked this conversation as resolved.
Show resolved Hide resolved
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -752,55 +763,87 @@ def concatenated_forward(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
if self.args.use_liger_loss:
# skip the lm head and get the last hidden state

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I guess we don't have much of an option beyond using a config parameter for now.

Given that we run forward pass on a submodule, it would be very nice to have some validation so that there are no unexpected failures etc with different distributed training settings. But in this case, I feel there might be compatibility issues with FSDP given the limitation from the docs: https://pytorch.org/docs/stable/fsdp.html

"FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance. This is because the submodule’s parameters will be sharded, but the submodule itself is not an FSDP instance, so its forward pass will not all-gather the full parameters appropriately."

(might be fixed by just making the base model attribute an FSDP instance as well, coz why not)

Beyond that this looks fine! I have a couple nits (don't matter that much):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @SumanthRH yes you are right get_decoder() will work

Yes next is to verify the distributed training cases

base_model = getattr(model, self.args.base_model_class_name)
outputs = base_model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()

# return the final loss and aux_outputs tuple
loss, aux_outputs = self.orpo_loss_fn(
lm_head.weight,
outputs[0],
concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
)

if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * outputs.aux_loss

return loss, aux_outputs
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
output_hidden_states=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
if self.aux_loss_enabled:
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -812,21 +855,41 @@ def get_batch_loss_metrics(
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self.args.use_liger_loss:
# full ORPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
log_odds_ratio,
log_odds_chosen,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

reward_accuracies = (chosen_rewards > rejected_rewards).float()

Expand All @@ -846,8 +909,6 @@ def get_batch_loss_metrics(
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

return loss, metrics

Expand All @@ -859,7 +920,6 @@ def compute_loss(
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

Expand Down
Loading