diff --git a/setup.py b/setup.py index 313c958c6e..cb13dd902b 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,7 @@ "diffusers": ["diffusers>=0.18.0"], "judges": ["openai>=1.23.2", "llm-blender>=0.0.2"], # liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility - "liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"], + "liger": ["liger-kernel>=0.5.1; sys_platform != 'win32'"], "mergekit": ["mergekit>=0.0.5.1"], "peft": ["peft>=0.8.0"], "quantization": ["bitsandbytes"], diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index d2eaee3947..a4f5ad6973 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -19,7 +19,7 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft +from transformers.testing_utils import require_liger_kernel, require_peft from trl import ORPOConfig, ORPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE @@ -148,3 +148,36 @@ def test_orpo_trainer_with_lora(self, config_name): # check the params have changed - ignore 0 biases if param.sum() != 0: self.assertFalse(torch.equal(param, new_param)) + + @require_liger_kernel + def test_orpo_trainer_with_liger(self): + """Test ORPO trainer with Liger loss enabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + report_to="none", + use_liger_loss=True, # Enable Liger loss + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = ORPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index b7e2ef7ad0..1cd5fa0a00 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -61,6 +61,11 @@ class ORPOConfig(TrainingArguments): string. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from the + model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. """ learning_rate: float = 1e-6 @@ -76,3 +81,5 @@ class ORPOConfig(TrainingArguments): is_encoder_decoder: Optional[bool] = None model_init_kwargs: Optional[dict[str, Any]] = None dataset_num_proc: Optional[int] = None + use_liger_loss: bool = False + base_model_attribute_name: str = "model" diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 50392526db..ac21f40831 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -47,7 +47,7 @@ ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_fx_proxy +from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -68,7 +68,6 @@ if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training - if is_wandb_available(): import wandb @@ -78,6 +77,9 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss + class ORPOTrainer(Trainer): r""" @@ -357,6 +359,15 @@ def make_inputs_require_grad(module, input, output): "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) + # Import Liger loss if enabled + if self.args.use_liger_loss: + if not is_liger_kernel_available(): + raise ValueError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin @@ -752,53 +763,90 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - labels = concatenated_batch["concatenated_labels"].clone() - chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - - all_logps = self.get_batch_logps( - all_logits, - labels, - average_log_prob=True, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + if self.args.use_liger_loss: + # skip the lm head and get the last hidden state + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name) + outputs = base_model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + lm_head = model.get_output_embeddings() + + # return the final loss and aux_outputs tuple + loss, aux_outputs = self.orpo_loss_fn( + lm_head.weight, + outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + concatenated_batch["concatenated_labels"][:, 1:] + if not self.is_encoder_decoder + else concatenated_batch["concatenated_labels"], + lm_head.bias if hasattr(lm_head, "bias") else None, + ) - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] + if self.aux_loss_enabled: + loss += self.aux_loss_coef * outputs.aux_loss - if not self.is_encoder_decoder: - chosen_logits = all_logits[:len_chosen, :-1, :] - rejected_logits = all_logits[len_chosen:, :-1, :] + return loss, aux_outputs else: - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + output_hidden_states=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch["concatenated_labels"].clone() + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + labels, + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1, :] + rejected_logits = all_logits[len_chosen:, :-1, :] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + outputs.aux_loss, + ) - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) def get_batch_loss_metrics( self, @@ -810,21 +858,41 @@ def get_batch_loss_metrics( metrics = {} forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] + if self.args.use_liger_loss: + # full ORPO loss and aux outputs + ( + loss, + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + chosen_rewards, + rejected_rewards, + log_odds_ratio, + log_odds_chosen, + ), + ) = forward_output + else: + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - policy_chosen_logps, policy_rejected_logps - ) - # full ORPO loss - loss = policy_nll_loss - losses.mean() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss reward_accuracies = (chosen_rewards > rejected_rewards).float() @@ -844,8 +912,6 @@ def get_batch_loss_metrics( xm.mark_step() # needed because .item() calls for k, v in metrics.items(): metrics[k] = v.item() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss return loss, metrics @@ -857,7 +923,6 @@ def compute_loss( num_items_in_batch=None, ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() - with compute_loss_context_manager: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")