Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knowledge distillation for vision guide #25619

Merged
merged 29 commits into from
Oct 18, 2023

Conversation

merveenoyan
Copy link
Contributor

This is a draft PR that I opened in the past on KD guide for CV, but I accidentally removed my fork. I prioritized TGI docs so this PR might stay stale for a while, I will ask for a review after I iterate over comments left by @sayakpaul in my previous PR. (Mainly training MobileNet with random initial weights and not with pre-trained weights from transformers)

@merveenoyan merveenoyan marked this pull request as ready for review September 12, 2023 14:15
@merveenoyan
Copy link
Contributor Author

@sayakpaul I changed the setup and didn't observe a lot of difference, but I felt like it would be still cool to show how to distill a model. WDYT?

@amyeroberts
Copy link
Collaborator

cc @rafaelpadilla for reference

Copy link
Contributor

@rafaelpadilla rafaelpadilla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic to see knowledge distillation being discussed—such an exciting topic! 🚀
Just shared a few comments and suggestions that might enhance readability. Most are related to writing style.
I appreciate the straightforward example you've provided. 👍

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@merveenoyan
Copy link
Contributor Author

@rafaelpadilla @NielsRogge can we merge this if this looks good?

@rafaelpadilla
Copy link
Contributor

@rafaelpadilla @NielsRogge can we merge this if this looks good?

Yes, it's OK to me.
My comments were merely about writing style

dataset = load_dataset("beans")
```

We can use either of the processors given they return the same output. We will use `map()` method of `dataset` to apply the preprocessing to every split of the dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is actually not true, ResNet and MobileNet each have their own image processors

Copy link
Contributor Author

@merveenoyan merveenoyan Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do return the same thing because processor just does preprocessing on same resolution. Check this out.

from transformers import AutoFeatureExtractor
from PIL import Image
import requests
import numpy as np

teacher_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
student_extractor = AutoFeatureExtractor.from_pretrained("google/mobilenet_v2_1.4_224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
sample = Image.open(requests.get(url, stream=True).raw)

np.array_equal(teacher_extractor(sample),student_extractor(sample))
# True

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing this up! ❤️

@LysandreJik LysandreJik requested a review from MKhalusova October 5, 2023 08:29
Copy link
Contributor

@MKhalusova MKhalusova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the guide! While reading it I had a few questions that I feel other folks may have, and it would be great to address them :)

processed_datasets = dataset.map(process, batched=True)
```

Essentially, we want the student model (a randomly initialized MobileNet) to mimic the teacher model (pre-trained ResNet). To achieve this, we first get the logits output by the teacher and the student. Then, we divide each of them by the parameter `temperature`, which controls the importance of each soft target. We will use the KL loss to compute the divergence between the student and teacher. A parameter called `lambda` weighs the importance of the distillation loss. In this example, we will use `temperature=5` and `lambda=0.5`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool to link KL loss to some page that gives a definition of what that is for people who are not familiar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're customizing the Trainer, it would also be nice to link to this page https://huggingface.co/docs/transformers/en/main_classes/trainer#trainer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first sentence would be great to have somewhere in the introduction - how the distillation works. Something like: "To distill knowledge from one model to another, we take a pre-trained teacher model, and randomly initialize a student model. Next, we train the student model to minimize the difference between its outputs and the teacher's outputs, thus making it mimic the behavior. "


```python
trainer.evaluate(processed_datasets["test"])
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also push the final model to hub?
trainer.push_to_hub()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the final model is pushed already when we set push_to_hub to True (I also have save strategy enabled for every epoch so it's triggered every epoch as well), no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK trainer.push_to_hub() also creates a basic model card, e.g. with metrics, and some training results.

Copy link
Contributor

@MKhalusova MKhalusova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating on this! This revision looks fantastic :)

Copy link
Contributor

@rafaelpadilla rafaelpadilla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. :)

@merveenoyan
Copy link
Contributor Author

@LysandreJik can you give a review or ask for another reviewer if needed?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @merveenoyan!

@LysandreJik
Copy link
Member

Please resolve the merge conflicts and merge @merveenoyan

@LysandreJik LysandreJik merged commit 280c757 into huggingface:main Oct 18, 2023
3 checks passed
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Knowledge distillation for vision guide

* Update knowledge_distillation_for_image_classification.md

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Rafael Padilla <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Rafael Padilla <[email protected]>

* Iterated on Rafael's comments

* Added to toctree

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Rafael Padilla <[email protected]>

* Addressed comments

* Update knowledge_distillation_for_image_classification.md

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Rafael Padilla <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: NielsRogge <[email protected]>

* Update knowledge_distillation_for_image_classification.md

* Update knowledge_distillation_for_image_classification.md

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Maria Khalusova <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Maria Khalusova <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Maria Khalusova <[email protected]>

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: Maria Khalusova <[email protected]>

* Address comments

* Update knowledge_distillation_for_image_classification.md

* Explain KL Div

---------

Co-authored-by: Rafael Padilla <[email protected]>
Co-authored-by: NielsRogge <[email protected]>
Co-authored-by: Maria Khalusova <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants