diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e1babf7556b7f2..82fe86ef93a939 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -86,9 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput): class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): super().__init__() - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) def forward(self, image_features): hidden_states = self.linear_1(image_features) @@ -287,7 +291,9 @@ def get_image_features( image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ self.vision_tower = self.vision_tower.to("cuda") - image_outputs = self.vision_tower([[im.to("cuda") for im in sample] for sample in pixel_values], output_hidden_states=True) + image_outputs = self.vision_tower( + [[im.to("cuda") for im in sample] for sample in pixel_values], output_hidden_states=True + ) self.vision_tower = self.vision_tower.to("cpu") # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. selected_image_feature = image_outputs.hidden_states[vision_feature_layer] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 10de58da93407e..90c38895b4280b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -242,8 +242,6 @@ def forward( hidden_states = self.input_layernorm(hidden_states) - - # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, @@ -578,7 +576,6 @@ def forward( **flash_attn_kwargs, ) - hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index 1a2e0a14736bdf..448393daaff54b 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -13,8 +13,9 @@ # limitations under the License. import argparse import json -import regex as re import os + +import regex as re import torch from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from safetensors.torch import load_file as safe_load_file @@ -28,7 +29,6 @@ PixtralImageProcessor, PixtralProcessor, PixtralVisionConfig, - PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import bytes_to_unicode @@ -159,6 +159,7 @@ def converted(self) -> Tokenizer: def convert_mistral_tokenizer(model_file): from transformers import LlamaTokenizer + mistral_tokenizer = MistralTokenizer.from_file(model_file) vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab() control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens @@ -204,25 +205,35 @@ def convert_dictionnary(original_state_dict, vision_config, text_config): new_dict[new_key] = value return new_dict + MISTRAL_CONFIG_MAPPING = { - "dim":"hidden_size", - "hidden_dim":"intermediate_size", - "n_kv_heads":"num_key_value_heads", - "n_heads":"num_attention_heads", - "n_layers":"num_hidden_layers" + "dim": "hidden_size", + "hidden_dim": "intermediate_size", + "n_kv_heads": "num_key_value_heads", + "n_heads": "num_attention_heads", + "n_layers": "num_hidden_layers", } + def convert_mistral_model(input_dir, output_dir): vision_config = {} if os.path.isfile(f"{input_dir}/params.json"): - with open(f'{input_dir}/params.json') as f: param_json = json.load(f) + with open(f"{input_dir}/params.json") as f: + param_json = json.load(f) vision_config = param_json.pop("vision_encoder") - for k,v in MISTRAL_CONFIG_MAPPING.items(): + for k, v in MISTRAL_CONFIG_MAPPING.items(): value = param_json.pop(k) param_json[v] = value if "hidden_act" not in vision_config: vision_config["hidden_act"] = "silu" - text_config = MistralConfig(**param_json, hidden_act="silu", sliding_window=None,tie_word_embeddings=False, is_composition=True, rms_norm_eps=1e-5) + text_config = MistralConfig( + **param_json, + hidden_act="silu", + sliding_window=None, + tie_word_embeddings=False, + is_composition=True, + rms_norm_eps=1e-5, + ) else: text_config = MistralConfig( attention_dropout=0.0, @@ -258,8 +269,8 @@ def convert_mistral_model(input_dir, output_dir): config.architectures = ["LlavaForConditionalGeneration"] config.save_pretrained(output_dir) full_original_state_dict = {} - safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")]) - if len(safetensors_files)==1: + safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")]) + if len(safetensors_files) == 1: full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors") else: for file in safetensors_files: @@ -271,7 +282,8 @@ def convert_mistral_model(input_dir, output_dir): model = LlavaForConditionalGeneration(config) model.load_state_dict(new_dict, strict=True, assign=True) model.save_pretrained(output_dir) - + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -282,10 +294,7 @@ def main(): "--output_dir", help="Location to write HF model and tokenizer", ) - parser.add_argument( - "--tokenizer_file", - help="Location of the specific tokenizer model file to use." - ) + parser.add_argument("--tokenizer_file", help="Location of the specific tokenizer model file to use.") parser.add_argument( "--chat_template_file", help="Optional file containing a raw chat template. Will be set as the processor's chat template.", @@ -302,6 +311,5 @@ def main(): processor.save_pretrained(args.output_dir) - if __name__ == "__main__": main() diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index fa66ce5b0f1f36..69eb7c28ee24c9 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -37,7 +37,7 @@ validate_kwargs, validate_preprocess_arguments, ) -from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging +from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging from ...utils.import_utils import requires_backends diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index a117cf47614df6..24e067294e0dd2 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -197,7 +197,7 @@ def forward( attn_weights = torch.matmul(query_states.float(), key_states.float().transpose(2, 3)) * self.scale if attention_mask is not None: - attn_weights = attn_weights + attention_mask + attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) @@ -504,9 +504,8 @@ def forward( position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) - attention_mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) out = self.transformer(patch_embeds, attention_mask, position_embedding) - return out \ No newline at end of file + return out diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index f1898412b6d8a6..38f3201247e24d 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -22,7 +22,7 @@ from ...image_utils import ImageInput, is_valid_image, load_image from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends +from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends logger = logging.get_logger(__name__)