diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index a9394cff5a..b4171cf87c 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -19,8 +19,9 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface. ## Usage tips -The GKD Trainer is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs two parameters to be set via the [`GKDConfig`] namely: +The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely: * `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch. +* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher. * `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two. The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. diff --git a/trl/trainer/gkd_config.py b/trl/trainer/gkd_config.py index 4318791d84..7230b29640 100644 --- a/trl/trainer/gkd_config.py +++ b/trl/trainer/gkd_config.py @@ -41,6 +41,9 @@ class GKDConfig(SFTConfig): from a string. disable_dropout (`bool`, *optional*, defaults to `True`): Whether or not to disable dropouts in `model`. + seq_kd (`bool`, *optional*, defaults to `False`): + Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT + on teacher-generated output). """ temperature: float = 0.9 @@ -50,6 +53,7 @@ class GKDConfig(SFTConfig): teacher_model_name_or_path: Optional[str] = None teacher_model_init_kwargs: Optional[Dict[str, Any]] = None disable_dropout: bool = True + seq_kd: bool = False def __post_init__(self): super().__post_init__() diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 7683c8cdf6..1b7c77557d 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -136,6 +136,7 @@ def __init__( self.lmbda = args.lmbda self.beta = args.beta self.temperature = args.temperature + self.seq_kd = args.seq_kd self.generation_config = GenerationConfig( max_new_tokens=args.max_new_tokens, @@ -280,6 +281,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, With probability `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the original inputs. """ + if self.seq_kd: + with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model: + new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( + unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id + ) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_attention_mask + inputs["labels"] = new_labels if random.random() <= self.lmbda: with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(