Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knowledge distillation for vision guide #25619

Merged
merged 29 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
aad783d
Knowledge distillation for vision guide
merveenoyan Aug 20, 2023
01080da
Update knowledge_distillation_for_image_classification.md
merveenoyan Sep 12, 2023
6eca2fa
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Sep 15, 2023
b595a1a
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Sep 15, 2023
36979c6
Iterated on Rafael's comments
merveenoyan Sep 15, 2023
06d7659
Added to toctree
merveenoyan Sep 15, 2023
8702960
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 2, 2023
742dc93
Addressed comments
merveenoyan Oct 2, 2023
23085f0
Update knowledge_distillation_for_image_classification.md
merveenoyan Oct 2, 2023
ed113cd
Merge branch 'main' into knowledge-distillation
merveenoyan Oct 2, 2023
c01e4cd
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 3, 2023
4e46a06
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 4, 2023
10cc3e0
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 4, 2023
5c36920
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 4, 2023
cacbe86
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 5, 2023
3bc1928
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 5, 2023
f07351b
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 5, 2023
836cb90
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 5, 2023
c8b5098
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 5, 2023
ea0b75e
Update knowledge_distillation_for_image_classification.md
merveenoyan Oct 5, 2023
c4bce38
Update knowledge_distillation_for_image_classification.md
merveenoyan Oct 5, 2023
70c1c1b
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 9, 2023
04582a5
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 9, 2023
1cb7469
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 9, 2023
72a419d
Update docs/source/en/tasks/knowledge_distillation_for_image_classifi…
merveenoyan Oct 9, 2023
518017d
Address comments
merveenoyan Oct 12, 2023
0be1027
Update knowledge_distillation_for_image_classification.md
merveenoyan Oct 12, 2023
9cae56d
Explain KL Div
merveenoyan Oct 12, 2023
f0c2a9e
Merge branch 'main' into knowledge-distillation
merveenoyan Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,24 @@
- local: tasks/asr
title: Automatic speech recognition
title: Audio
- isExpanded: false
sections:
- local: tasks/image_classification
title: Image classification
- local: tasks/semantic_segmentation
title: Semantic segmentation
- local: tasks/video_classification
title: Video classification
- local: tasks/object_detection
title: Object detection
- local: tasks/zero_shot_object_detection
title: Zero-shot object detection
- local: tasks/zero_shot_image_classification
title: Zero-shot image classification
- local: tasks/monocular_depth_estimation
title: Depth estimation
isExpanded: false
- sections:
- local: tasks/image_classification
title: Image classification
- local: tasks/semantic_segmentation
title: Semantic segmentation
- local: tasks/video_classification
title: Video classification
- local: tasks/object_detection
title: Object detection
- local: tasks/zero_shot_object_detection
title: Zero-shot object detection
- local: tasks/zero_shot_image_classification
title: Zero-shot image classification
- local: tasks/monocular_depth_estimation
title: Depth estimation
- local: tasks/knowledge_distillation_for_image_classification
title: Knowledge Distillation for Computer Vision
title: Computer Vision
- isExpanded: false
sections:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
# Knowledge Distillation for Computer Vision
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved

