diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 419d3d5b1dc2cc..9adb669e2cad66 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -160,7 +160,7 @@ Flax), PyTorch, and/or TensorFlow. | [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ | | [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ | | [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ | -| [IDEFICS](model_doc/idefics) | ✅ | ❌ | ❌ | +| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | | [Informer](model_doc/informer) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/idefics.md b/docs/source/en/model_doc/idefics.md index 9989f89d682e8f..ab66bd555a71d5 100644 --- a/docs/source/en/model_doc/idefics.md +++ b/docs/source/en/model_doc/idefics.md @@ -52,6 +52,16 @@ To train a new IDEFICS model from scratch use the m4 codebase (a link will be pr [[autodoc]] IdeficsForVisionText2Text - forward +## TFIdeficsModel + +[[autodoc]] TFIdeficsModel + - call + +## TFIdeficsForVisionText2Text + +[[autodoc]] TFIdeficsForVisionText2Text + - call + ## IdeficsImageProcessor [[autodoc]] IdeficsImageProcessor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 21222be3fb414a..97a4e89684eb7e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3862,6 +3862,15 @@ "TFHubertPreTrainedModel", ] ) + + _import_structure["models.idefics"].extend( + [ + "TFIdeficsForVisionText2Text", + "TFIdeficsModel", + "TFIdeficsPreTrainedModel", + ] + ) + _import_structure["models.layoutlm"].extend( [ "TFLayoutLMForMaskedLM", @@ -7905,6 +7914,11 @@ TFHubertModel, TFHubertPreTrainedModel, ) + from .models.idefics import ( + TFIdeficsForVisionText2Text, + TFIdeficsModel, + TFIdeficsPreTrainedModel, + ) from .models.layoutlm import ( TFLayoutLMForMaskedLM, TFLayoutLMForQuestionAnswering, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index a3df614b9b7922..756da20dbc51a6 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -58,6 +58,7 @@ ("gptj", "TFGPTJModel"), ("groupvit", "TFGroupViTModel"), ("hubert", "TFHubertModel"), + ("idefics", "TFIdeficsModel"), ("layoutlm", "TFLayoutLMModel"), ("layoutlmv3", "TFLayoutLMv3Model"), ("led", "TFLEDModel"), @@ -112,6 +113,7 @@ ("funnel", "TFFunnelForPreTraining"), ("gpt-sw3", "TFGPT2LMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"), + ("idefics", "TFIdeficsForVisionText2Text"), ("layoutlm", "TFLayoutLMForMaskedLM"), ("lxmert", "TFLxmertForPreTraining"), ("mobilebert", "TFMobileBertForPreTraining"), diff --git a/src/transformers/models/idefics/__init__.py b/src/transformers/models/idefics/__init__.py index 7a4e8056f540d5..3b32064789cabe 100644 --- a/src/transformers/models/idefics/__init__.py +++ b/src/transformers/models/idefics/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tf_available, + is_torch_available, + is_vision_available, +) _import_structure = {"configuration_idefics": ["IdeficsConfig"]} @@ -39,6 +45,17 @@ ] _import_structure["processing_idefics"] = ["IdeficsProcessor"] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_idefics"] = [ + "TFIdeficsForVisionText2Text", + "TFIdeficsModel", + "TFIdeficsPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_idefics import IdeficsConfig @@ -64,6 +81,17 @@ ) from .processing_idefics import IdeficsProcessor + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_idefics import ( + TFIdeficsForVisionText2Text, + TFIdeficsModel, + TFIdeficsPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/idefics/image_processing_idefics.py b/src/transformers/models/idefics/image_processing_idefics.py index ee8dfbb4077c66..f4998020daf642 100644 --- a/src/transformers/models/idefics/image_processing_idefics.py +++ b/src/transformers/models/idefics/image_processing_idefics.py @@ -92,8 +92,9 @@ def preprocess( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, transform: Callable = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, - ) -> TensorType.PYTORCH: + ) -> TensorType: """ Preprocess a batch of images. @@ -162,7 +163,6 @@ def preprocess( images = [self.rescale(image=image, scale=1 / 255) for image in images] images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images] - # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available - images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"] + images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"] return images diff --git a/src/transformers/models/idefics/modeling_tf_idefics.py b/src/transformers/models/idefics/modeling_tf_idefics.py new file mode 100644 index 00000000000000..8d9322b0edc272 --- /dev/null +++ b/src/transformers/models/idefics/modeling_tf_idefics.py @@ -0,0 +1,1812 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 Idefics model. """ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import tensorflow as tf + +from ... import TFPreTrainedModel +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ModelOutput +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + keras_serializable, + shape_list, + unpack_inputs, +) +from ...tf_utils import invert_attention_mask, scaled_dot_product_attention +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_idefics import IdeficsConfig +from .perceiver_tf import TFIdeficsPerceiverResampler +from .vision_tf import TFIdeficsVisionTransformer + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "IdeficsConfig" + + +@dataclass +class TFIdeficsBaseModelOutputWithPast(ModelOutput): + """ + Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(tf.Tensor)`, *optional*): + Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + last_hidden_state: tf.Tensor = None + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + image_hidden_states: Optional[Tuple[tf.Tensor]] = None + + +@dataclass +class TFIdeficsCausalLMOutputWithPast(ModelOutput): + """ + Base class for Idefics causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(tf.Tensor)`, *optional*): + Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[tf.Tensor] = None + logits: tf.Tensor = None + past_key_values: Optional[List[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + image_hidden_states: Optional[Tuple[tf.Tensor]] = None + + +def expand_inputs_for_generation( + input_ids, + expand_size=1, + is_encoder_decoder=False, + attention_mask=None, + encoder_outputs=None, + **model_kwargs, +): + expanded_return_idx = tf.reshape(tf.repeat(tf.range(tf.shape(input_ids)[0]), expand_size), [-1]) + input_ids = tf.gather(input_ids, expanded_return_idx) + model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None) + model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None) + model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None) + model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx) + + if model_kwargs["image_attention_mask"] is not None: + model_kwargs["image_attention_mask"] = tf.gather(model_kwargs["image_attention_mask"], expanded_return_idx) + + if model_kwargs["pixel_values"] is not None: + model_kwargs["pixel_values"] = tf.gather(model_kwargs["pixel_values"], expanded_return_idx) + + elif model_kwargs["image_encoder_embeddings"] is not None: + model_kwargs["image_encoder_embeddings"] = tf.gather( + model_kwargs["image_encoder_embeddings"], expanded_return_idx + ) + + elif model_kwargs["perceiver_embeddings"] is not None: + model_kwargs["perceiver_embeddings"] = tf.gather(model_kwargs["perceiver_embeddings"], expanded_return_idx) + + return input_ids, model_kwargs + + +def update_model_kwargs_for_generation(outputs, model_kwargs): + # must have this key set to at least None + if "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + else: + model_kwargs["past_key_values"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = tf.concat([token_type_ids, token_type_ids[:, -1:, ...]], axis=-1) + + # update attention masks + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = tf.concat( + [attention_mask, tf.ones_like(attention_mask[:, -1:, ...])], axis=-1 + ) + if "image_attention_mask" in model_kwargs: + image_attention_mask = model_kwargs["image_attention_mask"] + last_mask = image_attention_mask[:, -1:, ...] + model_kwargs["image_attention_mask"] = last_mask + + # Get the precomputed image_hidden_states + model_kwargs["image_hidden_states"] = outputs.image_hidden_states + + return model_kwargs + + +def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1:] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + if past_key_values is not None: + position_ids = position_ids[:, -1:] + + pixel_values = kwargs.get("pixel_values", None) + image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) + perceiver_embeddings = kwargs.get("perceiver_embeddings", None) + image_attention_mask = kwargs.get("image_attention_mask", None) + interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "pixel_values": pixel_values, + "image_encoder_embeddings": image_encoder_embeddings, + "perceiver_embeddings": perceiver_embeddings, + "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, + } + + +def freeze_model(model, module_exceptions=[]): + mapping = { + "LayerNorm": tf.keras.layers.LayerNormalization, + "Dense": tf.keras.layers.Dense, + "Embedding": tf.keras.layers.Embedding, + } + module_exceptions_mapped = [mapping[m] for m in module_exceptions] + if not hasattr(model, "layers"): + model.trainable = False # It is just a layer + return model + for layer in model.layers: + if module_exceptions and any(isinstance(layer, t) for t in module_exceptions_mapped): + layer.trainable = True # Explicitly setting it to true to avoid any mistakes + else: + layer.trainable = False + return model + + +class TFIdeficsDecoupledEmbedding(tf.keras.layers.Embedding): + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze: Optional[bool] = False, + dtype=None, + **kwargs, + ) -> None: + """ + Args: + num_embeddings (`int`): + Size of the dictionary of embeddings + num_additional_embeddings (`int`): + Number of additional embeddings. Only useful when you `partially_freeze=True`. + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `False`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + + Note: there are a lot of other parameters to initialize a standard `tf.keras.layers.Embedding` such as `mask_zero`, + `input_length` or `embeddings_initializer`. We are not supporting these. + """ + super().__init__( + input_dim=num_embeddings, + output_dim=embedding_dim, + dtype=dtype, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.trainable = False + + if self.num_additional_embeddings > 0: + self.additional_embedding = tf.keras.layers.Embedding( + input_dim=self.num_additional_embeddings, + output_dim=embedding_dim, + dtype=dtype, + name="additional_embedding", + ) + + def call(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return super().call(input_ids) + + # Clone so that we don't modify the original input_ids later on + input_ids = tf.identity(input_ids) + additional_vocab_indices = tf.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = tf.gather_nd(input_ids, additional_vocab_indices) + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids = tf.tensor_scatter_nd_update( + input_ids, + additional_vocab_indices, + # tensor filled with 0, having the same length as additional_vocab_indices + tf.zeros(tf.shape(additional_vocab_indices)[0], dtype=input_ids.dtype), + ) + full_vector = super().call(input_ids) + + # overwrite the records with high indices + full_vector = tf.tensor_scatter_nd_update(full_vector, additional_vocab_indices, additional_embeddings) + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.output_dim, + self.partially_freeze, + ) + + +class TFIdeficsDecoupledLinear(tf.keras.layers.Layer): + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Dense`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + out_additional_features: int = 0, + bias: bool = True, + partially_freeze: bool = True, + **kwargs, + ) -> None: + """ + out_additional_features: int. Number of additional trainable dimensions. Only makes sense when + `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra + parameters (if any) will be trainable. If False, default to the regular behavior of tf.keras.layers.Dense. + """ + super().__init__(**kwargs) + self.out_additional_features = out_additional_features + self.partially_freeze = partially_freeze + + self.in_features = in_features + self.out_features = out_features + self.use_bias = bias + + if out_additional_features > 0: + self.additional_fc = tf.keras.layers.Dense( + units=out_additional_features, use_bias=bias, name="additional_fc" + ) + + def call(self, inputs: tf.Tensor) -> tf.Tensor: + output = tf.linalg.matmul(a=inputs, b=self.weight, transpose_b=True) + if self.bias is not None: + output = tf.nn.bias_add(output, self.bias) + + if self.out_additional_features > 0: + additional_features = self.additional_fc(inputs) + output = tf.concat([output, additional_features], axis=-1) + + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "in_features": self.in_features, + "out_features": self.out_features, + "out_additional_features": self.out_additional_features, + "bias": self.bias is not None, + "partially_freeze": self.partially_freeze, + } + ) + return config + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.out_features, + self.out_additional_features, + self.bias is not None, + self.partially_freeze, + ) + + @classmethod + def from_config(cls, config): + return cls(**config) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.weight = self.add_weight( + shape=(self.out_features, self.in_features), trainable=not self.partially_freeze, name="weight" + ) + if self.use_bias: + self.bias = self.add_weight(shape=(self.out_features,), trainable=not self.partially_freeze, name="bias") + else: + self.bias = None + if getattr(self, "additional_fc", None) is not None: + with tf.name_scope(self.additional_fc.name): + self.additional_fc.build(self.in_features) + + +def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0): + """ + Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes. + """ + bsz, tgt_len = input_ids_shape + + # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask) + mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min) + mask_cond = tf.range(tgt_len) + mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) + + if bsz is None: + # When batch size is dynamic, expand and tile + # so we can compile a functional model + mask = tf.expand_dims(mask, 0) + mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length) + mask = tf.tile(mask, [bsz, 1, 1, 1]) + else: + # When batch size is static, directly use broadcast_to + mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) + + return mask + + +def _expand_mask(mask, dtype, tgt_len=None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = shape_list(mask) + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1) + expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len]) + + inverted_mask = 1.0 - tf.cast(expanded_mask, dtype) + + return tf.where( + tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask + ) + + +class TFIdeficsRMSNorm(tf.keras.layers.Layer): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + """ + TFIdeficsRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def build(self, input_shape): + if self.built: + return + self.built = True + self.weight = self.add_weight(name="weight", shape=[self.hidden_size], initializer="ones") + + super().build(input_shape) + + def call(self, hidden_states): + variance = tf.math.reduce_mean(tf.math.square(tf.cast(hidden_states, tf.float32)), axis=-1, keepdims=True) + hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [tf.float16, tf.bfloat16]: + hidden_states = tf.cast(hidden_states, self.weight.dtype) + + return self.weight * hidden_states + + +class TFIdeficsEmbedding(tf.keras.layers.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = tf.constant( + 1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) + ) + + def _compute_cos_sin(self, seq_len): + t = tf.range(seq_len, dtype=self.inv_freq.dtype) + freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication + emb = tf.concat((freqs, freqs), axis=-1) + + return tf.cos(emb), tf.sin(emb) + + def call(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len is None: + seq_len = shape_list(x)[2] + return self._compute_cos_sin(seq_len=seq_len) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return tf.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = tf.gather(cos, position_ids) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = tf.gather(sin, position_ids) + cos = tf.expand_dims(cos, 1) + sin = tf.expand_dims(sin, 1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class TFIdeficsMLP(tf.keras.layers.Layer): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + **kwargs, + ): + super().__init__(**kwargs) + self.gate_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="gate_proj") + self.down_proj = tf.keras.layers.Dense(hidden_size, use_bias=False, name="down_proj") + self.up_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="up_proj") + self.act_fn = get_tf_activation(hidden_act) + self.intermediate_size = intermediate_size + self.hidden_size = hidden_size + + def call(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "gate_proj", None) is not None: + with tf.name_scope(self.gate_proj.name): + self.gate_proj.build(self.hidden_size) + if getattr(self, "down_proj", None) is not None: + with tf.name_scope(self.down_proj.name): + self.down_proj.build(self.intermediate_size) + if getattr(self, "up_proj", None) is not None: + with tf.name_scope(self.up_proj.name): + self.up_proj.build(self.hidden_size) + + +class TFIdeficsAttention(tf.keras.layers.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout: float = 0.0, + is_cross_attention: bool = False, + config: IdeficsConfig = None, + qk_layer_norms: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.dropout = dropout + self.config = config + self.is_causal = True + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.is_cross_attention = is_cross_attention + + self.q_proj = tf.keras.layers.Dense( + num_heads * self.head_dim, + use_bias=False, + name="q_proj", + ) + self.k_proj = tf.keras.layers.Dense( + num_heads * self.head_dim, + use_bias=False, + name="k_proj", + ) + self.v_proj = tf.keras.layers.Dense( + num_heads * self.head_dim, + use_bias=False, + name="v_proj", + ) + self.o_proj = tf.keras.layers.Dense( + hidden_size, + use_bias=False, + name="o_proj", + ) + self.rotary_emb = TFIdeficsEmbedding(self.head_dim, name="rotary_emb") + + self.qk_layer_norms = qk_layer_norms + if self.qk_layer_norms: + self.q_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="q_layer_norm") + self.k_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="k_layer_norm") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: Optional[tf.Tensor] = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[tf.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]: + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = self.is_cross_attention or key_value_states is not None + + bsz, q_len, _ = shape_list(hidden_states) + + query_states = self._shape(self.q_proj(hidden_states), q_len, bsz) + if not is_cross_attention: + key_states = self._shape(self.k_proj(hidden_states), q_len, bsz) + value_states = self._shape(self.v_proj(hidden_states), q_len, bsz) + else: + _, kv_len, _ = shape_list(key_value_states) # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = self._shape(self.k_proj(key_value_states), kv_len, bsz) + value_states = self._shape(self.v_proj(key_value_states), kv_len, bsz) + + kv_seq_len = shape_list(key_states)[-2] + if past_key_value is not None: + kv_seq_len += shape_list(past_key_value[0])[-2] + if not is_cross_attention: + # Below is to allow symbolic tensors compilation + if tf.is_tensor(kv_seq_len): + seq_len = tf.reduce_max(kv_seq_len, q_len) + else: + seq_len = max(kv_seq_len, q_len) + cos, sin = self.rotary_emb(value_states, seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.qk_layer_norms: + query_states = self.q_layer_norm(query_states) + key_states = self.k_layer_norm(key_states) + + tf.debugging.assert_equal( + tf.shape(attention_mask), + [bsz, 1, q_len, kv_seq_len], + message=f"Attention weights should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {tf.shape(attention_mask)}", + ) + + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + tf.debugging.assert_equal( + tf.shape(attn_output), + [bsz, self.num_heads, q_len, self.head_dim], + message=f"Attention weights should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is {tf.shape(attn_output)}", + ) + + attn_output = tf.reshape(tf.transpose(attn_output, perm=[0, 2, 1, 3]), (bsz, q_len, self.hidden_size)) + + attn_output = self.o_proj(attn_output) + + attn_weights = None + if output_attentions: + logger.warning_once( + "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" + ) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.is_cross_attention: + kv_input_dim = ( + self.hidden_size + if not hasattr(self.config.vision_config, "embed_dim") + else self.config.vision_config.embed_dim + ) + else: + kv_input_dim = self.hidden_size + if getattr(self, "o_proj", None) is not None: + with tf.name_scope(self.o_proj.name): + self.o_proj.build(self.num_heads * self.head_dim) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build(self.hidden_size) + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build(kv_input_dim) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build(kv_input_dim) + if getattr(self, "rotary_emb", None) is not None: + with tf.name_scope(self.rotary_emb.name): + self.rotary_emb.build(None) + + +class TFIdeficsDecoderLayer(tf.keras.layers.Layer): + def __init__(self, config: IdeficsConfig, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + self.self_attn = TFIdeficsAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.dropout, + config=config, + name="self_attn", + ) + self.mlp = TFIdeficsMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + name="mlp", + ) + self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") + self.post_attention_layernorm = TFIdeficsRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" + ) + self.dropout = config.dropout + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[tf.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + training=False, + ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "input_layernorm", None) is not None: + with tf.name_scope(self.input_layernorm.name): + self.input_layernorm.build(None) + if getattr(self, "post_attention_layernorm", None) is not None: + with tf.name_scope(self.post_attention_layernorm.name): + self.post_attention_layernorm.build(None) + + +class TFIdeficsGatedCrossAttentionLayer(tf.keras.layers.Layer): + def __init__(self, config: IdeficsConfig, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + self.cross_attn = TFIdeficsAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + is_cross_attention=True, + dropout=config.dropout, + config=config, + qk_layer_norms=config.qk_layer_norms, + name="cross_attn", + ) + self.mlp = TFIdeficsMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + name="mlp", + ) + self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") + self.post_attention_layernorm = TFIdeficsRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" + ) + self.config = config.dropout + + self.act_cross_attn = tf.keras.activations.tanh + self.act_dense = tf.keras.activations.tanh + + self.alpha_initializer = config.alpha_initializer + self.alpha_type = config.alpha_type + self.alphas_initializer_range = config.alphas_initializer_range + + def build(self, input_shape): + if self.built: + return + self.built = True + if self.alpha_initializer == "zeros": + if self.alpha_type == "vector": + self.alpha_cross_attn = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_dense" + ) + elif self.alpha_type == "float": + self.alpha_cross_attn = self.add_weight( + shape=(1,), initializer="zeros", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight(shape=(1,), initializer="zeros", trainable=True, name="alpha_dense") + else: + raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") + + elif self.alpha_initializer == "ones": + if self.alpha_type == "vector": + self.alpha_cross_attn = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_dense" + ) + elif self.alpha_type == "float": + self.alpha_cross_attn = self.add_weight( + shape=(1,), initializer="ones", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight(shape=(1,), initializer="ones", trainable=True, name="alpha_dense") + else: + raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") + + elif self.alpha_initializer in {"normal", "gaussian", "random"}: + if self.alpha_type == "vector": + self.alpha_cross_attn = self.add_weight( + shape=(1, 1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_cross_attn", + ) + self.alpha_dense = self.add_weight( + shape=(1, 1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_dense", + ) + elif self.alpha_type == "float": + self.alpha_cross_attn = self.add_weight( + shape=(1,), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_type", + ) + self.alpha_dense = self.add_weight( + shape=(1,), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_dense", + ) + else: + raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") + + else: + raise NotImplementedError(f"Alpha initialization scheme {self.alpha_initializer} not yet implemented!") + + if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): + raise ValueError("Alpha parameters not initialized correctly!") + with tf.name_scope(self.cross_attn.name): + self.cross_attn.build(None) + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + with tf.name_scope(self.input_layernorm.name): + self.input_layernorm.build(None) + with tf.name_scope(self.post_attention_layernorm.name): + self.post_attention_layernorm.build(None) + super().build(input_shape) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + image_hidden_states: Optional[tf.Tensor] = None, + image_attention_mask: Optional[tf.Tensor] = None, + cross_attention_gate: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[tf.Tensor]] = None, + ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored + """ + if image_hidden_states is None: + raise ValueError( + "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" + " conditioned on." + ) + + if cross_attention_gate is None: + raise ValueError( + "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images." + ) + + if past_key_value is not None: + raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=image_hidden_states, + attention_mask=image_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = tf.nn.dropout(hidden_states, rate=self.config) + mask = tf.cast(cross_attention_gate == 0, dtype=hidden_states.dtype) + # Expand dimensions of mask to match hidden_states + mask = tf.expand_dims(mask, -1) + hidden_states = tf.where( + tf.broadcast_to(mask, tf.shape(hidden_states)) == 1, tf.zeros_like(hidden_states), hidden_states + ) + # when there are no images the model is used in pure language mode + # gate = 0 if no_images else 1 + hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = tf.nn.dropout(hidden_states, rate=self.config) + hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a TensorFlow [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) subclass. + Use it as a regular TensorFlow Layer and refer to the TensorFlow documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`IdeficsConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class TFIdeficsPreTrainedModel(TFPreTrainedModel): + config_class = IdeficsConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TFIdeficsDecoderLayer", "TFIdeficsGatedCrossAttentionLayer"] + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +@keras_serializable +class TFIdeficsMainLayer(tf.keras.layers.Layer): + """ + Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] + + Args: + config: IdeficsConfig + """ + + config_class = IdeficsConfig + + def __init__(self, config: IdeficsConfig, add_pooling_year: bool = True, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = TFIdeficsDecoupledEmbedding( + num_embeddings=config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_text_layers, + name="embed_tokens", + ) + + self.image_size = config.vision_config.image_size + self.vision_config = config.vision_config + self.vision_model = TFIdeficsVisionTransformer(config.vision_config, name="vision_model") + + # Perceiver Resampler + if config.use_resampler: + perceiver_config = config.perceiver_config + self.perceiver_resampler = TFIdeficsPerceiverResampler( + config, + config.vision_config.embed_dim, + perceiver_config.resampler_depth, + perceiver_config.resampler_n_heads, + perceiver_config.resampler_head_dim, + perceiver_config.resampler_n_latents, + name="perceiver_resampler", + ) + + self.decoder_layers = [ + TFIdeficsDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) + ] + + self.cross_layer_interval = config.cross_layer_interval + num_cross_layers = config.num_hidden_layers // self.cross_layer_interval + self.gated_cross_attn_layers = [ + TFIdeficsGatedCrossAttentionLayer(config, name=f"gated_cross_attn_layers.{i}") + for i in range(num_cross_layers) + ] + self.gradient_checkpointing = False + + self.norm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm") + + self.gradient_checkpointing = False + self.freeze_relevant_params(config) + + def freeze_relevant_params(self, config=None): + if config is None: + config = self.config + + if config.freeze_text_layers: + self.freeze_text_layers(config.freeze_text_module_exceptions) + + if config.freeze_vision_layers: + freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) + + def freeze_text_layers(self, module_exceptions=[]): + for module in [self.decoder_layers, self.norm]: + freeze_model(module, module_exceptions=module_exceptions) + + def freeze_vision_layers(self, module_exceptions=[]): + freeze_model(self.vision_model, module_exceptions=module_exceptions) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + # if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @unpack_inputs + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + pixel_values: Optional[tf.Tensor] = None, + image_encoder_embeddings: Optional[tf.Tensor] = None, + perceiver_embeddings: Optional[tf.Tensor] = None, + image_attention_mask: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ) -> Union[TFIdeficsBaseModelOutputWithPast, Tuple[tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = shape_list(input_ids) + elif inputs_embeds is not None: + batch_size, seq_length, _ = shape_list(inputs_embeds) + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = shape_list(past_key_values[0][0])[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int32), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + elif position_ids is None: + position_ids = tf.range(past_key_values_length, seq_length + past_key_values_length, dtype=tf.int32) + position_ids = tf.expand_dims(position_ids, 0) + + no_images = False + if ( + sum((int(pixel_values is None), int(image_encoder_embeddings is None), int(perceiver_embeddings is None))) + != 2 + ): + raise ValueError( + "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." + ) + + elif pixel_values is not None: + no_images = tf.reduce_sum(tf.cast(pixel_values, dtype=tf.int32)) == 0 + pixel_values = tf.cast(pixel_values, dtype=self.dtype) # fp16 compatibility + # Below hack is because when cross-loading pytorch weights, there is an + # initial forward pass with dummy input and code below is here to handle that + if len(pixel_values.shape) == 4: + batch_size = shape_list(pixel_values)[0] + num_images = shape_list(pixel_values)[0] + # pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[1:]]) + elif len(pixel_values.shape) == 5: + batch_size, num_images = shape_list(pixel_values)[:2] + pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[2:]]) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ).last_hidden_state + + elif image_encoder_embeddings is not None: + batch_size, num_images, image_seq_len, image_hidden_size = shape_list(image_encoder_embeddings) + image_hidden_states = tf.cast(image_encoder_embeddings, dtype=self.dtype) + image_hidden_states = tf.reshape( + image_hidden_states, (batch_size * num_images, image_seq_len, image_hidden_size) + ) + + if self.config.use_resampler: + if perceiver_embeddings is None: + perceiver_embeddings = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)[1:3] + else: + batch_size, num_images, image_seq_len, image_hidden_size = shape_list(perceiver_embeddings) + image_hidden_states = perceiver_embeddings + elif perceiver_embeddings is None: + image_seq_len, image_hidden_size = shape_list(image_hidden_states)[1:3] + else: + raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True") + + image_hidden_states = tf.reshape( + image_hidden_states, (batch_size, num_images * image_seq_len, image_hidden_size) + ) + # # Hack to use the model in full language modeling mode + # image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32) + + # this is to account for the dummy inputs + if pixel_values is not None and len(pixel_values.shape) == 4 and image_attention_mask is None: + image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32) + + text_seq_len = shape_list(image_attention_mask)[1] + image_attention_mask = tf.expand_dims(image_attention_mask, -1) + image_attention_mask = tf.repeat(image_attention_mask, repeats=image_seq_len) + image_attention_mask = tf.reshape(image_attention_mask, (batch_size, text_seq_len, num_images * image_seq_len)) + + if image_hidden_states is not None: + image_batch_size, image_sequence_length, _ = shape_list(image_hidden_states) + image_hidden_shape = (image_batch_size, image_sequence_length) + if image_attention_mask is None: + image_attention_mask = tf.ones(image_hidden_shape, dtype=tf.int32) + image_attention_mask = invert_attention_mask(image_attention_mask) + else: + image_attention_mask = None + + cross_attention_gate = tf.squeeze( + tf.cast(tf.reduce_any(image_attention_mask == 0, axis=-1), dtype=self.dtype), axis=1 + ) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.decoder_layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + def vblock( + main_block, + hidden_states, + attention_mask, + position_ids, + past_key_value, + image_hidden_states, + image_attention_mask, + cross_attention_gate, + output_attentions, + use_cache, + layer_idx, + cross_layer_interval, + gated_cross_attn_layers, + ): + # TODO(ls): Add cross attention values to respective lists + if layer_idx % cross_layer_interval == 0: + xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] + outputs = xblock( + hidden_states, + attention_mask=attention_mask, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + cross_attention_gate=cross_attention_gate, + output_attentions=output_attentions, + use_cache=use_cache, + past_key_value=None, # not implemented + ) + hidden_states = outputs[0] + + layer_outputs = main_block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + return layer_outputs + + if self.gradient_checkpointing and training: + past_key_value = None + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + layer_outputs = tf.recompute_grad( + vblock, + decoder_layer, + hidden_states, + attention_mask, + position_ids, + past_key_value, + image_hidden_states, + image_attention_mask, + output_attentions, + use_cache, + no_images, + idx, + self.cross_layer_interval, + self.gated_cross_attn_layers, + ) + else: + layer_outputs = vblock( + decoder_layer, + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + cross_attention_gate=cross_attention_gate, + output_attentions=output_attentions, + use_cache=use_cache, + layer_idx=idx, + cross_layer_interval=self.cross_layer_interval, + gated_cross_attn_layers=self.gated_cross_attn_layers, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + image_hidden_states = tf.reshape( + image_hidden_states, (batch_size, num_images, image_seq_len, image_hidden_size) + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] + if v is not None + ) + return TFIdeficsBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + image_hidden_states=image_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_tokens", None) is not None: + with tf.name_scope(self.embed_tokens.name): + self.embed_tokens.build(None) + if getattr(self, "vision_model", None) is not None: + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build(None) + if getattr(self, "perceiver_resampler", None) is not None: + with tf.name_scope(self.perceiver_resampler.name): + self.perceiver_resampler.build(None) + if getattr(self, "decoder_layers", None) is not None: + for layer in self.decoder_layers: + with tf.name_scope(layer.name): + layer.build(None) + if getattr(self, "gated_cross_attn_layers", None) is not None: + for layer in self.gated_cross_attn_layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFIdeficsModel(TFIdeficsPreTrainedModel): + def __init__(self, config: IdeficsConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFIdeficsMainLayer(config, name="model") + + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + pixel_values: Optional[tf.Tensor] = None, + image_encoder_embeddings: Optional[tf.Tensor] = None, + perceiver_embeddings: Optional[tf.Tensor] = None, + image_attention_mask: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ) -> Union[TFIdeficsBaseModelOutputWithPast, Tuple[tf.Tensor]]: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_encoder_embeddings=image_encoder_embeddings, + perceiver_embeddings=perceiver_embeddings, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +class TFIdeficsForVisionText2Text(TFPreTrainedModel, TFCausalLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + config_class = IdeficsConfig + + def __init__(self, config, vision_model=None, **kwargs): + super().__init__(config, **kwargs) + self.model = TFIdeficsMainLayer(config, name="model") + self.lm_head = TFIdeficsDecoupledLinear( + config.hidden_size, + config.vocab_size, + config.additional_vocab_size, + bias=False, + partially_freeze=config.freeze_lm_head, + name="lm_head", + ) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def tie_weights(self): + """ + Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of + IdeficsDecoupledLinear and IdeficsDecoupledEmbedding. + """ + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings.weight = input_embeddings.weight + if input_embeddings.num_additional_embeddings > 0: + assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings + output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight + + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + if hasattr(output_embeddings, "out_additional_features") and hasattr( + input_embeddings, "num_additional_embeddings" + ): + output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings + + @unpack_inputs + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFIdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + pixel_values: Optional[tf.Tensor] = None, + image_encoder_embeddings: Optional[tf.Tensor] = None, + perceiver_embeddings: Optional[tf.Tensor] = None, + image_attention_mask: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + training=False, + ) -> Union[TFIdeficsCausalLMOutputWithPast, Tuple[tf.Tensor]]: + r""" + Args: + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >> from transformers import AutoTokenizer, TFIdeficsForVisionText2Text + + >> model = TFIdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b") + >> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics-9b") + + >> prompt = "Hey, are you consciours? Can you talk to me?" + >> inputs = tokenizer(prompt, return_tensors="tf") + + >> # Generate + >> generate_ids = model.generate(inputs.input_ids, max_length=30) + >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_encoder_embeddings=image_encoder_embeddings, + perceiver_embeddings=perceiver_embeddings, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask != 0] + shift_labels = labels[..., 1:][shift_attention_mask != 0] + else: + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + # Flatten the tokens + loss = self.hf_compute_loss( + labels=tf.reshape(shift_labels, [-1]), logits=tf.reshape(shift_logits, [-1, shift_logits.shape[-1]]) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return TFIdeficsCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + image_hidden_states = kwargs.pop("image_hidden_states", None) + if image_hidden_states is not None: + if self.config.use_resampler: + kwargs["perceiver_embeddings"] = image_hidden_states + else: + kwargs["image_encoder_embeddings"] = image_hidden_states + kwargs["pixel_values"] = None + inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) + unwanted_kwargs = ["token_type_ids"] + for kwarg in unwanted_kwargs: + inputs.pop(kwarg, None) + return inputs + + @staticmethod + def _expand_inputs_for_generation( + *args, + **model_kwargs, + ): + return expand_inputs_for_generation(*args, **model_kwargs) + + @staticmethod + def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder): + return update_model_kwargs_for_generation(outputs, model_kwargs) + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),) + return reordered_past + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) diff --git a/src/transformers/models/idefics/perceiver_tf.py b/src/transformers/models/idefics/perceiver_tf.py new file mode 100644 index 00000000000000..c9e76004a70ddc --- /dev/null +++ b/src/transformers/models/idefics/perceiver_tf.py @@ -0,0 +1,194 @@ +# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. +# +# MIT License +# +# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +""" + +Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially +time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note +that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to +prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that +to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. + +References: + - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model + - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch + +""" +from typing import Optional, Tuple + +import tensorflow as tf + +from ...modeling_tf_utils import shape_list +from .configuration_idefics import IdeficsConfig + + +class TFIdeficsPerceiverResampler(tf.keras.layers.Layer): + def __init__( + self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs + ) -> None: + """ + Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or + MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then + returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed + to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. + Could be e.g., VIT embed_dim, ResNet pool dim, and so on. + + Args: + config (`IdeficsConfig`): config object + embed_dim (`int`): The size of each embedding vector + depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). + head_dim (`int`): Dimensionality of each head projection in the Transformer block. + n_latents (`int`): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + + """ + super().__init__(**kwargs) + self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents + self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver + + self.intermediate_dim = ( + self.embed_dim * 4 + if not hasattr(config.vision_config, "embed_dim") + else config.vision_config.embed_dim * 4 + ) + # Create Transformer Blocks + self.blocks = [] + for i in range(depth): + self.blocks.append( + [ + TFIdeficsPerceiverAttention( + self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0" + ), + TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"), + ] + ) + + self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + def build(self, input_shape): + # Create Latents for Perceiver + self.latents = self.add_weight( + shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents" + ) + super().build(input_shape) + + def call(self, context: tf.Tensor) -> tf.Tensor: + """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" + # tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) + latents = tf.expand_dims(self.latents, axis=0) + latents = tf.tile(latents, [tf.shape(context)[0], 1, 1]) + # Feed through Perceiver Attention blocks... + for attn, ff in self.blocks: + latents = attn(context, latents) + latents + latents = ff(latents) + latents + return self.layer_norm(latents) + + +class TFIdeficsPerceiverAttention(tf.keras.layers.Layer): + def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None: + """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" + super().__init__(**kwargs) + self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim + self.qk_layer_norms = qk_layer_norms + # Normalization & Scaling + self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm") + self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm") + if self.qk_layer_norms: + self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm") + self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm") + + self.qk_scale = self.head_dim**-0.5 + + # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). + self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj") + self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj") + self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj") + + self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj") + + def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor: + """ + Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! + + Args: + context (`tf.Tensor`): + Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. + latents (`tf.Tensor`): + Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. + + Returns: + `tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross + from context. + """ + context = self.context_layer_norm(context) + latents = self.latents_layer_norm(latents) + batch_size, seq_length, embed_dim = shape_list(context) + + # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! + # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` + q = self.q_proj(latents) + k = self.k_proj(tf.concat([context, latents], axis=-2)) + v = self.v_proj(tf.concat([context, latents], axis=-2)) + + # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) + # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] + q, k, v = [ + tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3]) + for x in (q, k, v) + ] + + if self.qk_layer_norms: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + + scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) + stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True) + attn = tf.nn.softmax(stabilized_scores, axis=-1) + + # Attend & project back to output... + resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v) + return self.output_proj( + tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim)) + ) + + +class TFIdeficsMLP(tf.keras.layers.Layer): + def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs): + """Simple MLP block with intermediate_size and embedding size""" + super().__init__(**kwargs) + self.embed_dim = config.vision_config.embed_dim + self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln") + self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc") + self.act = tf.keras.layers.ReLU(name="act") + self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj") + + def call(self, hidden_states: Optional[Tuple[tf.Tensor]]) -> tf.Tensor: + hidden_states = self.ln(hidden_states) + hidden_states = self.fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index d7fd8c8de6555e..2afe2a49781245 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -22,34 +22,53 @@ from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy -from ...utils import TensorType, is_torch_available +from ...utils import is_tf_available, is_torch_available if is_torch_available(): import torch +if is_tf_available(): + import tensorflow as tf IMAGE_TOKEN = "" # copied from m4.training.packing -def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): - # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] - - # If any of images index are more than num_classes, set them to -1. - # Words after the max number of images allowed have been seen don't attend on anything +def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1): + # Set elements >= num_classes to -1 if num_classes != -1: - incremental_mask[incremental_mask >= num_classes] = -1 + if return_tensors == "pt": + incremental_mask[incremental_mask >= num_classes] = -1 + elif return_tensors == "tf": + incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask) + + # Create mask for negative values + if return_tensors == "pt": + negatives = incremental_mask == -1 + incremental_mask[negatives] = 0 + attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) + attn_mask[negatives, :] = 0 + elif return_tensors == "tf": + negatives = tf.equal(incremental_mask, -1) + incremental_mask = tf.where(negatives, 0, incremental_mask) + attn_mask = tf.one_hot(incremental_mask, depth=num_classes) + # Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1] + negatives_expanded = tf.expand_dims(negatives, -1) + attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask) - negatives = incremental_mask == -1 - incremental_mask[negatives] = 0 - attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) - attn_mask[negatives, :] = 0 return attn_mask # copied from m4.training.packing -def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): +def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors): + if return_tensors == "pt": + return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer) + elif return_tensors == "tf": + return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer) + + +def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer): image_attention_mask = torch.full_like(input_ids, fill_value=-1) next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) @@ -96,6 +115,39 @@ def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): return image_attention_mask, next_image_attention_mask +def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer): + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + eod_token_id = tokenizer.eos_token_id + batch_size = tf.shape(input_ids)[0] + image_attention_mask = tf.fill(tf.shape(input_ids), -1) + next_image_attention_mask = tf.fill(tf.shape(input_ids), -1) + + for batch_idx in range(batch_size): + count = -1 + seen_eod = False + seq_length = tf.shape(input_ids)[1] + + for idx in range(seq_length - 1, -1, -1): + token_id = input_ids[batch_idx, idx].numpy() + if token_id == image_token_id: + count += 1 + indices = [[batch_idx, idx]] + updates = [count] + image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates) + next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) + elif token_id == eod_token_id and not seen_eod: + seen_eod = True + count = 0 + indices = [[batch_idx, idx]] + updates = [count] + next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) + if seen_eod and token_id != eod_token_id: + indices = [[batch_idx, idx]] + updates = [-1] + next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) + return image_attention_mask, next_image_attention_mask + + def is_url(string): """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately invalidated the url""" @@ -156,7 +208,7 @@ def __call__( add_eos_token=False, add_end_of_utterance_token=None, debug=False, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + return_tensors="pt", ) -> BatchEncoding: """This method takes batched or non-batched prompts made of text and images and converts them into prompts that the model was trained on and prepares the image pixel values for the model to process. @@ -268,7 +320,6 @@ def __call__( # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it if add_end_of_utterance_token is None: add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token - # turn non-batched prompts into batched if not any(isinstance(i, list) for i in prompts): prompts = [prompts] @@ -322,7 +373,7 @@ def image_tokens(last_was_image): if debug is True: print(f"{full_text=}") - image_objects = self.image_processor(image_objects, transform=transform) + image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors) all_prompts.append(full_text) all_images.append(image_objects) @@ -345,39 +396,72 @@ def image_tokens(last_was_image): output_input_ids = [] output_images = [] output_attention_masks = [] + for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images): padded_input_ids = text - image_count = padded_input_ids.count(self.image_token_id) local_max_num_images = min(image_count, max_num_images) current_images = images[:local_max_num_images] if len(current_images) > 0: - padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) - padded_image_tensor[: current_images.size(0)] = current_images + if return_tensors == "pt": + padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) + padded_image_tensor[: current_images.size(0)] = current_images + elif return_tensors == "tf": + # Assuming current_images is a TensorFlow tensor + # Get the shape of current_images, excluding the first dimension + image_shape = tf.shape(current_images)[1:] + # Create a shape for the padded_image_tensor + padded_shape = tf.concat([[max_num_images], image_shape], axis=0) + # Create the padded_image_tensor of zeros + padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype) + # Get the number of images (assuming current_images has shape [num_images, height, width, channels]) + num_images = tf.shape(current_images)[0] + # Update the padded_image_tensor with the values from current_images + indices = tf.reshape(tf.range(num_images), (-1, 1)) + updates = current_images + padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates) else: - padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) + if return_tensors == "pt": + padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) + elif return_tensors == "tf": + padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims)) output_images.append(padded_image_tensor) - output_input_ids.append(torch.tensor(padded_input_ids)) - output_attention_masks.append(torch.tensor(attention_mask)) - - output_input_ids = torch.stack(output_input_ids) - output_images = torch.stack(output_images) - output_attention_masks = torch.stack(output_attention_masks) + if return_tensors == "pt": + output_input_ids.append(torch.tensor(padded_input_ids)) + output_attention_masks.append(torch.tensor(attention_mask)) + elif return_tensors == "tf": + output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32)) + output_attention_masks.append(attention_mask) + + if return_tensors == "pt": + output_input_ids = torch.stack(output_input_ids) + output_images = torch.stack(output_images) + output_attention_masks = torch.stack(output_attention_masks) + elif return_tensors == "tf": + output_input_ids = tf.stack(output_input_ids) + output_images = tf.stack(output_images) + output_attention_masks = tf.stack(output_attention_masks) if at_least_one_image: - image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer) + image_attention_mask, _ = image_attention_mask_for_packed_input_ids( + output_input_ids, self.tokenizer, return_tensors + ) image_attention_mask = incremental_to_binary_attention_mask( - image_attention_mask, num_classes=max_num_images + image_attention_mask, return_tensors, num_classes=max_num_images ) else: # in full language mode we set the image mask to all-0s - image_attention_mask = torch.zeros( - output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool - ) - + if return_tensors == "pt": + image_attention_mask = torch.zeros( + output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool + ) + elif return_tensors == "tf": + image_attention_mask = tf.zeros( + (output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool + ) return BatchFeature( data={ "input_ids": output_input_ids, diff --git a/src/transformers/models/idefics/vision_tf.py b/src/transformers/models/idefics/vision_tf.py new file mode 100644 index 00000000000000..0060bb7ac9a7fb --- /dev/null +++ b/src/transformers/models/idefics/vision_tf.py @@ -0,0 +1,573 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" + + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling +from ...modeling_tf_utils import TFPreTrainedModel, shape_list +from ...tf_utils import flatten +from ...utils import ModelOutput, logging +from .configuration_idefics import IdeficsVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +class TFIdeficsVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[tf.Tensor] = None + last_hidden_state: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + +class TFIdeficsVisionEmbeddings(tf.keras.layers.Layer): + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = tf.keras.layers.Conv2D( + filters=self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + use_bias=False, + padding="valid", + data_format="channels_last", + name="patch_embedding", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = tf.keras.layers.Embedding( + self.num_positions, self.embed_dim, name="position_embedding" + ) + # self.position_ids = tf.range(self.num_positions)[tf.newaxis, :] + + def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: + num_patches = shape_list(embeddings)[1] - 1 + pos_embed = self.position_embedding(self.position_ids) + num_positions = shape_list(pos_embed)[1] - 1 + if num_patches == num_positions and height == width: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + + embed_dim = shape_list(embeddings)[-1] + num_h_patches = height // self.config.patch_size + num_w_patches = width // self.config.patch_size + num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 + sqrt_num_positions = math.sqrt(float(num_positions)) + patch_pos_embed = tf.reshape(patch_pos_embed, (1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)) + + scale_height = num_h_patches / sqrt_num_positions + scale_width = num_w_patches / sqrt_num_positions + original_height = tf.cast(tf.shape(patch_pos_embed)[1], tf.float32) + original_width = tf.cast(tf.shape(patch_pos_embed)[2], tf.float32) + # Apply scaling + new_height = tf.cast(original_height * scale_height, tf.int32) + new_width = tf.cast(original_width * scale_width, tf.int32) + + patch_pos_embed = tf.image.resize( + patch_pos_embed, size=[new_height, new_width], method=tf.image.ResizeMethod.BICUBIC + ) + + if ( + int(num_h_patches) != shape_list(patch_pos_embed)[-3] + or int(num_w_patches) != shape_list(patch_pos_embed)[-2] + ): + raise ValueError( + f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " + f"shape of position embedding ({shape_list(patch_pos_embed)[-2], shape_list(patch_pos_embed)[-1]})" + ) + patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, embed_dim)) + return tf.concat((class_pos_embed[tf.newaxis, :], patch_pos_embed), axis=1) + + def call(self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False) -> tf.Tensor: + # Input `pixel_values` is NCHW format which doesn't run on CPU so first thing we do is + # transpose it to change it to NHWC. We don't care to transpose it back because + # the Conv2D layer is only hit once for each query + + if isinstance(pixel_values, dict): + pixel_values = pixel_values["pixel_values"] + + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + batch_size, height, width, num_channels = shape_list(pixel_values) + if not interpolate_pos_encoding: + if height != self.image_size or width != self.image_size: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`" + ) + + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + # Change the 2D spatial dimensions to a single temporal dimension. + # shape = (batch_size, num_patches, out_channels=embed_dim) + patch_embeds = flatten(patch_embeds, 1, 2) + + class_embeds = tf.broadcast_to( + self.class_embedding[tf.newaxis, tf.newaxis, :], [batch_size, 1, self.embed_dim] + ) + embeddings = tf.concat([class_embeds, patch_embeds], axis=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.position_ids = tf.range(self.num_positions, name="self.position_ids")[tf.newaxis, :] + self.class_embedding = self.add_weight(shape=(self.embed_dim,), name="class_embedding") + if getattr(self, "patch_embedding", None) is not None: + with tf.name_scope(self.patch_embedding.name): + self.patch_embedding.build([None, None, None, self.config.num_channels]) + if getattr(self, "position_embedding", None) is not None: + with tf.name_scope(self.position_embedding.name): + self.position_embedding.build(None) + + +class TFIdeficsVisionAttention(tf.keras.layers.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = tf.keras.layers.Dense(self.embed_dim, name="k_proj") + self.v_proj = tf.keras.layers.Dense(self.embed_dim, name="v_proj") + self.q_proj = tf.keras.layers.Dense(self.embed_dim, name="q_proj") + self.out_proj = tf.keras.layers.Dense(self.embed_dim, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + causal_attention_mask: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.linalg.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + tf.shape(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is {tf.shape(attn_weights)}", + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if shape_list(causal_attention_mask) != [bsz, 1, tgt_len, src_len]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(causal_attention_mask)}" + ) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + causal_attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + if attention_mask is not None: + if shape_list(attention_mask) != [bsz, 1, tgt_len, src_len]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}" + ) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attn_weights = tf.reshape(attn_weights_reshaped, (bsz * self.num_heads, tgt_len, src_len)) + else: + attn_weights_reshaped = None + + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + + attn_output = tf.linalg.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + tf.shape(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {tf.shape(attn_output)}", + ) + + attn_output = tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)) + attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3]) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build((self.embed_dim, self.embed_dim)) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build((self.embed_dim, self.embed_dim)) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build((self.embed_dim, self.embed_dim)) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build((self.embed_dim, self.embed_dim)) + + +class TFIdeficsVisionMLP(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.activation_fn = get_tf_activation(config.hidden_act) + self.fc1 = tf.keras.layers.Dense(config.intermediate_size, name="fc1") + self.fc2 = tf.keras.layers.Dense(config.hidden_size, name="fc2") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build(self.config.hidden_size) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build(self.config.intermediate_size) + + +class TFIdeficsVisionEncoderLayer(tf.keras.layers.Layer): + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.hidden_size + self.self_attn = TFIdeficsVisionAttention(config, name="self_attn") + self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.mlp = TFIdeficsVisionMLP(config, name="mlp") + self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + causal_attention_mask: tf.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[tf.Tensor]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, self.embed_dim]) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, self.embed_dim]) + + +class TFIdeficsVisionEncoder(tf.keras.layers.Layer): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`TFIdeficsVisionEncoderLayer`]. + + Args: + config: IdeficsVisionConfig + """ + + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layers = [ + TFIdeficsVisionEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) + ] + self.gradient_checkpointing = False + + def call( + self, + inputs_embeds, + attention_mask: Optional[tf.Tensor] = None, + causal_attention_mask: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutput]: + r""" + Args: + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = tf.recompute_grad( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFIdeficsVisionTransformer(TFPreTrainedModel): + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.embed_dim = config.hidden_size + + self.embeddings = TFIdeficsVisionEmbeddings(config, name="embeddings") + self.pre_layrnorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm") + self.encoder = TFIdeficsVisionEncoder(config, name="encoder") + self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm") + + # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "pre_layrnorm", None) is not None: + with tf.name_scope(self.pre_layrnorm.name): + self.pre_layrnorm.build([None, None, self.embed_dim]) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "post_layernorm", None) is not None: + with tf.name_scope(self.post_layernorm.name): + self.post_layernorm.build([None, self.embed_dim]) diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 75e302947e8066..b91a2ea520f0d0 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -104,6 +104,33 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): return outputs +def scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: float = None +): + """TF equivalent for torch's nn.functional.scaled_dot_product_attention""" + if dropout_p != 0.0: + raise ValueError( + "Dropout is not supported in this implementation - file an issue " + "with Transformers and ping @Rocketknight1 if you need it for a port!" + ) + if is_causal and attn_mask is not None: + raise ValueError("You cannot specify an attn_mask and is_causal at the same time!") + if is_causal: + attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32) + attn_mask = tf.experimental.numpy.tril(attn_mask, k=0) + if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool): + # Convert boolean mask to a negative logit bias + attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype)) + logits = tf.einsum("...qd, ...kd -> ...qk", query, key) + if scale is None: + scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5 + logits *= scale # scale by 1/sqrt(key_dim) + if attn_mask is not None: + logits += attn_mask + probs = tf.nn.softmax(logits) + return probs @ value + + def flatten(input, start_dim=0, end_dim=-1): # Replicates the behavior of torch.flatten in TF diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 5d4c28cbcc4595..e0b396c7164a75 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1542,6 +1542,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFIdeficsForVisionText2Text(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFIdeficsModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFIdeficsPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFLayoutLMForMaskedLM(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/idefics/test_image_processing_idefics.py b/tests/models/idefics/test_image_processing_idefics.py index 6c682ce4a8f8c6..de42a421cd877e 100644 --- a/tests/models/idefics/test_image_processing_idefics.py +++ b/tests/models/idefics/test_image_processing_idefics.py @@ -152,7 +152,7 @@ def test_torchvision_numpy_transforms_equivalency(self): # they both do the same image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - image_processor = self.image_processing_class(**self.image_processor_dict) + image_processor = self.image_processing_class(**self.image_processor_dict, return_tensors="pt") print(image_inputs) @@ -181,8 +181,8 @@ def convert_to_rgb(image): ] ) - pixel_values_transform_implied = image_processor(image_inputs, transform=None) - pixel_values_transform_supplied = image_processor(image_inputs, transform=transform) + pixel_values_transform_implied = image_processor(image_inputs, transform=None, return_tensors="pt") + pixel_values_transform_supplied = image_processor(image_inputs, transform=transform, return_tensors="pt") torch.testing.assert_close(pixel_values_transform_implied, pixel_values_transform_supplied, rtol=0.0, atol=0.0) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 9f8f177617d200..5c3d45d2e81bcb 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -21,6 +21,7 @@ from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( TestCasePlus, + is_pt_tf_cross_test, require_bitsandbytes, require_torch, require_torch_sdpa, @@ -559,6 +560,11 @@ def check_hidden_states_output(inputs_dict, config, model_class): check_hidden_states_output(inputs_dict, config, model_class) + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + self.has_attentions = False + super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys) + @slow def test_model_from_pretrained(self): model_name = "HuggingFaceM4/idefics-9b" diff --git a/tests/models/idefics/test_modeling_tf_idefics.py b/tests/models/idefics/test_modeling_tf_idefics.py new file mode 100644 index 00000000000000..eeb3faafa223d9 --- /dev/null +++ b/tests/models/idefics/test_modeling_tf_idefics.py @@ -0,0 +1,565 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TF Idefics model. """ + +import os +import tempfile +import unittest +from importlib import import_module + +from transformers import IdeficsConfig, is_tf_available, is_vision_available +from transformers.testing_utils import TestCasePlus, is_pt_tf_cross_test, require_tf, require_vision, slow +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers import IdeficsProcessor, TFIdeficsForVisionText2Text, TFIdeficsModel + from transformers.modeling_tf_utils import keras + from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig + +if is_vision_available(): + from PIL import Image + + +IDEFICS_TINY_RANDOM_MODEL = "HuggingFaceM4/tiny-random-idefics" + + +class IdeficsModelTester: + def __init__( + self, + parent, + batch_size=1, + seq_length=7, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + scope=None, + modality_type_vocab_size=2, + vision_embed_dim=32, + vision_patch_size=2, + vision_image_size=30, + vision_num_attention_heads=4, + vision_num_hidden_layers=5, + vision_intermediate_size=37, + perceiver_qk_layer_norms_perceiver=False, + perceiver_resampler_depth=2, + perceiver_resampler_head_dim=8, + perceiver_resampler_n_heads=2, + perceiver_resampler_n_latents=16, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.scope = scope + self.modality_type_vocab_size = modality_type_vocab_size + + self.vision_embed_dim = vision_embed_dim + self.vision_patch_size = vision_patch_size + self.vision_image_size = vision_image_size + self.vision_num_attention_heads = vision_num_attention_heads + self.vision_num_hidden_layers = vision_num_hidden_layers + self.vision_intermediate_size = vision_intermediate_size + + self.vision_config = IdeficsVisionConfig( + embed_dim=self.vision_embed_dim, + patch_size=self.vision_patch_size, + image_size=self.vision_image_size, + num_attention_heads=self.vision_num_attention_heads, + num_hidden_layers=self.vision_num_hidden_layers, + intermediate_size=self.vision_intermediate_size, + ) + + self.perceiver_qk_layer_norms_perceiver = perceiver_qk_layer_norms_perceiver + self.perceiver_resampler_depth = perceiver_resampler_depth + self.perceiver_resampler_head_dim = perceiver_resampler_head_dim + self.perceiver_resampler_n_heads = perceiver_resampler_n_heads + self.perceiver_resampler_n_latents = perceiver_resampler_n_latents + + self.perceiver_config = IdeficsPerceiverConfig( + qk_layer_norms_perceiver=self.perceiver_qk_layer_norms_perceiver, + resampler_depth=self.perceiver_resampler_depth, + resampler_head_dim=self.perceiver_resampler_head_dim, + resampler_n_heads=self.perceiver_resampler_n_heads, + resampler_n_latents=self.perceiver_resampler_n_latents, + ) + + # we set the expected sequence length (which is used in several tests) + # this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token + self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1 + + def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + pixel_values = floats_tensor( + [ + self.batch_size, + num_images, + self.num_channels, + self.image_size + image_expansion, + self.image_size + image_expansion, + ] + ) + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images]) + + config = self.get_config() + return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding) + + def get_config(self): + return IdeficsConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + num_labels=self.num_labels, + modality_type_vocab_size=self.modality_type_vocab_size, + vision_config=self.vision_config, + ) + + def create_and_check_model( + self, + config, + input_ids, + input_mask, + pixel_values, + image_attention_mask, + interpolate_pos_encoding, + ): + model = TFIdeficsModel(config=config) + result = model( + input_ids, + attention_mask=input_mask, + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size) + ) + + def create_and_check_model_gen( + self, + config, + input_ids, + input_mask, + pixel_values, + image_attention_mask, + interpolate_pos_encoding, + ): + model = TFIdeficsForVisionText2Text(config) + model.generate( + input_ids, + attention_mask=input_mask, + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + max_length=self.seq_length + 2, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + pixel_values, + image_attention_mask, + interpolate_pos_encoding, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": input_mask, + "pixel_values": pixel_values, + "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, + } + return config, inputs_dict + + def prepare_pixel_values(self): + return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + +@require_tf +class TFIdeficsModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (TFIdeficsModel, TFIdeficsForVisionText2Text) if is_tf_available() else () + pipeline_model_mapping = {"feature-extraction": TFIdeficsModel} if is_tf_available() else {} + test_pruning = False + test_headmasking = False + test_onnx = False + test_resize_embeddings = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + # XXX: IdeficsForVisionText2TextTest has no MODEL_FOR group yet, but it should be the same + # as MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, so for now manually changing to do the right thing + # as super won't do it + if return_labels: + inputs_dict["labels"] = tf.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int64 + ) + return inputs_dict + + def test_model_outputs_equivalence(self): + try: + orig = self.all_model_classes + # IdeficsModel.forward doesn't have labels input arg - only IdeficsForVisionText2Text does + self.all_model_classes = (TFIdeficsForVisionText2Text,) if is_tf_available() else () + super().test_model_outputs_equivalence() + finally: + self.all_model_classes = orig + + def setUp(self): + self.model_tester = IdeficsModelTester(self) + self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model_single_image(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=False, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_multiple_images(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=False, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_image_pos_embeddings_interpolation_single_image(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=True, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_image_pos_embeddings_interpolation_multiple_images(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=True, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_generate_with_image_pos_embeddings_interpolation_single_image(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model_gen(*config_and_inputs) + + def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model_gen(*config_and_inputs) + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="""IDEFICS does not support retaining the gradients of the hidden states and attention""") + def test_retain_grad_hidden_states_attentions(self): + return + + @unittest.skip(reason="IDEFICS uses out-of-bounds embeddings deliberately.") + def test_embeddings_out_of_bounds_raise_exception(self): + pass + + @unittest.skip(reason="IDEFICS attention weights are not extracted in scaled_dot_product_attention") + def test_prepare_serving_output(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer)) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + # IDEFICS does not support outputting attention score becuase it uses SDPA under the hood + self.assertTrue(attentions[0] is None) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self.assertEqual(out_len + 1, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + # IDEFICS does not support outputting attention score becuase it uses SDPA under the hood + self.assertTrue(self_attentions[0] is None) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_length = self.model_tester.seq_length + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + self.has_attentions = False + super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys) + + def test_keras_save_load(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + tf_main_layer_classes = { + module_member + for model_class in self.all_model_classes + for module in (import_module(model_class.__module__),) + for module_member_name in dir(module) + if module_member_name.endswith("MainLayer") + for module_member in (getattr(module, module_member_name),) + if isinstance(module_member, type) + and keras.layers.Layer in module_member.__bases__ + and getattr(module_member, "_keras_serializable", False) + } + + for main_layer_class in tf_main_layer_classes: + main_layer = main_layer_class(config) + + symbolic_inputs = { + name: keras.Input(tensor.shape[1:], dtype=tensor.dtype, batch_size=2) + for name, tensor in inputs_dict.items() + if tf.is_tensor(tensor) + } + model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs)) + outputs = model(inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, "keras_model.h5") + model.save(filepath) + model = keras.models.load_model(filepath, custom_objects={main_layer_class.__name__: main_layer_class}) + assert isinstance(model, keras.Model) + after_outputs = model(inputs_dict) + self.assert_outputs_same(after_outputs, outputs) + + @unittest.skip(reason="IDEFICS test_keras_fit testing done in TFIdeficsForVisionText2TextTest") + def test_keras_fit(self): + pass + + @slow + def test_model_from_pretrained(self): + model = TFIdeficsModel.from_pretrained(IDEFICS_TINY_RANDOM_MODEL, from_pt=True) + self.assertIsNotNone(model) + + @unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.") + def test_saved_model_creation(self): + pass + + @unittest.skip(reason="""IDEFICS loss computation not implemented yet""") + def test_loss_computation(self): + pass + + +@require_tf +class TFIdeficsForVisionText2TextTest(TFIdeficsModelTest, unittest.TestCase): + all_model_classes = (TFIdeficsForVisionText2Text,) if is_tf_available() else () + test_resize_embeddings = False + + def setUp(self): + self.model_tester = IdeficsModelTester( + self, + modality_type_vocab_size=3, + ) + self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37) + + @unittest.skip("We only test the model that takes in multiple images") + def test_model(self): + pass + + @unittest.skip("We only test the model that takes in multiple images") + def test_for_token_classification(self): + pass + + @unittest.skip(reason="""IDEFICS does not support retaining the gradients of the hidden states and attention""") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="""IDEFICS loss computation not implemented yet""") + def test_loss_computation(self): + pass + + @slow + def test_keras_fit(self): + super().test_keras_fit() + + +# Below is the expected output for the integration test TFIdeficsModelIntegrationTest. +# Since we are using tiny-random to be able to fit it on the CI GPU,it is better to assert on the +# ids because the generated text is gibberish + +# fmt: off +EXPECTED_GENERATED_IDS = [[0, 0, 1, 4911, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 530, 1967, 310, 1023, 26361, 29889, 13, 2659, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 25519, 22326, 8071, 26357, 28004, 4428, 5916, 14383, 1033, 12358, 10536, 21834, 10447, 21201, 18102, 16886, 8875, 25388, 25914, 28304, 8558, 31048, 1322, 25952, 189, 31600, 3600, 12824, 7045, 28090, 20228, 32001, 5385, 29186, 2165, 11822, 13825, 23077, 7883, 22504, 2078, 18893, 2179, 10556, 9515, 7672, 3491, 12403, 5398, 27299, 6463, 16349, 23037, 28956, 16960, 22664, 7724, 17587, 17424, 10175, 17417, 5930, 30855, 17695, 16170, 14474, 29996, 313, 14502, 3241, 13618, 32001, 5385, 29186, 2165, 11822, 13825, 19934, 4875, 27142, 3230, 2709, 28054, 3270, 19148, 10917, 1060, 26443, 12259, 1347, 28482, 3830, 25519, 199, 12782, 9144, 12289, 1142, 18400, 21390, 19129, 7292, 28430, 24711, 5551, 30349, 30533, 13271, 17697, 4982, 8713, 5380, 17869, 12490, 5398, 27299, 11593, 19918, 15924, 29430, 10175, 17417, 5930, 30855, 17695, 16170, 14474, 19234], + [1, 4911, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 530, 1967, 310, 1023, 413, 986, 575, 29889, 13, 2659, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 25519, 22326, 8071, 26357, 28004, 4428, 17554, 20500, 21714, 27834, 4798, 12195, 30379, 5427, 20228, 10473, 14351, 8049, 15605, 14491, 212, 2711, 32000, 21714, 31259, 24368, 19036, 22970, 26083, 19394, 20372, 7672, 9939, 25388, 30533, 8200, 30271, 2114, 24749, 13224, 10603, 21118, 2179, 3759, 16515, 6587, 1287, 23998, 17793, 32001, 5385, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 2943, 1221, 16043, 18244, 24965, 14383, 19840, 5980, 13488, 28531, 735, 26146, 22504, 2078, 18893, 20372, 7672, 32001, 5385, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 19551, 220, 10528, 28940, 4453, 28266, 15416, 18693, 8199, 1153, 27706, 29231, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 19551, 8231, 10739, 31992, 25906, 22254, 23127, 7689, 19614, 1149, 18844, 23037, 28956, 16960, 22664, 6975, 28938, 24002, 11026, 15020, 21964, 16307], ] + +@require_tf +@require_vision +class TFIdeficsModelIntegrationTest(TestCasePlus): + @cached_property + def default_processor(self): + return IdeficsProcessor.from_pretrained(IDEFICS_TINY_RANDOM_MODEL) if is_vision_available() else None + + @slow + def test_inference_natural_language_visual_reasoning(self): + cat_image_path = self.tests_dir / "fixtures/tests_samples/COCO/000000039769.png" + cats_image_obj = Image.open(cat_image_path) # 2 cats + dogs_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image1.jpeg" + + prompts = [ + [ + "User:", + dogs_image_url, + "Describe this image.\nAssistant: An image of two dogs.\n", + "User:", + cats_image_obj, + "Describe this image.\nAssistant:", + ], + [ + "User:", + cats_image_obj, + "Describe this image.\nAssistant: An image of two kittens.\n", + "User:", + dogs_image_url, + "Describe this image.\nAssistant:", + ], + ] + + model = TFIdeficsForVisionText2Text.from_pretrained(IDEFICS_TINY_RANDOM_MODEL, from_pt=True) + processor = self.default_processor + inputs = processor(prompts, return_tensors="tf") + generated_ids = model.generate(**inputs, max_length=100) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + + # keep for debugging + for i, t in enumerate(generated_text): + t = bytes(t, "utf-8").decode("unicode_escape") + print(f"{i}:\n{t}\n") + + self.assertListEqual(EXPECTED_GENERATED_IDS[0], generated_ids[0].numpy().tolist()) + self.assertListEqual(EXPECTED_GENERATED_IDS[1], generated_ids[1].numpy().tolist()) diff --git a/tests/models/idefics/test_processor_idefics.py b/tests/models/idefics/test_processor_idefics.py index 2e319413d4c5e2..26dcbb1c0f1566 100644 --- a/tests/models/idefics/test_processor_idefics.py +++ b/tests/models/idefics/test_processor_idefics.py @@ -41,7 +41,7 @@ def setUp(self): self.checkpoint_path = self.get_auto_remove_tmp_dir() - image_processor = IdeficsImageProcessor() + image_processor = IdeficsImageProcessor(return_tensors="pt") tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics") processor = IdeficsProcessor(image_processor, tokenizer) @@ -132,7 +132,7 @@ def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() - processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt") predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] @@ -145,7 +145,7 @@ def test_tokenizer_padding(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer(padding_side="right") - processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt") predicted_tokens = [ " Describe this image.\nAssistant:", @@ -156,8 +156,9 @@ def test_tokenizer_padding(self): ([1] * 10) + ([0] * 10), ] prompts = [[prompt] for prompt in self.prepare_prompts()[2]] - max_length = processor(prompts, padding="max_length", truncation=True, max_length=20) - longest = processor(prompts, padding="longest", truncation=True, max_length=30) + + max_length = processor(prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt") + longest = processor(prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt") decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1]) decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1]) @@ -203,7 +204,7 @@ def test_model_input_names(self): processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor) prompts = self.prepare_prompts() - inputs = processor(prompts, padding="longest") + inputs = processor(prompts, padding="longest", return_tensors="pt") # For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask'] self.assertSetEqual(set(inputs.keys()), set(self.input_keys)) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index f396875570c98d..2cf272f4aac10d 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -380,7 +380,9 @@ def test_keras_save_load(self): main_layer = main_layer_class(config) symbolic_inputs = { - name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() + name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) + for name, tensor in inputs_dict.items() + if tf.is_tensor(tensor) } model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs)) @@ -1689,7 +1691,11 @@ def test_dataset_conversion(self): tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True) if "labels" not in tf_inputs_dict: return # This model isn't giving us labels after all, don't try training with it - tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key} + tf_inputs_dict = { + key: val + for key, val in tf_inputs_dict.items() + if "head_mask" not in key and isinstance(val, tf.Tensor) + } tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor input_dataset = Dataset.from_dict(tf_inputs_dict) tf_dataset = model.prepare_tf_dataset( @@ -1853,8 +1859,8 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): def random_attention_mask(shape, rng=None, name=None, dtype=None): attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype) - # make sure that at least one token is attended to for each batch - attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1) + # Mark the first token as 1 (matches behaviour of PyTorch/Flax function) + attn_mask = tf.concat([tf.ones_like(attn_mask[:, :1]), attn_mask[:, 1:]], axis=1) return attn_mask