Skip to content

Commit

Permalink
Final weights
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Oct 25, 2024
1 parent 0c8aa0a commit 41a4733
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 57 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
"models.aria": [
"AriaConfig",
"AriaTextConfig",
"AriaProcessor",
],
"models.audio_spectrogram_transformer": [
"ASTConfig",
Expand Down Expand Up @@ -5023,6 +5024,7 @@
from .models.aria import (
AriaConfig,
AriaTextConfig,
AriaProcessor,
)
from .models.audio_spectrogram_transformer import (
ASTConfig,
Expand Down
67 changes: 27 additions & 40 deletions src/transformers/models/aria/convert_aria_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
AutoImageProcessor,
AutoTokenizer,
LlavaProcessor,
SiglipVisionConfig,
Idefics3VisionConfig,
AriaProcessor,
)
from huggingface_hub import login

login("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")

Check warning on line 34 in src/transformers/models/aria/convert_aria_weights_to_hf.py

View workflow job for this annotation

GitHub Actions / trufflehog

Found verified HuggingFace result 🐷🔑

EPILOG_TXT = """Example:
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id lmsys/vicuna-7b-v1.5 --vision_model_id openai/clip-vit-large-patch14-336 --output_hub_path org/aria-v1.5-7b-conv --old_state_dict_id liuhaotian/aria-v1.5-7b
Expand All @@ -50,15 +53,7 @@
"""

KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.": "",
".vision_resampler": "", # all lmms-lab models do avg pooling, so no vision_resampler
"model.mm_projector": "multi_modal_projector",
"model": "model.model",
"vision_model.model": "vision_model",
"lm_head": "language_model.lm_head",
"model.model": "language_model.model",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
"vision_tower.vision_model": "vision_tower",
}


Expand All @@ -72,13 +67,6 @@ def load_original_state_dict(model_id):
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)

# tied wieghts so lm.head is not saved. Let's clone to load state dict
if "lm_head.weight" not in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()

if "model.image_newline" in original_state_dict:
# not used in the original implementation because "merge_type=flat"
del original_state_dict["model.image_newline"]
return original_state_dict


Expand All @@ -94,33 +82,33 @@ def convert_state_dict_to_hf(state_dict):
key = key.replace(key_to_modify, new_key)

new_state_dict[key] = value
new_state_dict['vision_tower.post_layernorm.weight'] = torch.zeros((1152,))
new_state_dict['vision_tower.post_layernorm.bias'] = torch.zeros((1152,))

return new_state_dict


def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
text_config = AutoConfig.from_pretrained(text_model_id)
text_config = AutoConfig.from_pretrained(text_model_id).text_config

tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
if "Qwen" not in text_model_id: # qwen already has a pad token
tokenizer.add_special_tokens({"pad_token": "<pad>"})

image_processor = AutoImageProcessor.from_pretrained(vision_model_id)
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)

if "siglip" in vision_model_id:
vision_config = SiglipVisionConfig(
hidden_size=1152,
image_size=384,
intermediate_size=4304,
num_attention_heads=16,
num_hidden_layers=26,
patch_size=14,
vision_use_head=False,
).to_dict()
else:
vision_config = None
processor = AriaProcessor.from_pretrained(
text_model_id, tokenizer_path=text_model_id,
)

vision_config = Idefics3VisionConfig(
hidden_size=1152,
image_size=980,
intermediate_size=4304,
num_attention_heads=16,
num_hidden_layers=27,
patch_size=14,
).to_dict()

config = AriaConfig(
text_config=text_config,
Expand All @@ -140,14 +128,10 @@ def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, ol
with torch.device("meta"):
model = AriaForConditionalGeneration(config)

if "Qwen" in text_model_id:
state_dict = load_original_state_dict(old_state_dict_id)
else:
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
state_dict = torch.load(state_dict_path, map_location="cpu")
state_dict = load_original_state_dict(old_state_dict_id)

state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=True, assign=True)
model.load_state_dict(state_dict, strict=False, assign=True)

pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
Expand All @@ -169,7 +153,6 @@ def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, ol
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
dim=0,
)

model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)

Expand All @@ -181,18 +164,22 @@ def main():
)
parser.add_argument(
"--text_model_id",
default="rhymes-ai/Aria",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
default="rhymes-ai/Aria",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
default="m-ric/Aria_hf",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
default="rhymes-ai/Aria",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
Expand Down
20 changes: 3 additions & 17 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,18 +2481,6 @@ class AriaCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None


class Idefics3Wrapper(AriaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.vision_model = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.post_init()

def forward(self, pixel_values, **kwargs):
return self.vision_model(pixel_values, **kwargs)


class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
"""
Aria model for conditional generation tasks.
Expand All @@ -2509,12 +2497,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
def __init__(self, config: AriaConfig):
super().__init__(config)

self.vision_tower = Idefics3Wrapper(
config
self.vision_tower = AutoModel.from_config(
config.vision_config, attn_implementation=config.vision_config._attn_implementation
)
print("PREFIX", self.vision_tower.base_model_prefix)
print(dir(self.vision_tower))
# self.vision_tower.base_model_prefix = "vision_tower.vision_model"

self.multi_modal_projector = AriaProjector(
patch_to_query_dict=config.projector_patch_to_query_dict,
embed_dim=config.vision_config.hidden_size,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/aria/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: ImageInput = None,
audio= None,
videos = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
Expand Down

0 comments on commit 41a4733

Please sign in to comment.