Knowledge distillation is a technique used to transfer knowledge from a larger, more complex model (teacher) to a smaller, simpler model (student). In the context of image classification, the goal is to train a student model which mimics the behavior of the teacher model. It was first introduced in [Distilling the Knowledge in a Neural Network by Hinton et al](https://arxiv.org/abs/1503.02531). In this notebook, we will do task-specific knowledge distillation. We will use the [beans dataset](https://huggingface.co/datasets/beans) for this.
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved

This tutorial aims to demonstrate how to distill a [ResNet](https://huggingface.co/microsoft/resnet-50) (teacher model) to a [MobileNet](https://huggingface.co/google/mobilenet_v2_1.4_224) (student model) using the `Trainer` API of 🤗 Transformers.
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved

Let's install the libraries needed for distillation and evaluating the process.

```bash
pip install transformers datasets accelerate tensorboard evaluate --upgrade
```

In this example, we are using the `microsoft/resnet-50` model, trained on the [ImageNet-1k dataset](https://huggingface.co/datasets/imagenet-1k) with a resolution of 224x224.
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved

```python
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
from PIL import Image
import requests
import numpy as np

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
sample = Image.open(requests.get(url, stream=True).raw)
```

We will now load and pre-process the dataset.

```python
from datasets import load_dataset

dataset = load_dataset("beans")
```

We can use either of the processors, given they return the same output. We will use the `map()` method of `dataset` to apply the preprocessing to every split of the dataset.
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved

```python
from transformers import AutoImageProcessor
teacher_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

def process(examples):
processed_inputs = teacher_processor(examples["image"])
return processed_inputs

processed_datasets = dataset.map(process, batched=True)
```

Essentially, we want the student model (a randomly initialized MobileNet) to mimic the teacher model (pre-trained ResNet). To achieve this, we first get the logits output by the teacher and the student. Then, we divide each of them by the parameter `temperature`, which controls the importance of each soft target. We will use the KL loss to compute the divergence between the student and teacher. A parameter called `lambda` weighs the importance of the distillation loss. In this example, we will use `temperature=5` and `lambda=0.5`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool to link KL loss to some page that gives a definition of what that is for people who are not familiar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're customizing the Trainer, it would also be nice to link to this page https://huggingface.co/docs/transformers/en/main_classes/trainer#trainer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first sentence would be great to have somewhere in the introduction - how the distillation works. Something like: "To distill knowledge from one model to another, we take a pre-trained teacher model, and randomly initialize a student model. Next, we train the student model to minimize the difference between its outputs and the teacher's outputs, thus making it mimic the behavior. "



merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
```python
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F


class ImageDistilTrainer(Trainer):
def __init__(self, *args, teacher_model=None, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher_model
self.student = student_model
self.loss_function = nn.KLDivLoss(reduction="batchmean")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.teacher.to(device)
self.teacher.eval()
self.temperature = temperature
self.lambda_param = lambda_param

def compute_loss(self, student, inputs, return_outputs=False):
student_output = self.student(**inputs)

with torch.no_grad():
teacher_output = self.teacher(**inputs)

# Compute soft targets for teacher and student
soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

# Compute the loss
distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

# Compute the true label loss
student_target_loss = student_output.loss

# Calculate final loss
loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
return (loss, student_output) if return_outputs else loss
```

We will now login to Hugging Face Hub so we can push our model to the Hugging Face Hub through the `Trainer`.

```python
from huggingface_hub import notebook_login

notebook_login()
```

Let's initialize the data collator and `Trainer`. We will use `DefaultDataCollator` for this.
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved

```python
from transformers import AutoModelForImageClassification, DefaultDataCollator, MobileNetV2Config, MobileNetV2ForImageClassification

training_args = TrainingArguments(
output_dir="my-awesome-model",
num_train_epochs=30,
fp16=True,
logging_dir=f"{repo_name}/logs",
logging_strategy="epoch",
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to="tensorboard",
push_to_hub=True,
hub_strategy="every_save",
hub_model_id=repo_name,
)

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
"microsoft/resnet-50",
num_labels=num_labels,
ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)
```

We can use `compute_metrics` function to evaluate our model on the test set. This function will be used during the training process to compute the `accuracy` & `f1` of our model.

```python
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
predictions, labels = eval_pred
acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
return {"accuracy": acc["accuracy"]}
```

Let's initialize the `Trainer` with the training arguments we defined.

```python
data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
student_model=student_model,
teacher_model=teacher_model,
training_args=training_args,
train_dataset=processed_datasets["train"],
eval_dataset=processed_datasets["validation"],
data_collator=data_collator,
tokenizer=teacher_extractor,
compute_metrics=compute_metrics,
temperature=5,
lambda_param=0.5
)
```

We can now train our model.

```python
trainer.train()
```

We can evaluate the model on the test set.

```python
trainer.evaluate(processed_datasets["test"])
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also push the final model to hub?
trainer.push_to_hub()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the final model is pushed already when we set push_to_hub to True (I also have save strategy enabled for every epoch so it's triggered every epoch as well), no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK trainer.push_to_hub() also creates a basic model card, e.g. with metrics, and some training results.


On test set, our model reaches 65 percent accuracy. To have a sanity check over efficiency of distillation, we also trained MobileNet on the beans dataset from scratch with the same hyperparameters and observed 63 percent accuracy on the test set. We invite the readers to try different pre-trained teacher models, student architectures, distillation parameters and report their findings. The training logs can be found in [this repository](https://huggingface.co/merve/resnet-mobilenet-beans).