Skip to content

Commit

Permalink
[draft] Test positional embedding and patch resizing
Browse files Browse the repository at this point in the history
  • Loading branch information
filipradenovic committed Mar 8, 2023
1 parent d93415c commit 5bdea21
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
86 changes: 86 additions & 0 deletions diht/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

import hashlib
import logging
import math
import os
import urllib
from pathlib import Path

import torch

from omegaconf import OmegaConf
from timm.models.layers import to_2tuple
from torchvision.transforms import InterpolationMode
from tqdm import tqdm

Expand Down Expand Up @@ -84,6 +86,12 @@ def _download_model_ckpt(checkpoint_url, root):

def _load_model_ckpt(model, checkpoint_path):
checkpoint_state_dict = torch.load(checkpoint_path)
resize_pos_embed(
checkpoint_state_dict, model, interpolation="bilinear", align_corners=False
)
resize_patch_embed(
checkpoint_state_dict, model, interpolation="bilinear", align_corners=False
)
model.load_state_dict(checkpoint_state_dict)
return model

Expand Down Expand Up @@ -153,3 +161,81 @@ def load_model(model_name, is_train=False, download_root=None):
model.eval()

return text_tokenizer, transform, model


def resize_pos_embed(
state_dict, model, interpolation: str = "bicubic", align_corners: bool = False
):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("visual.positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = (
1 # FIXME detect different token configs (ie no class token, or more)
)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return

if extra_tokens:
pos_emb_tok, pos_emb_img = (
old_pos_embed[:extra_tokens],
old_pos_embed[extra_tokens:],
)
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))

logging.info(
"Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size
)
pos_emb_img = pos_emb_img.reshape(
1, old_grid_size[0], old_grid_size[1], -1
).permute(0, 3, 1, 2)
pos_emb_img = torch.nn.functional.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=align_corners,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(
1, grid_size[0] * grid_size[1], -1
)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["visual.positional_embedding"] = new_pos_embed


def resize_patch_embed(
state_dict, model, interpolation: str = "bicubic", align_corners: bool = False
):
# interpolate patch embeddings
old_patch_embed = state_dict.get("visual.conv1.weight", None)
if old_patch_embed is None:
return

height, width = old_patch_embed.shape[-2:]
patch_size = model.visual.patch_size[0]
assert (
height == width
), f"Patch embeddings are supposed to be square, got height {height} width {width}"

if height == patch_size:
logging.info("Patch size of pretrained weights matches config.")
# Patch size already matches config, no need to interpolate embs
pass
else:
# Interpolate patch embeddings to match config
logging.info(
f"Resizing patch size from ({height}, {width}) to ({patch_size}, {patch_size})"
)
old_patch_embed = torch.nn.functional.interpolate(
old_patch_embed,
size=(patch_size, patch_size),
mode=interpolation,
align_corners=align_corners,
)
state_dict["visual.conv1.weight"] = old_patch_embed
22 changes: 22 additions & 0 deletions diht/model_zoo_configs/diht_vitb16_448px.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
checkpoint_url: https://dl.fbaipublicfiles.com/diht/diht_vitb16_224px_d2f68609e5d0469c824cd16353375cf7c47b468666fdfcd8baecf06705e4a6f0.ckpt
image_transform:
image_size: 448
mean: [0.48145466, 0.4578275, 0.40821073]
std: [0.26862954, 0.26130258, 0.27577711]
text_transform:
text_tokenizer_name: "clip"
model_cfg:
name: DiHT
params:
embed_dim: 512
vision_cfg:
image_size: 448
layers: 12
width: 768
patch_size: 16
text_cfg:
context_length: 77
vocab_size: 49408
width: 512
heads: 8
layers: 12
22 changes: 22 additions & 0 deletions diht/model_zoo_configs/diht_vitb8_224px.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
checkpoint_url: https://dl.fbaipublicfiles.com/diht/diht_vitb16_224px_d2f68609e5d0469c824cd16353375cf7c47b468666fdfcd8baecf06705e4a6f0.ckpt
image_transform:
image_size: 224
mean: [0.48145466, 0.4578275, 0.40821073]
std: [0.26862954, 0.26130258, 0.27577711]
text_transform:
text_tokenizer_name: "clip"
model_cfg:
name: DiHT
params:
embed_dim: 512
vision_cfg:
image_size: 224
layers: 12
width: 768
patch_size: 8
text_cfg:
context_length: 77
vocab_size: 49408
width: 512
heads: 8
layers: 12

0 comments on commit 5bdea21

Please sign in to comment.