Skip to content

Commit

Permalink
timm to pytorch conversion for vit model fix (#26908)
Browse files Browse the repository at this point in the history
* timm to pytorch conversion for vit model fix

* remove unecessary print statments

* Detect non-supported ViTs in transformers & better handle id2label mapping

* detect non supported hybrid resnet-vit models in conversion script

* remove check for overlap between cls token and pos embed
  • Loading branch information
staghado authored Nov 20, 2023
1 parent e66984f commit 93f2de8
Showing 1 changed file with 62 additions and 57 deletions.
119 changes: 62 additions & 57 deletions src/transformers/models/vit/convert_vit_timm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@


import argparse
import json
from pathlib import Path

import requests
import timm
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from timm.data import ImageNetInfo, infer_imagenet_subset

from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
from transformers.utils import logging
Expand Down Expand Up @@ -60,13 +59,11 @@ def create_rename_keys(config, base_model=False):
)

if base_model:
# layernorm + pooler
# layernorm
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
("pre_logits.fc.weight", "pooler.dense.weight"),
("pre_logits.fc.bias", "pooler.dense.bias"),
]
)

Expand Down Expand Up @@ -140,60 +137,68 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
# define default ViT configuration
config = ViTConfig()
base_model = False
# dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size
if vit_name[-5:] == "in21k":
base_model = True
config.patch_size = int(vit_name[-12:-10])
config.image_size = int(vit_name[-9:-6])
else:
config.num_labels = 1000
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.patch_size = int(vit_name[-6:-4])
config.image_size = int(vit_name[-3:])
# size of the architecture
if "deit" in vit_name:
if vit_name[9:].startswith("tiny"):
config.hidden_size = 192
config.intermediate_size = 768
config.num_hidden_layers = 12
config.num_attention_heads = 3
elif vit_name[9:].startswith("small"):
config.hidden_size = 384
config.intermediate_size = 1536
config.num_hidden_layers = 12
config.num_attention_heads = 6
else:
pass
else:
if vit_name[4:].startswith("small"):
config.hidden_size = 768
config.intermediate_size = 2304
config.num_hidden_layers = 8
config.num_attention_heads = 8
elif vit_name[4:].startswith("base"):
pass
elif vit_name[4:].startswith("large"):
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
elif vit_name[4:].startswith("huge"):
config.hidden_size = 1280
config.intermediate_size = 5120
config.num_hidden_layers = 32
config.num_attention_heads = 16

# load original model from timm
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()

# load state_dict of original model, remove and rename some keys
# detect unsupported ViT models in transformers
# fc_norm is present
if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")

# use of global average pooling in combination (or without) class token
if getattr(timm_model, "global_pool", None) == "avg":
raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")

# CLIP style vit with norm_pre layer present
if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
raise ValueError(
f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
)

# SigLIP style vit with attn_pool layer present
if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
raise ValueError(
f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
)

# use of layer scale in ViT model blocks
if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
):
raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")

# Hybrid ResNet-ViTs
if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed):
raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.")

# get patch size and image size from the patch embedding submodule
config.patch_size = timm_model.patch_embed.patch_size[0]
config.image_size = timm_model.patch_embed.img_size[0]

# retrieve architecture-specific parameters from the timm model
config.hidden_size = timm_model.embed_dim
config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
config.num_hidden_layers = len(timm_model.blocks)
config.num_attention_heads = timm_model.blocks[0].attn.num_heads

# check whether the model has a classification head or not
if timm_model.num_classes != 0:
config.num_labels = timm_model.num_classes
# infer ImageNet subset from timm model
imagenet_subset = infer_imagenet_subset(timm_model)
dataset_info = ImageNetInfo(imagenet_subset)
config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
config.label2id = {v: k for k, v in config.id2label.items()}
else:
print(f"{vit_name} is going to be converted as a feature extractor only.")
base_model = True

# load state_dict of original model
state_dict = timm_model.state_dict()

# remove and rename some keys in the state dict
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model)
Expand All @@ -202,8 +207,8 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
read_in_q_k_v(state_dict, config, base_model)

# load HuggingFace model
if vit_name[-5:] == "in21k":
model = ViTModel(config).eval()
if base_model:
model = ViTModel(config, add_pooling_layer=False).eval()
else:
model = ViTForImageClassification(config).eval()
model.load_state_dict(state_dict)
Expand All @@ -219,8 +224,8 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):

if base_model:
timm_pooled_output = timm_model.forward_features(pixel_values)
assert timm_pooled_output.shape == outputs.pooler_output.shape
assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
assert timm_pooled_output.shape == outputs.last_hidden_state.shape
assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
else:
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
Expand Down

0 comments on commit 93f2de8

Please sign in to comment.