Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Knowledge Distillation Base #417

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Dec 2, 2024

Summary

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment.

Code Changes

  1. Refactor beta into two weights: weight_hard_loss and weight_soft_loss, as coefficients between hard_loss and soft_loss. @Tcc0403 also pointed out that we could use torch.lerp if applicable.

  2. Pass teacher_logits and student_logits directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the student_log_probs value calculated during ce_loss in distillation base.

    1. Remove the unnecessary get_batch_logps in test/utils.py.
  3. Modify chunking dimensions from B to B * T. Thanks to @hongpeng-guo's great advice.

    1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension.
  4. Normalize the distillation_loss using (full_target != ignore_index).sum().

TODO

  1. Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.

Knowledge Distillation

Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.

In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let z_t and z_s represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature T. When ground truth labels y are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.

The combined loss function is defined as:

$$\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),$$

Here, we directly pass in logits rather than logpbs. @Tcc0403

Shared DistillationBase

To support various distillation learning objectives, this PR aims to add a LigerFusedLinearDistillationBase which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.

Testing Done

I'll post JSD tests and benchmarks results in next PR.

  • Hardware Type: L40S
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
@austin362667 austin362667 changed the title Feat/distill base Introduce Knowledge Distillation Base Dec 2, 2024
Signed-off-by: Austin Liu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant