From 703272faf696d66ff6dcc9e7e797edea69fd1d0c Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Thu, 19 Dec 2024 22:11:57 -0800 Subject: [PATCH] Integrate Liger CPO & SimPO Signed-off-by: Mecoli1219 --- tests/test_cpo_trainer.py | 51 +++++++++- trl/trainer/cpo_config.py | 7 ++ trl/trainer/cpo_trainer.py | 187 +++++++++++++++++++++++++------------ 3 files changed, 185 insertions(+), 60 deletions(-) diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index 744ee07aa6..2648927014 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -19,7 +19,7 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft +from transformers.testing_utils import require_liger_kernel, require_peft from trl import CPOConfig, CPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE @@ -154,3 +154,52 @@ def test_cpo_trainer_with_lora(self, config_name): # check the params have changed - ignore 0 biases if param.sum() != 0: self.assertFalse(torch.equal(param, new_param)) + + @parameterized.expand( + [ + ("qwen", "sigmoid", "standard_preference"), + ("qwen", "simpo", "standard_preference"), + ("t5", "simpo", "standard_implicit_prompt_preference"), + ] + ) + @require_liger_kernel + def test_orpo_trainer_with_liger(self, name, loss_type, config_name): + """Test ORPO trainer with Liger loss enabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CPOConfig( + output_dir=tmp_dir, + report_to="none", + loss_type=loss_type, + use_liger_loss=True, # Enable Liger loss + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + if name == "qwen": + model = self.model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + training_args.is_encoder_decoder = True + + trainer = CPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) diff --git a/trl/trainer/cpo_config.py b/trl/trainer/cpo_config.py index a288a8c75f..20b07e4941 100644 --- a/trl/trainer/cpo_config.py +++ b/trl/trainer/cpo_config.py @@ -76,6 +76,11 @@ class CPOConfig(TrainingArguments): string. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from the + model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. """ learning_rate: float = 1e-6 @@ -95,3 +100,5 @@ class CPOConfig(TrainingArguments): is_encoder_decoder: Optional[bool] = None model_init_kwargs: Optional[dict[str, Any]] = None dataset_num_proc: Optional[int] = None + use_liger_loss: bool = False + base_model_attribute_name: str = "model" diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 6d236cfb37..8a31e272df 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -44,7 +44,7 @@ ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_fx_proxy +from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from .cpo_config import CPOConfig @@ -68,6 +68,10 @@ import wandb +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss, LigerFusedLinearSimPOLoss + + class CPOTrainer(Trainer): r""" Initialize CPOTrainer. @@ -362,6 +366,24 @@ def make_inputs_require_grad(module, input, output): "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) + # Import Liger loss if enabled + if self.args.use_liger_loss: + if not is_liger_kernel_available(): + raise ValueError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + if args.loss_type == "sigmoid": + self.cpo_loss_fn = LigerFusedLinearCPOLoss( + ignore_index=self.label_pad_token_id, beta=self.beta, alpha=self.cpo_alpha + ) + elif args.loss_type == "simpo": + self.cpo_loss_fn = LigerFusedLinearSimPOLoss( + ignore_index=self.label_pad_token_id, beta=self.beta, alpha=self.cpo_alpha, gamma=self.simpo_gamma + ) + else: + raise ValueError("Liger loss is only available for sigmoid and simpo loss types.") + def build_tokenized_answer(self, prompt, answer): """ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. @@ -736,53 +758,84 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - labels = concatenated_batch["concatenated_labels"].clone() - - if self.cpo_alpha == 0: - nll_loss = torch.tensor(0.0).to(self.accelerator.device) + if self.args.use_liger_loss: + # skip the lm head and get the last hidden state + # skip the lm head and get the last hidden state + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name) + outputs = base_model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + lm_head = model.get_output_embeddings() + + # return the final loss and aux_outputs tuple + loss, aux_outputs = self.cpo_loss_fn( + lm_head.weight, + outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + concatenated_batch["concatenated_labels"][:, 1:] + if not self.is_encoder_decoder + else concatenated_batch["concatenated_labels"], + lm_head.bias if hasattr(lm_head, "bias") else None, + ) + + if self.aux_loss_enabled: + loss += self.aux_loss_coef * outputs.aux_loss + + return loss, aux_outputs + else: - nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch["concatenated_labels"].clone() + + if self.cpo_alpha == 0: + nll_loss = torch.tensor(0.0).to(self.accelerator.device) + else: + nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - all_logps = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=self.loss_type in ["ipo", "simpo"], - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=self.loss_type in ["ipo", "simpo"], + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) def get_batch_loss_metrics( self, @@ -794,22 +847,41 @@ def get_batch_loss_metrics( metrics = {} forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] + if self.args.use_liger_loss: + # full CPO loss and aux outputs + ( + loss, + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + chosen_rewards, + rejected_rewards, + ), + ) = forward_output + else: + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards = self.cpo_loss( + policy_chosen_logps, + policy_rejected_logps, + ) - losses, chosen_rewards, rejected_rewards = self.cpo_loss( - policy_chosen_logps, - policy_rejected_logps, - ) + loss = losses.mean() + self.cpo_alpha * policy_nll_loss + + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss - loss = losses.mean() + self.cpo_alpha * policy_nll_loss reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" @@ -823,9 +895,6 @@ def get_batch_loss_metrics( metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - return loss, metrics def compute_loss(