Skip to content

Commit

Permalink
Add Sequence-Level KD (#2220)
Browse files Browse the repository at this point in the history
* Fix templates for dpo, etc.

* Update dpo.py

Add the third issue fixs

* make this a utility.

* Add Sequence-Level KD

* add to the docs-strings and the documentation

* reviewed

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <[email protected]>

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent 70036bf commit 7f0d246
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.

## Usage tips

The GKD Trainer is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs two parameters to be set via the [`GKDConfig`] namely:
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.

The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class GKDConfig(SFTConfig):
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether or not to disable dropouts in `model`.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
"""

temperature: float = 0.9
Expand All @@ -50,6 +53,7 @@ class GKDConfig(SFTConfig):
teacher_model_name_or_path: Optional[str] = None
teacher_model_init_kwargs: Optional[Dict[str, Any]] = None
disable_dropout: bool = True
seq_kd: bool = False

def __post_init__(self):
super().__post_init__()
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self.lmbda = args.lmbda
self.beta = args.beta
self.temperature = args.temperature
self.seq_kd = args.seq_kd

self.generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
Expand Down Expand Up @@ -280,6 +281,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
With probability `self.lmbda`, it generates new responses using the student model,
which are then used for training instead of the original inputs.
"""
if self.seq_kd:
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
if random.random() <= self.lmbda:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
Expand Down

0 comments on commit 7f0d246

Please sign in to comment.