Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

timm to pytorch conversion for vit model fix #26908

Merged
merged 5 commits into from
Nov 20, 2023

Conversation

staghado
Copy link
Contributor

@staghado staghado commented Oct 18, 2023

This PR fixes this issue #26219 with timm to PyTorch conversion. It removes the need for hard coded values for model dims by using the attributes of the timm model without needing the model name.

It does the following things :

  • Extract model dims from the timm model directly, no need for ifs
  • Decides whether the converted model will be a classification model or only a feature extractor using the num_classes attribute of the timm model.
  • In the case of a feature extractor only model : remove the pooling layers from the PyTorch model and compare the output to the last hidden state instead.

This works for a large number of models in the ViT family.

@ArthurZucker, @amyeroberts, @rwightman

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup, LGTM but I need a second look from @rwightman 🤗

ArthurZucker
ArthurZucker previously approved these changes Oct 19, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker ArthurZucker mentioned this pull request Oct 19, 2023
5 tasks
@rwightman
Copy link

rwightman commented Oct 19, 2023

Overall looking much better, should be more robust now.

A few things, there are a number of vit configurations supported in timm that are not, to my knowledge, supported in transformers. Should there be an attempt to detect? Thinking of some examples

  • if fc_norm is present (norm after pooling)
  • use of global average pooling in combination (or without) class token
  • non-overlapping position and class token embedding
  • CLIP style vit with norm_pre layer present
  • SigLIP style vit with attn_pool layer present
  • and soon, use of 'registers' via reg_token param
  • use of layer scale in ViT model blocks

@ArthurZucker ArthurZucker dismissed their stale review October 25, 2023 08:07

The git history is completely messed up!

@ArthurZucker
Copy link
Collaborator

Hey! Make sure to rebase to only have your changes! 😉

@staghado staghado force-pushed the timm-pytorch-conversion-fix branch from 44432e3 to 330eaf2 Compare October 25, 2023 14:21
@staghado
Copy link
Contributor Author

staghado commented Oct 25, 2023

Hey! Make sure to rebase to only have your changes! 😉

I have reset the branch's history and left only my changes which fix the issue here.

@staghado
Copy link
Contributor Author

Overall looking much better, should be more robust now.

A few things, there are a number of vit configurations supported in timm that are not, to my knowledge, supported in transformers. Should there be an attempt to detect? Thinking of some examples

  • if fc_norm is present (norm after pooling)
  • use of global average pooling in combination (or without) class token
  • non-overlapping position and class token embedding
  • CLIP style vit with norm_pre layer present
  • SigLIP style vit with attn_pool layer present
  • and soon, use of 'registers' via reg_token param
  • use of layer scale in ViT model blocks

I have tried to add some checks before trying to convert the model from timm to huggingface.
Checks to be added :
1. non-overlapping position and class token embedding
2. use of 'registers' via reg_token param
3. check when a model has a convolution feature extractor like ResNet50

I have tested the script on the pre-trained ViTs and only the following give errors:

  • vit_base_r50_s16_224.orig_in21k (contains a resnet block)
  • vit_base_r50_s16_384.orig_in21k_ft_in1k (contains a resnet block)
  • vit_small_r26_s32_224.augreg_in21k
  • vit_small_r26_s32_224.augreg_in21k_ft_in1k
  • vit_small_r26_s32_384.augreg_in21k_ft_in1k
  • vit_tiny_r_s16_p8_224.augreg_in21k
  • vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k
  • vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k

@rwightman
Copy link

@staghado looking good, those hybrid resnet-vit models should be possible to catch (see if below) with a meaningful error .. other than looks ready to go

if not isinstance(model.patch_embed, timm.layers.PatchEmbed) ...

Copy link

@rwightman rwightman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good from the timm perspective

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this and making the script more general!

Just a small question on an outstanding to-do. Otherwise LGTM!

src/transformers/models/vit/convert_vit_timm_to_pytorch.py Outdated Show resolved Hide resolved
@staghado
Copy link
Contributor Author

@ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this script! 🚀

@ArthurZucker ArthurZucker merged commit 93f2de8 into huggingface:main Nov 20, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants