Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Integrate Liger CPO & SimPO #2506

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft

from trl import CPOConfig, CPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -154,3 +154,52 @@ def test_cpo_trainer_with_lora(self, config_name):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@parameterized.expand(
[
("qwen", "sigmoid", "standard_preference"),
("qwen", "simpo", "standard_preference"),
("t5", "simpo", "standard_implicit_prompt_preference"),
]
)
@require_liger_kernel
def test_orpo_trainer_with_liger(self, name, loss_type, config_name):
"""Test ORPO trainer with Liger loss enabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CPOConfig(
output_dir=tmp_dir,
report_to="none",
loss_type=loss_type,
use_liger_loss=True, # Enable Liger loss
)

dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

if name == "qwen":
model = self.model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
tokenizer = self.t5_tokenizer
training_args.is_encoder_decoder = True

trainer = CPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
7 changes: 7 additions & 0 deletions trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class CPOConfig(TrainingArguments):
string.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from the
model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
"""

learning_rate: float = 1e-6
Expand All @@ -95,3 +100,5 @@ class CPOConfig(TrainingArguments):
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_num_proc: Optional[int] = None
use_liger_loss: bool = False
base_model_attribute_name: str = "model"
194 changes: 135 additions & 59 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
Expand All @@ -68,6 +68,10 @@
import wandb


if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss, LigerFusedLinearSimPOLoss


class CPOTrainer(Trainer):
r"""
Initialize CPOTrainer.
Expand Down Expand Up @@ -362,6 +366,31 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Import Liger loss if enabled
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
if args.loss_type == "sigmoid":
self.cpo_loss_fn = LigerFusedLinearCPOLoss(
ignore_index=self.label_pad_token_id,
beta=self.beta,
alpha=self.cpo_alpha,
label_smoothing=self.label_smoothing,
)
elif args.loss_type == "simpo":
self.cpo_loss_fn = LigerFusedLinearSimPOLoss(
ignore_index=self.label_pad_token_id,
beta=self.beta,
alpha=self.cpo_alpha,
gamma=self.simpo_gamma,
label_smoothing=self.label_smoothing,
)
else:
raise ValueError("Liger loss is only available for sigmoid and simpo loss types.")

def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
Expand Down Expand Up @@ -736,53 +765,84 @@ def concatenated_forward(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()

if self.cpo_alpha == 0:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
if self.args.use_liger_loss:
# skip the lm head and get the last hidden state
# skip the lm head and get the last hidden state
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs = base_model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()

# return the final loss and aux_outputs tuple
loss, aux_outputs = self.cpo_loss_fn(
lm_head.weight,
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
)

if self.aux_loss_enabled:
loss += self.aux_loss_coef * outputs.aux_loss

return loss, aux_outputs

else:
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()

if self.cpo_alpha == 0:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
else:
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -794,22 +854,41 @@ def get_batch_loss_metrics(
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self.args.use_liger_loss:
# full CPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)

losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss

if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
Expand All @@ -823,9 +902,6 @@ def get_batch_loss_metrics(
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

return loss, metrics

def compute_loss(
Expand Down