Skip to content

Commit

Permalink
timm to pytorch conversion for vit model fix
Browse files Browse the repository at this point in the history
  • Loading branch information
staghado committed Oct 18, 2023
1 parent de55ead commit 1ed2469
Showing 1 changed file with 35 additions and 53 deletions.
88 changes: 35 additions & 53 deletions src/transformers/models/vit/convert_vit_timm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def create_rename_keys(config, base_model=False):
)

if base_model:
# layernorm + pooler
# layernorm
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
("pre_logits.fc.weight", "pooler.dense.weight"),
("pre_logits.fc.bias", "pooler.dense.bias"),
# ("pre_logits.fc.weight", "pooler.dense.weight"),
# ("pre_logits.fc.bias", "pooler.dense.bias"),
]
)

Expand Down Expand Up @@ -140,60 +140,39 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
# define default ViT configuration
config = ViTConfig()
base_model = False
# dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size
if vit_name[-5:] == "in21k":
base_model = True
config.patch_size = int(vit_name[-12:-10])
config.image_size = int(vit_name[-9:-6])
else:
config.num_labels = 1000

# load original model from timm
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()

# get patch size and image size from the patch embedding submodule
config.patch_size = timm_model.patch_embed.patch_size[0]
config.image_size = timm_model.patch_embed.img_size[0]

# retrieve architecture-specific parameters from the timm model
config.hidden_size = timm_model.embed_dim
config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
config.num_hidden_layers = len(timm_model.blocks)
config.num_attention_heads = timm_model.blocks[0].attn.num_heads

# check whether the model has a classification head or not
if timm_model.num_classes != 0:
config.num_labels = timm_model.num_classes
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
# .__ceil__() avoids having to import math
filename = f"imagenet-{(config.num_labels / 1000).__ceil__()}k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.patch_size = int(vit_name[-6:-4])
config.image_size = int(vit_name[-3:])
# size of the architecture
if "deit" in vit_name:
if vit_name[9:].startswith("tiny"):
config.hidden_size = 192
config.intermediate_size = 768
config.num_hidden_layers = 12
config.num_attention_heads = 3
elif vit_name[9:].startswith("small"):
config.hidden_size = 384
config.intermediate_size = 1536
config.num_hidden_layers = 12
config.num_attention_heads = 6
else:
pass
else:
if vit_name[4:].startswith("small"):
config.hidden_size = 768
config.intermediate_size = 2304
config.num_hidden_layers = 8
config.num_attention_heads = 8
elif vit_name[4:].startswith("base"):
pass
elif vit_name[4:].startswith("large"):
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
elif vit_name[4:].startswith("huge"):
config.hidden_size = 1280
config.intermediate_size = 5120
config.num_hidden_layers = 32
config.num_attention_heads = 16

# load original model from timm
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()
print(f"{vit_name} is going to be converted as a feature extractor only. This is not guaranteed to work.")
base_model = True

# load state_dict of original model, remove and rename some keys
# load state_dict of original model
state_dict = timm_model.state_dict()

# remove and rename some keys in the state dict
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model)
Expand All @@ -202,8 +181,9 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
read_in_q_k_v(state_dict, config, base_model)

# load HuggingFace model
if vit_name[-5:] == "in21k":
model = ViTModel(config).eval()
if base_model:
model = ViTModel(config, add_pooling_layer=False).eval()
# print(model.state_dict().keys())
else:
model = ViTForImageClassification(config).eval()
model.load_state_dict(state_dict)
Expand All @@ -219,8 +199,10 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):

if base_model:
timm_pooled_output = timm_model.forward_features(pixel_values)
assert timm_pooled_output.shape == outputs.pooler_output.shape
assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
print(timm_pooled_output)
print(outputs.last_hidden_state)
assert timm_pooled_output.shape == outputs.last_hidden_state.shape
assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
else:
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
Expand Down

0 comments on commit 1ed2469

Please sign in to comment.