diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index b73c5f346dba57..c7df0356f2c984 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -60,13 +60,13 @@ def create_rename_keys(config, base_model=False): ) if base_model: - # layernorm + pooler + # layernorm rename_keys.extend( [ ("norm.weight", "layernorm.weight"), ("norm.bias", "layernorm.bias"), - ("pre_logits.fc.weight", "pooler.dense.weight"), - ("pre_logits.fc.bias", "pooler.dense.bias"), + # ("pre_logits.fc.weight", "pooler.dense.weight"), + # ("pre_logits.fc.bias", "pooler.dense.bias"), ] ) @@ -140,60 +140,39 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): # define default ViT configuration config = ViTConfig() base_model = False - # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size - if vit_name[-5:] == "in21k": - base_model = True - config.patch_size = int(vit_name[-12:-10]) - config.image_size = int(vit_name[-9:-6]) - else: - config.num_labels = 1000 + + # load original model from timm + timm_model = timm.create_model(vit_name, pretrained=True) + timm_model.eval() + + # get patch size and image size from the patch embedding submodule + config.patch_size = timm_model.patch_embed.patch_size[0] + config.image_size = timm_model.patch_embed.img_size[0] + + # retrieve architecture-specific parameters from the timm model + config.hidden_size = timm_model.embed_dim + config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features + config.num_hidden_layers = len(timm_model.blocks) + config.num_attention_heads = timm_model.blocks[0].attn.num_heads + + # check whether the model has a classification head or not + if timm_model.num_classes != 0: + config.num_labels = timm_model.num_classes repo_id = "huggingface/label-files" - filename = "imagenet-1k-id2label.json" + # .__ceil__() avoids having to import math + filename = f"imagenet-{(config.num_labels / 1000).__ceil__()}k-id2label.json" id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} - config.patch_size = int(vit_name[-6:-4]) - config.image_size = int(vit_name[-3:]) - # size of the architecture - if "deit" in vit_name: - if vit_name[9:].startswith("tiny"): - config.hidden_size = 192 - config.intermediate_size = 768 - config.num_hidden_layers = 12 - config.num_attention_heads = 3 - elif vit_name[9:].startswith("small"): - config.hidden_size = 384 - config.intermediate_size = 1536 - config.num_hidden_layers = 12 - config.num_attention_heads = 6 - else: - pass else: - if vit_name[4:].startswith("small"): - config.hidden_size = 768 - config.intermediate_size = 2304 - config.num_hidden_layers = 8 - config.num_attention_heads = 8 - elif vit_name[4:].startswith("base"): - pass - elif vit_name[4:].startswith("large"): - config.hidden_size = 1024 - config.intermediate_size = 4096 - config.num_hidden_layers = 24 - config.num_attention_heads = 16 - elif vit_name[4:].startswith("huge"): - config.hidden_size = 1280 - config.intermediate_size = 5120 - config.num_hidden_layers = 32 - config.num_attention_heads = 16 - - # load original model from timm - timm_model = timm.create_model(vit_name, pretrained=True) - timm_model.eval() + print(f"{vit_name} is going to be converted as a feature extractor only. This is not guaranteed to work.") + base_model = True - # load state_dict of original model, remove and rename some keys + # load state_dict of original model state_dict = timm_model.state_dict() + + # remove and rename some keys in the state dict if base_model: remove_classification_head_(state_dict) rename_keys = create_rename_keys(config, base_model) @@ -202,8 +181,9 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model - if vit_name[-5:] == "in21k": - model = ViTModel(config).eval() + if base_model: + model = ViTModel(config, add_pooling_layer=False).eval() + # print(model.state_dict().keys()) else: model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) @@ -219,8 +199,10 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) - assert timm_pooled_output.shape == outputs.pooler_output.shape - assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) + print(timm_pooled_output) + print(outputs.last_hidden_state) + assert timm_pooled_output.shape == outputs.last_hidden_state.shape + assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) else: timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape