Skip to content

Commit

Permalink
make fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 20, 2024
1 parent f6d61d3 commit 8228140
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 29 deletions.
12 changes: 9 additions & 3 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
Expand Down Expand Up @@ -287,7 +291,9 @@ def get_image_features(
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
self.vision_tower = self.vision_tower.to("cuda")
image_outputs = self.vision_tower([[im.to("cuda") for im in sample] for sample in pixel_values], output_hidden_states=True)
image_outputs = self.vision_tower(
[[im.to("cuda") for im in sample] for sample in pixel_values], output_hidden_states=True
)
self.vision_tower = self.vision_tower.to("cpu")
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ def forward(

hidden_states = self.input_layernorm(hidden_states)



# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
Expand Down Expand Up @@ -578,7 +576,6 @@ def forward(
**flash_attn_kwargs,
)


hidden_states = layer_outputs[0]

if output_attentions:
Expand Down
44 changes: 26 additions & 18 deletions src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
import argparse
import json
import regex as re
import os

import regex as re
import torch
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from safetensors.torch import load_file as safe_load_file
Expand All @@ -28,7 +29,6 @@
PixtralImageProcessor,
PixtralProcessor,
PixtralVisionConfig,
PreTrainedTokenizerFast,
)
from transformers.convert_slow_tokenizer import bytes_to_unicode

Expand Down Expand Up @@ -159,6 +159,7 @@ def converted(self) -> Tokenizer:

def convert_mistral_tokenizer(model_file):
from transformers import LlamaTokenizer

mistral_tokenizer = MistralTokenizer.from_file(model_file)
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab()
control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens
Expand Down Expand Up @@ -204,25 +205,35 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
new_dict[new_key] = value
return new_dict


MISTRAL_CONFIG_MAPPING = {
"dim":"hidden_size",
"hidden_dim":"intermediate_size",
"n_kv_heads":"num_key_value_heads",
"n_heads":"num_attention_heads",
"n_layers":"num_hidden_layers"
"dim": "hidden_size",
"hidden_dim": "intermediate_size",
"n_kv_heads": "num_key_value_heads",
"n_heads": "num_attention_heads",
"n_layers": "num_hidden_layers",
}


def convert_mistral_model(input_dir, output_dir):
vision_config = {}
if os.path.isfile(f"{input_dir}/params.json"):
with open(f'{input_dir}/params.json') as f: param_json = json.load(f)
with open(f"{input_dir}/params.json") as f:
param_json = json.load(f)
vision_config = param_json.pop("vision_encoder")
for k,v in MISTRAL_CONFIG_MAPPING.items():
for k, v in MISTRAL_CONFIG_MAPPING.items():
value = param_json.pop(k)
param_json[v] = value
if "hidden_act" not in vision_config:
vision_config["hidden_act"] = "silu"
text_config = MistralConfig(**param_json, hidden_act="silu", sliding_window=None,tie_word_embeddings=False, is_composition=True, rms_norm_eps=1e-5)
text_config = MistralConfig(
**param_json,
hidden_act="silu",
sliding_window=None,
tie_word_embeddings=False,
is_composition=True,
rms_norm_eps=1e-5,
)
else:
text_config = MistralConfig(
attention_dropout=0.0,
Expand Down Expand Up @@ -258,8 +269,8 @@ def convert_mistral_model(input_dir, output_dir):
config.architectures = ["LlavaForConditionalGeneration"]
config.save_pretrained(output_dir)
full_original_state_dict = {}
safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")])
if len(safetensors_files)==1:
safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")])
if len(safetensors_files) == 1:
full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
else:
for file in safetensors_files:
Expand All @@ -271,7 +282,8 @@ def convert_mistral_model(input_dir, output_dir):
model = LlavaForConditionalGeneration(config)
model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir)



def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -282,10 +294,7 @@ def main():
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--tokenizer_file",
help="Location of the specific tokenizer model file to use."
)
parser.add_argument("--tokenizer_file", help="Location of the specific tokenizer model file to use.")
parser.add_argument(
"--chat_template_file",
help="Optional file containing a raw chat template. Will be set as the processor's chat template.",
Expand All @@ -302,6 +311,5 @@ def main():
processor.save_pretrained(args.output_dir)



if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
from ...utils.import_utils import requires_backends


Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/pixtral/modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def forward(
attn_weights = torch.matmul(query_states.float(), key_states.float().transpose(2, 3)) * self.scale

if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
Expand Down Expand Up @@ -504,9 +504,8 @@ def forward(

position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)


attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)
out = self.transformer(patch_embeds, attention_mask, position_embedding)
return out
return out
2 changes: 1 addition & 1 deletion src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends


logger = logging.get_logger(__name__)
Expand Down

0 comments on commit 8228140

Please sign in to comment.