Skip to content

Commit

Permalink
Integrate Liger CPO & SimPO
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Dec 20, 2024
1 parent 8c49ea3 commit 703272f
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 60 deletions.
51 changes: 50 additions & 1 deletion tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft

from trl import CPOConfig, CPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -154,3 +154,52 @@ def test_cpo_trainer_with_lora(self, config_name):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@parameterized.expand(
[
("qwen", "sigmoid", "standard_preference"),
("qwen", "simpo", "standard_preference"),
("t5", "simpo", "standard_implicit_prompt_preference"),
]
)
@require_liger_kernel
def test_orpo_trainer_with_liger(self, name, loss_type, config_name):
"""Test ORPO trainer with Liger loss enabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CPOConfig(
output_dir=tmp_dir,
report_to="none",
loss_type=loss_type,
use_liger_loss=True, # Enable Liger loss
)

dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

if name == "qwen":
model = self.model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
tokenizer = self.t5_tokenizer
training_args.is_encoder_decoder = True

trainer = CPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
7 changes: 7 additions & 0 deletions trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class CPOConfig(TrainingArguments):
string.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from the
model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
"""

learning_rate: float = 1e-6
Expand All @@ -95,3 +100,5 @@ class CPOConfig(TrainingArguments):
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_num_proc: Optional[int] = None
use_liger_loss: bool = False
base_model_attribute_name: str = "model"
187 changes: 128 additions & 59 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
Expand All @@ -68,6 +68,10 @@
import wandb


if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss, LigerFusedLinearSimPOLoss


class CPOTrainer(Trainer):
r"""
Initialize CPOTrainer.
Expand Down Expand Up @@ -362,6 +366,24 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Import Liger loss if enabled
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
if args.loss_type == "sigmoid":
self.cpo_loss_fn = LigerFusedLinearCPOLoss(
ignore_index=self.label_pad_token_id, beta=self.beta, alpha=self.cpo_alpha
)
elif args.loss_type == "simpo":
self.cpo_loss_fn = LigerFusedLinearSimPOLoss(
ignore_index=self.label_pad_token_id, beta=self.beta, alpha=self.cpo_alpha, gamma=self.simpo_gamma
)
else:
raise ValueError("Liger loss is only available for sigmoid and simpo loss types.")

def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
Expand Down Expand Up @@ -736,53 +758,84 @@ def concatenated_forward(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()

if self.cpo_alpha == 0:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
if self.args.use_liger_loss:
# skip the lm head and get the last hidden state
# skip the lm head and get the last hidden state
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs = base_model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()

# return the final loss and aux_outputs tuple
loss, aux_outputs = self.cpo_loss_fn(
lm_head.weight,
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
)

if self.aux_loss_enabled:
loss += self.aux_loss_coef * outputs.aux_loss

return loss, aux_outputs

else:
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()

if self.cpo_alpha == 0:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
else:
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -794,22 +847,41 @@ def get_batch_loss_metrics(
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self.args.use_liger_loss:
# full CPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)

losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss

if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
Expand All @@ -823,9 +895,6 @@ def get_batch_loss_metrics(
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

return loss, metrics

def compute_loss(
Expand Down

0 comments on commit 703272f

Please sign in to comment.