Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MobileCLIP S1 & S2 models via timm integration #886

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion src/open_clip/big_vision.py → src/open_clip/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
"""
from typing import Union

import torch
import numpy as np

from .model import CustomTextCLIP
from .model import CLIP, CustomTextCLIP
from .transformer import TextTransformer, Transformer


Expand Down Expand Up @@ -134,3 +138,53 @@ def _convert_openclip_txt(module: TextTransformer, prefix):
model.logit_scale.copy_(_n2p(w['params/t'])[0])


@torch.no_grad()
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):

def _convert_timm_img(state_dict):
if fastvit:
from timm.models.fastvit import checkpoint_filter_fn
else:
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
return timm_state_dict

def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
text_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('projection_layer', 'text_projection')
k = k.replace('embedding_layer', 'token_embedding')
if k.startswith('positional_embedding.pos_embed.pos_embed'):
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
v = v.squeeze()
k = k.replace('final_layer_norm', 'ln_final')
k = k.replace('pre_norm_mha.0', 'ln_1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'ln_2')
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
k = k.replace('qkv_proj.weight', 'in_proj_weight')
k = k.replace('qkv_proj.bias', 'in_proj_bias')
k = k.replace('transformer.', 'transformer.resblocks.')
text_dict['text.' + k] = v
return text_dict

image_dict = _convert_timm_img(state_dict)
text_dict = _convert_openclip_txt(state_dict)
out_dict = {**image_dict, **text_dict}
out_dict['logit_scale'] = state_dict['logit_scale']
return out_dict


def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
state_dict = convert_mobile_clip_state_dict(model, state_dict)
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
# convert b model
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
return state_dict
21 changes: 18 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .convert import convert_state_dict
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .coca_model import CoCa
Expand Down Expand Up @@ -139,25 +140,39 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):
return state_dict


def load_checkpoint(model, checkpoint_path, strict=True):
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
from .big_vision import load_big_vision_weights
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.convert import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}

state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format

# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)

# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)

# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])

# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]

resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)

# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys

Expand Down
2 changes: 0 additions & 2 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def encode_text(self, text, normalize: bool = False):
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]

x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x, _ = text_global_pool(x, text, self.text_pool_type)
if self.text_projection is not None:
Expand Down
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/MobileCLIP-B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_mci_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": false
},
"custom_text": true
}
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/MobileCLIP-S1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci1",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/MobileCLIP-S2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci2",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-B-LTT.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_base_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vitamin_base_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-L-256.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-L-336.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-L.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-L2-256.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-L2-336.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-L2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
20 changes: 20 additions & 0 deletions src/open_clip/model_configs/ViTamin-S-LTT.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_small_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
Loading
Loading