From 8282db5cc9cd17a99dab7e6260af0fe11cd02b2a Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Wed, 22 May 2024 19:37:15 +0200 Subject: [PATCH] Paligemma causal attention mask (#30967) * PaliGemma working causal attention * Formatting * Style * Docstrings + remove commented code * Update docstring for PaliGemma Config * PaliGemma - add separator ind to model/labels * Refactor + docstring paligemma processor method * Style * return token type ids when tokenizing labels * use token type ids when building causal mask * add token type ids to tester * remove separator from config * fix style * don't ignore separator * add processor documentation * simplify tokenization * fix causal mask * style * fix label propagation, revert suffix naming * fix style * fix labels tokenization * [run-slow]paligemma * add eos if suffixes are present * [run-slow]paligemma * [run-slow]paligemma * add misssing tokens to fast version * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style * [run-slow]paligemma --------- Co-authored-by: Peter Robicheaux Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/paligemma/modeling_paligemma.py | 53 +++++++-- .../models/paligemma/processing_paligemma.py | 105 ++++++++++++------ .../paligemma/test_modeling_paligemma.py | 2 + 3 files changed, 113 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index b2f2904b2f83ac..532f9e6c80b6a7 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -282,9 +282,14 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.vocab_size = model_embeds.num_embeddings return model_embeds - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position + ): _, _, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape + dtype, device = inputs_embeds.dtype, inputs_embeds.device + min_dtype = torch.finfo(dtype).min + scaled_image_features = image_features / (self.config.hidden_size**0.5) final_embedding = torch.zeros( batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device @@ -305,24 +310,43 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features ) final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) + if attention_mask is not None: + position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1) + else: + position_ids = None - final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1) - final_attention_mask_4d = final_attention_mask_4d.float().expand( - -1, self.config.text_config.num_key_value_heads, -1, -1 - ) - - # position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1) - # position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids) - position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1) + if token_type_ids is not None and labels is not None: + # we are training thus we need to create a full mask on the image + prefix but causal on suffix + target_length = cache_position[-1] + 1 + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + # unmask the prefill + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :] == 0, 0 + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) - if labels is not None: final_labels = torch.full( (batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device ) final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels) else: + causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1) + causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1) final_labels = None - return final_embedding, final_attention_mask_4d, final_labels, position_ids + return final_embedding, causal_mask, final_labels, position_ids @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -333,6 +357,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -396,8 +421,10 @@ def forward( selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) + if cache_position is None: + cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels + image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position ) else: @@ -486,6 +513,7 @@ def prepare_inputs_for_generation( cache_position=None, pixel_values=None, attention_mask=None, + token_type_ids=None, **kwargs, ): past_length = 0 @@ -544,6 +572,7 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values, + "token_type_ids": token_type_ids, } ) return model_inputs diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 258954d8569a2a..e4e2d5d15b6db4 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -23,13 +23,20 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import AddedToken, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...tokenization_utils_base import ( + AddedToken, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) from ...utils import TensorType logger = logging.getLogger(__name__) IMAGE_TOKEN = "" +EXTRA_TOKENS = ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''] # fmt: skip # Copied from transformers.models.idefics2.processing_idefics2.is_url @@ -64,7 +71,7 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token): image_seq_len (`int`): The length of the image sequence. image_token (`str`): The image token. """ - return f"{image_token * image_seq_len}{bos_token}{prompt}" + return f"{image_token * image_seq_len}{bos_token}{prompt}\n" class PaliGemmaProcessor(ProcessorMixin): @@ -85,7 +92,11 @@ class PaliGemmaProcessor(ProcessorMixin): image_processor_class = "SiglipImageProcessor" tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") - def __init__(self, image_processor=None, tokenizer=None): + def __init__( + self, + image_processor=None, + tokenizer=None, + ): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: @@ -98,7 +109,10 @@ def __init__(self, image_processor=None, tokenizer=None): image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) tokens_to_add = {"additional_special_tokens": [image_token]} tokenizer.add_special_tokens(tokens_to_add) + tokenizer.add_tokens(EXTRA_TOKENS) self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False super().__init__(image_processor, tokenizer) @@ -116,12 +130,15 @@ def __call__( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 - input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821 + input_data_format: Optional[ + Union[str, "ChannelDimension"] # noqa: F821 + ] = None, resample: "PILImageResampling" = None, # noqa: F821 do_convert_rgb: bool = None, do_thumbnail: bool = None, do_align_long_axis: bool = None, do_rescale: bool = None, + suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -130,6 +147,25 @@ def __call__( SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring of the above two methods for more information. + The usage for PaliGemma fine-tuning preparation is slightly different than usual. suffix passed are suffixes to + the prompt in `text`, and will be placed after the prompt. This is because attention is handled differently for + the prefix and the suffix. For instance, + ```python + image = PIL_cow_image + prompt = "answer en Where is the cow standing?" + suffix = "on the beach" + inputs = processor(text=prompt, images=image, suffix=suffix) + ``` + Here `inputs` will contain the `input_ids` and `token_type_ids` that follow + ```python + inputs["input_ids"][:, 256:] + # tensor([[ 2, 6006, 603, 573, 13910, 9980, 235336, 108, 477, 573, 8318]]) + inputs["token_type_ids"][:, 256:] + tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]) + ``` + Meaning the last three tokens are of "label" ("suffix") type while the other ones are of "prefix" type. + + Args: text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings @@ -161,16 +197,24 @@ def __call__( - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. + suffix (`str`, `List[str]`, `List[List[str]]`): + The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md + for more information. If your prompt is " What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench". Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix` + is provided, the `input_ids` will also contain the suffix input ids. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **labels** -- Labels compatible with training if `suffix` is not None """ + + return_token_type_ids = True if suffix is not None else False + if images is None: raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.") if text is None: @@ -188,6 +232,11 @@ def __call__( text = [text] elif isinstance(text, list) and _is_str_or_image(text[0]): pass + if suffix is not None and _is_str_or_image(suffix): + suffix = [suffix] + if suffix is not None: + suffix = [sfx + self.tokenizer.eos_token for sfx in suffix] + input_strings = [ build_string_from_input( prompt=prompt, @@ -214,36 +263,22 @@ def __call__( if max_length is not None: max_length += self.image_seq_length # max_length has to account for the image tokens - if tokenize_newline_separately: - inputs = self.tokenizer( - input_strings, - add_special_tokens=False, - return_tensors=None, - padding="do_not_pad", - max_length=max_length, - truncation=truncation, - ) - newline_token = self.tokenizer.convert_tokens_to_ids("\n") - concatenated_ids = [ids + [newline_token] for ids in inputs["input_ids"]] - concatenated_attention_masks = [mask + [1] for mask in inputs["attention_mask"]] - - text_inputs = self.tokenizer.pad( - {"input_ids": concatenated_ids, "attention_mask": concatenated_attention_masks}, - max_length=max_length, - padding=padding, - return_tensors=return_tensors, - ) - else: - text_inputs = self.tokenizer( - input_strings, - add_special_tokens=False, - return_tensors=return_tensors, - padding=padding, - max_length=max_length, - truncation=truncation, - ) - - return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + inputs = self.tokenizer( + input_strings, + text_pair=suffix, + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + truncation=truncation, + return_token_type_ids=return_token_type_ids, + ) + + return_data = {**inputs, "pixel_values": pixel_values} + + if return_token_type_ids: + labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + return BatchFeature(data=return_data) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma def batch_decode(self, *args, **kwargs): diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 0653e1057cb667..a1a69766fcb78d 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -163,6 +163,8 @@ def prepare_config_and_inputs_for_common(self): "pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask, + "labels": input_ids, + "token_type_ids": torch.zeros_like(input_ids), } return config, inputs_dict