Skip to content

Commit

Permalink
Detect non-supported ViTs in transformers & better handle id2label ma…
Browse files Browse the repository at this point in the history
…pping
  • Loading branch information
staghado committed Oct 25, 2023
1 parent 330eaf2 commit d5ce0e8
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions src/transformers/models/vit/convert_vit_timm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@


import argparse
import json
from pathlib import Path

import requests
import timm
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from timm.data import ImageNetInfo, infer_imagenet_subset

from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
from transformers.utils import logging
Expand Down Expand Up @@ -65,8 +64,6 @@ def create_rename_keys(config, base_model=False):
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
# ("pre_logits.fc.weight", "pooler.dense.weight"),
# ("pre_logits.fc.bias", "pooler.dense.bias"),
]
)

Expand Down Expand Up @@ -145,6 +142,35 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()

# detect unsupported ViT models in transformers
# fc_norm is present
if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")

# use of global average pooling in combination (or without) class token
if getattr(timm_model, "global_pool", None) == "avg":
raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")

# CLIP style vit with norm_pre layer present
if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
raise ValueError(
f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
)

# SigLIP style vit with attn_pool layer present
if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
raise ValueError(
f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
)

# use of layer scale in ViT model blocks
if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
):
raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")

# non-overlapping position and class token embedding (to be added)

# get patch size and image size from the patch embedding submodule
config.patch_size = timm_model.patch_embed.patch_size[0]
config.image_size = timm_model.patch_embed.img_size[0]
Expand All @@ -158,15 +184,13 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
# check whether the model has a classification head or not
if timm_model.num_classes != 0:
config.num_labels = timm_model.num_classes
repo_id = "huggingface/label-files"
# .__ceil__() avoids having to import math
filename = f"imagenet-{(config.num_labels / 1000).__ceil__()}k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
# infer ImageNet subset from timm model
imagenet_subset = infer_imagenet_subset(timm_model)
dataset_info = ImageNetInfo(imagenet_subset)
config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
config.label2id = {v: k for k, v in config.id2label.items()}
else:
print(f"{vit_name} is going to be converted as a feature extractor only. This is not guaranteed to work.")
print(f"{vit_name} is going to be converted as a feature extractor only.")
base_model = True

# load state_dict of original model
Expand Down

0 comments on commit d5ce0e8

Please sign in to comment.