diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index ed8ddd3c47dea3..a2ba1c0532933c 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2186,29 +2186,14 @@ def forward( # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) - language_model_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) - else: - logger.warning_once( - "Expanding inputs for image tokens in BLIP-2 should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 - ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -2304,9 +2289,6 @@ def generate( query_output = query_outputs.last_hidden_state language_model_inputs = self.language_projection(query_output) - language_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] @@ -2319,36 +2301,15 @@ def generate( if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for image tokens in BLIP-2 should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 - ) - - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = ( - generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - ) - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} if not self.language_model.config.is_encoder_decoder: inputs["input_ids"] = input_ids outputs = self.language_model.generate(**inputs, **generate_kwargs) + return outputs diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index 4129920f9b3663..de3e686b177b19 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -81,6 +81,10 @@ def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): self.image_token = tokenizer.image_token self.num_query_tokens = num_query_tokens + # We'll add the BOS manually as it has to be after image tokens + tokenizer.add_bos_token = False + self.bos_token = tokenizer.bos_token + super().__init__(image_processor, tokenizer) def __call__( @@ -118,11 +122,9 @@ def __call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + # BC for explicit return_tensors - if "return_tensors" in output_kwargs["common_kwargs"]: - return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) - else: - return_tensors = None + return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) encoding = BatchFeature(tensor_type=return_tensors) if text is not None: if isinstance(text, str): @@ -130,37 +132,16 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - text_encoding = {} - - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) - output_kwargs["text_kwargs"]["return_tensors"] = return_tensors - - # if we know how many query tokens, expand text inside processor. We need this hacky manipulation - # because BLIP expects image tokens to be at the beginning even before BOS token - if self.num_query_tokens is not None: - image_tokens = self.image_token.content * self.num_query_tokens - image_token_encoding = self.tokenizer( - [image_tokens] * len(text), add_special_tokens=False, return_tensors=None - ) - for k in _text_encoding: - text_encoding[k] = [ - img_encoding + txt_encoding - for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k]) - ] - else: - text_encoding = _text_encoding - logger.warning_once( - "Expanding inputs for image tokens in BLIP-2 should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) + # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token + image_tokens = self.image_token.content * self.num_query_tokens + text = [f"{image_tokens}{self.bos_token}{sample}" for sample in text] + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) # cast to desired return tensors type encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors)) + # add pixel_values encoding. If we also have text_encoding, update image encoding and return it. # else, return the text encoding. - if images is not None: image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) encoding.update(image_encoding) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index acce24cc42f5d8..5dfccf8320b70a 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1454,29 +1454,13 @@ def forward( # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) - language_model_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for image tokens in InstructBLIP should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 - ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1586,9 +1570,6 @@ def generate( query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :] language_model_inputs = self.language_projection(query_output) - language_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] @@ -1602,29 +1583,8 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for image tokens in InstructBLIP should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 - ) - - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = ( - generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - ) - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} if not self.language_model.config.is_encoder_decoder: diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index a96d97fb07e1d9..032069a33ce42a 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -24,7 +24,6 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import ( AddedToken, - BatchEncoding, PreTokenizedInput, TextInput, ) @@ -84,6 +83,10 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_toke else: self.image_token = tokenizer.image_token self.num_query_tokens = num_query_tokens + + # We'll add the BOS manually as it has to be after image tokens + tokenizer.add_bos_token = False + self.bos_token = tokenizer.bos_token super().__init__(image_processor, tokenizer, qformer_tokenizer) def __call__( @@ -125,34 +128,10 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - # we have to concatenate lists - so we keep track of return_tensors here - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) - output_kwargs["text_kwargs"]["return_tensors"] = return_tensors - # if we know how many query tokens, expand text inside processor. We need this hacky manipulation - # because BLIP expects image tokens to be at the beginning even before BOS token - if self.num_query_tokens is not None and images is not None: - text_encoding = {} - image_tokens = self.image_token.content * self.num_query_tokens - image_token_encoding = self.tokenizer( - [image_tokens] * len(text), add_special_tokens=False, return_tensors=None - ) - for k in _text_encoding: - text_encoding[k] = [ - img_encoding + txt_encoding - for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k]) - ] - else: - text_encoding = _text_encoding - if images is not None: - logger.warning_once( - "Expanding inputs for image tokens in InstructBLIP should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - - # cast to desired return tensors type after concatenating - text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) + # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token + image_tokens = self.image_token.content * self.num_query_tokens + text_with_images = [f"{image_tokens}{self.bos_token}{sample}" for sample in text] + text_encoding = self.tokenizer(text_with_images, **output_kwargs["text_kwargs"]) encoding.update(text_encoding) qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"]) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index e91b05bc015263..8227fee1a34660 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1481,29 +1481,13 @@ def forward( # unbatch inputs back, each video-frame gets `num_query_tokens` seq length language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) - language_model_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 - ) + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1621,9 +1605,6 @@ def generate( # unbatch the embeddings back by moving frames to seq-len language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) - language_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] @@ -1637,29 +1618,8 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 - ) - - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = ( - generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - ) - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} if not self.language_model.config.is_encoder_decoder: diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index 7184955af3aa56..2b9721913fd3aa 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -294,29 +294,13 @@ def forward( # unbatch inputs back, each video-frame gets `num_query_tokens` seq length language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) - language_model_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 - ) + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -434,9 +418,6 @@ def generate( # unbatch the embeddings back by moving frames to seq-len language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) - language_attention_mask = torch.ones( - language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device - ) if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] @@ -450,29 +431,8 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds - # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() - else: - logger.warning_once( - "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - attention_mask = torch.cat( - [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 - ) - - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = ( - generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - ) - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} if not self.language_model.config.is_encoder_decoder: diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 1d4e59e26b4621..63931a60a68062 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -24,7 +24,6 @@ from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import ( AddedToken, - BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, @@ -69,6 +68,10 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_toke else: self.video_token = tokenizer.video_token self.num_query_tokens = num_query_tokens + + # We'll add the BOS manually as it has to be after image tokens + tokenizer.add_bos_token = False + self.bos_token = tokenizer.bos_token super().__init__(image_processor, tokenizer, qformer_tokenizer) def __call__( @@ -108,8 +111,13 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - _text_encoding = self.tokenizer( - text=text, + # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token + # InstrucBLIP works with 4 frames only + video_tokens = self.video_token.content * self.num_query_tokens * 4 + text_with_videos = [f"{video_tokens}{self.bos_token}{sample}" for sample in text] + + text_encoding = self.tokenizer( + text=text_with_videos, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, @@ -127,32 +135,6 @@ def __call__( **kwargs, ) - # if we know how many query tokens, expand text inside processor. We need this hacky manipulation - # because BLIP expects image tokens to be at the beginning even before BOS token - if self.num_query_tokens is not None and images is not None: - text_encoding = {} - video_tokens = ( - self.video_token.content * self.num_query_tokens * 4 - ) # InstrucBLIP works with 4 frames only - video_token_encoding = self.tokenizer( - [video_tokens] * len(text), add_special_tokens=False, return_tensors=None - ) - for k in _text_encoding: - text_encoding[k] = [ - img_encoding + txt_encoding - for img_encoding, txt_encoding in zip(video_token_encoding[k], _text_encoding[k]) - ] - else: - text_encoding = _text_encoding - if images is not None: - logger.warning_once( - "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " - "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - - # cast to desired return tensors type after concatenating - text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) encoding.update(text_encoding) qformer_text_encoding = self.qformer_tokenizer( text=text, diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e8536ee50f94bb..07b61321d0de83 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -461,18 +461,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, @@ -480,56 +471,8 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 08caa3d1d8a75a..d024aa9a56d70f 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -154,27 +154,17 @@ def __call__( # try to expand inputs in processing if we have the necessary parts prompt_strings = text if image_inputs.get("pixel_values") is not None: - if self.patch_size is not None and self.vision_feature_select_strategy is not None: - # Replace the image token with the expanded image token sequence - pixel_values = image_inputs["pixel_values"] - height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * ( - width // self.patch_size - ) + self.num_additional_image_tokens - if self.vision_feature_select_strategy == "default": - num_image_tokens -= self.num_additional_image_tokens - - prompt_strings = [] - for sample in text: - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) - prompt_strings.append(sample) - else: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs}) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 269663c7d6141a..a860e8978ccb24 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -835,18 +835,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -863,58 +854,6 @@ def forward( image_newline=self.image_newline, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - if input_ids.shape[1] != 1: - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] if n_image_tokens != n_image_features: diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 38173cbd861fc1..6889a57148553e 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -149,30 +149,19 @@ def __call__( prompt_strings = text if image_inputs: - if self.patch_size is None or self.vision_feature_select_strategy is None: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - else: - image_sizes = iter(image_inputs["image_sizes"]) - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for sample in text: - while self.image_token in sample: - image_size = next(image_sizes) - if not isinstance(image_size, (list, tuple)): - # cast to list to avoid numerical precision errors when calculating unpadding - image_size = image_size.tolist() - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= self.num_additional_image_tokens - sample = sample.replace(self.image_token, "" * num_image_tokens, 1) - prompt_strings.append(sample) - prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index b0a20d6c5ccd93..1fd083558ccc25 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -909,25 +909,9 @@ def forward( "and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -942,7 +926,21 @@ def forward( image_newline=self.image_newline, ) - video_features = video_feature_lens = None + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: video_features = self.get_video_features( pixel_values_videos, @@ -954,95 +952,20 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if image_features is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_features is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 3d6431d7ea29ba..d21d29c71630ff 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -426,25 +426,9 @@ def forward( "and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -459,7 +443,21 @@ def forward( image_newline=self.image_newline, ) - video_features = video_feature_lens = None + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: video_features = self.get_video_features( pixel_values_videos, @@ -471,95 +469,20 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if image_features is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_features is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index 65195b77240721..464f84cff757b1 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -173,48 +173,33 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - if self.patch_size is None or self.vision_feature_select_strategy is None: - logger.warning_once( - "Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size`, `num_additional_image_tokens` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` " - "and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - else: - # images expand taking into account num_of_patches in each image - if image_inputs: - image_sizes = iter(image_inputs["image_sizes"]) - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for sample in text: - while self.image_token in sample: - image_size = next(image_sizes) - if not isinstance(image_size, (list, tuple)): - # cast to list to avoid numerical precision errors when calculating unpadding - image_size = image_size.tolist() - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= self.num_additional_image_tokens - sample = sample.replace(self.image_token, "" * num_image_tokens, 1) - prompt_strings.append(sample) - text = [sample.replace("", self.image_token) for sample in prompt_strings] - - # videos are easier, simply get frames and multiply - if videos_inputs: - one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(one_video[0]) - num_frames = one_video.shape[0] # frame dim is always after batch dim + if image_inputs: + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] - # no `self.num_additional_image_tokens` added because video always has a default feature selection strategy - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) - num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer - prompt_strings = [] - for sample in text: - sample = sample.replace(self.video_token, self.video_token * num_video_tokens) - prompt_strings.append(sample) - text = prompt_strings + # videos are easier, simply get frames and multiply + if videos_inputs: + one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer + prompt_strings = [] + for sample in text: + sample = sample.replace(self.video_token, self.video_token * num_video_tokens) + prompt_strings.append(sample) + text = prompt_strings text_inputs = self.tokenizer( text, diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 30adcb6ab5c089..c7f2cea69b9adc 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -538,127 +538,49 @@ def forward( "time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and ( - pixel_values_images is not None or pixel_values_videos is not None - ) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = None if pixel_values_images is not None: image_features = self.get_image_features( pixel_values_images, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - video_features = None - num_frames = 0 if pixel_values_videos is not None: video_features, num_frames = self.get_video_features( pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - if input_ids.shape[1] != 1: - for features, frames in ((image_features, 1), (video_features, num_frames)): - if features is not None: - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - input_ids, - ) = self._merge_input_ids_with_visual_features( - features, - inputs_embeds, - input_ids, - attention_mask, - labels, - num_frames=frames, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if pixel_values_images is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if pixel_values_videos is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] * video_features.shape[1] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] * video_features.shape[1] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index 3e1884271efe2b..4ef19387036b3c 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -158,16 +158,8 @@ def __call__( raise ValueError("Invalid input text. Please provide a string, or a list of strings") prompt_strings = text - if encoded_images is not None and (self.patch_size is None or self.vision_feature_select_strategy is None): - logger.warning_once( - "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set " - "directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = " - "{{vision_feature_select_strategy}}`. Using processors without these attributes in the config is " - "deprecated and will throw an error in v4.50." - ) - # Replace the image/video tokens with the expanded token sequence - elif encoded_images is not None: + + if encoded_images is not None: if "pixel_values_images" in encoded_images.keys(): height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0])) num_frames = 1 diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index b45325d2194e24..9b9054b6bb36a4 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -455,68 +455,14 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in VipLLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # in the case one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] if n_image_tokens != n_image_features: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 76ab793e3a36c0..ba1e10e2e486e2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -458,6 +458,7 @@ def test_greedy_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: + print(output_generate.shape[-1], self.max_new_tokens, inputs_dict["input_ids"].shape[-1]) self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) @pytest.mark.generate diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index a1ea708efd665b..cc18ec31a861dd 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -1004,7 +1004,14 @@ def get_config(self): # this model tester uses an encoder-decoder language model (T5) class Blip2ModelTester: def __init__( - self, parent, vision_kwargs=None, qformer_kwargs=None, text_kwargs=None, is_training=True, num_query_tokens=10 + self, + parent, + vision_kwargs=None, + qformer_kwargs=None, + text_kwargs=None, + is_training=True, + num_query_tokens=10, + image_token_index=4, ): if vision_kwargs is None: vision_kwargs = {} @@ -1021,6 +1028,7 @@ def __init__( self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests self.is_training = is_training self.num_query_tokens = num_query_tokens + self.image_token_index = image_token_index def prepare_config_and_inputs(self): _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() @@ -1043,6 +1051,7 @@ def get_config(self): qformer_config=self.qformer_model_tester.get_config(), text_config=self.text_model_tester.get_config(), num_query_tokens=self.num_query_tokens, + image_token_index=self.image_token_index, ) def create_and_check_for_conditional_generation( diff --git a/tests/models/blip_2/test_processor_blip_2.py b/tests/models/blip_2/test_processor_blip_2.py index 7eb5bedc2be7a7..54f9de09861b78 100644 --- a/tests/models/blip_2/test_processor_blip_2.py +++ b/tests/models/blip_2/test_processor_blip_2.py @@ -47,6 +47,9 @@ def get_tokenizer(self, **kwargs): def get_image_processor(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + def prepare_processor_dict(self): + return {"num_query_tokens": 1} + def tearDown(self): shutil.rmtree(self.tmpdirname) @@ -81,26 +84,12 @@ def test_image_processor(self): for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) - def test_tokenizer(self): - image_processor = self.get_image_processor() - tokenizer = self.get_tokenizer() - - processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor) - - input_str = "lower newer" - - encoded_processor = processor(text=input_str) - - encoded_tok = tokenizer(input_str, return_token_type_ids=False) - - for key in encoded_tok.keys(): - self.assertListEqual(encoded_tok[key], encoded_processor[key][0]) - def test_processor(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() + processor_kwargs = self.prepare_processor_dict() - processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor) + processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs) input_str = "lower newer" image_input = self.prepare_image_inputs() @@ -116,8 +105,9 @@ def test_processor(self): def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() + processor_kwargs = self.prepare_processor_dict() - processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor) + processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs) predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] @@ -129,8 +119,9 @@ def test_tokenizer_decode(self): def test_model_input_names(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() + processor_kwargs = self.prepare_processor_dict() - processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor) + processor = Blip2Processor(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs) input_str = "lower newer" image_input = self.prepare_image_inputs() diff --git a/tests/models/instructblip/test_processor_instructblip.py b/tests/models/instructblip/test_processor_instructblip.py index ffec4b01112c2f..72586924992cd3 100644 --- a/tests/models/instructblip/test_processor_instructblip.py +++ b/tests/models/instructblip/test_processor_instructblip.py @@ -58,6 +58,9 @@ def get_image_processor(self, **kwargs): def get_qformer_tokenizer(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer + def prepare_processor_dict(self): + return {"num_query_tokens": 1} + def tearDown(self): shutil.rmtree(self.tmpdirname) @@ -87,9 +90,13 @@ def test_image_processor(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) image_input = self.prepare_image_inputs() @@ -100,35 +107,17 @@ def test_image_processor(self): for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) - def test_tokenizer(self): - image_processor = self.get_image_processor() - tokenizer = self.get_tokenizer() - qformer_tokenizer = self.get_qformer_tokenizer() - - processor = InstructBlipProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - - input_str = ["lower newer"] - - encoded_processor = processor(text=input_str) - - encoded_tokens = tokenizer(input_str, return_token_type_ids=False) - encoded_tokens_qformer = qformer_tokenizer(input_str, return_token_type_ids=False) - - for key in encoded_tokens.keys(): - self.assertListEqual(encoded_tokens[key], encoded_processor[key]) - - for key in encoded_tokens_qformer.keys(): - self.assertListEqual(encoded_tokens_qformer[key], encoded_processor["qformer_" + key]) - def test_processor(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) input_str = "lower newer" @@ -149,9 +138,13 @@ def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] @@ -165,9 +158,13 @@ def test_model_input_names(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) input_str = "lower newer" diff --git a/tests/models/instructblipvideo/test_processor_instructblipvideo.py b/tests/models/instructblipvideo/test_processor_instructblipvideo.py index d613d878223213..443a6e12b82338 100644 --- a/tests/models/instructblipvideo/test_processor_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_processor_instructblipvideo.py @@ -59,6 +59,9 @@ def get_image_processor(self, **kwargs): def get_qformer_tokenizer(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer + def prepare_processor_dict(self): + return {"num_query_tokens": 1} + def tearDown(self): shutil.rmtree(self.tmpdirname) @@ -88,9 +91,13 @@ def test_image_processor(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipVideoProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) image_input = self.prepare_image_inputs() @@ -101,35 +108,17 @@ def test_image_processor(self): for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) - def test_tokenizer(self): - image_processor = self.get_image_processor() - tokenizer = self.get_tokenizer() - qformer_tokenizer = self.get_qformer_tokenizer() - - processor = InstructBlipVideoProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer - ) - - input_str = ["lower newer"] - - encoded_processor = processor(text=input_str) - - encoded_tokens = tokenizer(input_str, return_token_type_ids=False) - encoded_tokens_qformer = qformer_tokenizer(input_str, return_token_type_ids=False) - - for key in encoded_tokens.keys(): - self.assertListEqual(encoded_tokens[key], encoded_processor[key]) - - for key in encoded_tokens_qformer.keys(): - self.assertListEqual(encoded_tokens_qformer[key], encoded_processor["qformer_" + key]) - def test_processor(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipVideoProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) input_str = "lower newer" @@ -150,9 +139,13 @@ def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipVideoProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] @@ -166,9 +159,13 @@ def test_model_input_names(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() qformer_tokenizer = self.get_qformer_tokenizer() + processor_kwargs = self.prepare_processor_dict() processor = InstructBlipVideoProcessor( - tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer + tokenizer=tokenizer, + image_processor=image_processor, + qformer_tokenizer=qformer_tokenizer, + **processor_kwargs, ) input_str = "lower newer" diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 3d08ab35e0f630..8a796597ab7bb5 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -330,9 +330,6 @@ def test_small_model_integration_test(self): raw_image = Image.open(requests.get(image_file, stream=True).raw) inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") - EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip - self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) - output = model.generate(**inputs, max_new_tokens=20) EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip @@ -509,32 +506,18 @@ def test_llava_merge_inputs_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore model_id = "llava-hf/llava-1.5-7b-hf" model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) - # Simulate some user inputs - pixel_values = torch.randn( - (1, 3, 336, 336), - dtype=torch.float, - device=torch_device, - ) - input_ids = torch.tensor( - [ - [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], - ], - dtype=torch.long, - device=torch_device, - ) - attention_mask = torch.tensor( - [[0, 0, 1, 1, 1, 1, 1, 1, 1]], - dtype=torch.long, - device=torch_device, - ) + prompt = "USER: \nDescribe the imageASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) # Make sure that the loss is properly computed loss = model( - pixel_values=pixel_values, - input_ids=input_ids, - attention_mask=attention_mask, - labels=input_ids, + **inputs, + labels=inputs.input_ids.clone(), ).loss loss.backward() diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index d3a66a16df9a64..3e6c1a9a969f0b 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -50,7 +50,7 @@ def tearDown(self): shutil.rmtree(self.tmpdirname) def prepare_processor_dict(self): - return {"chat_template": "dummy_template"} + return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"} @unittest.skip( "Skip because the model has no processor kwargs except for chat template and" diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index 45faa24526305c..234e4791100054 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -27,7 +27,7 @@ if is_vision_available(): - from transformers import CLIPImageProcessor + from transformers import LlavaNextImageProcessor @require_vision @@ -37,7 +37,7 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - image_processor = CLIPImageProcessor() + image_processor = LlavaNextImageProcessor() tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") processor_kwargs = self.prepare_processor_dict() processor = LlavaNextProcessor(image_processor, tokenizer, **processor_kwargs) @@ -50,7 +50,7 @@ def get_image_processor(self, **kwargs): return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor def prepare_processor_dict(self): - return {"chat_template": "dummy_template"} + return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"} @unittest.skip( "Skip because the model has no processor kwargs except for chat template and" diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 14b079665ab6d6..eb4b9e3ed083d7 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -127,7 +127,6 @@ def __init__( self.num_image_tokens = (vision_config["image_size"] // vision_config["patch_size"]) ** 2 self.num_video_tokens = (self.num_image_tokens + 1) * self.num_frames self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens - self.encoder_seq_length = self.seq_length def get_config(self): return VideoLlavaConfig( @@ -185,22 +184,6 @@ def prepare_config_and_inputs_for_common(self): } return config, inputs_dict - def prepare_config_and_inputs_for_batched_test(self): - config_and_inputs = self.prepare_config_and_inputs() - config, _, pixel_values_videos = config_and_inputs - input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 - attention_mask = input_ids.ne(1).to(torch_device) - - # make sure no other special tokens are set - input_ids[(input_ids == 0) | (input_ids == 1)] = 3 - input_ids[:, 0] = config.video_token_index - inputs_dict = { - "pixel_values_videos": pixel_values_videos, - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return config, inputs_dict - @require_torch class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -339,7 +322,7 @@ def recursive_check(batched_object, single_row_object, model_name, key): ), ) - config, batched_input = self.model_tester.prepare_config_and_inputs_for_batched_test() + config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: config.output_hidden_states = True diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 4f501fc10a028f..8436a6655d9f25 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -322,7 +322,7 @@ def test_small_model_integration_test(self): outputs = model.generate(**inputs, max_new_tokens=10) - EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on" + EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on" self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow @@ -331,32 +331,18 @@ def test_vipllava_merge_inputs_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore model_id = "llava-hf/vip-llava-7b-hf" model = VipLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) - # Simulate some user inputs - pixel_values = torch.randn( - (1, 3, 336, 336), - dtype=torch.float, - device=torch_device, - ) - input_ids = torch.tensor( - [ - [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], - ], - dtype=torch.long, - device=torch_device, - ) - attention_mask = torch.tensor( - [[0, 0, 1, 1, 1, 1, 1, 1, 1]], - dtype=torch.long, - device=torch_device, - ) + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png" + image = Image.open(requests.get(url, stream=True).raw) + prompt = "USER: \nCan you please describe this image?\nASSISTANT:" + + inputs = processor(prompt, image, return_tensors="pt").to(torch_device, torch.float16) # Make sure that the loss is properly computed loss = model( - pixel_values=pixel_values, - input_ids=input_ids, - attention_mask=attention_mask, - labels=input_ids, + **inputs, + labels=inputs.input_ids.clone(), ).loss loss.backward() diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 4c4f6fac49813f..d81386f9e0c7db 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -174,8 +174,9 @@ def test_tokenizer_defaults_preserved_by_kwargs(self): self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + processor_kwargs = self.prepare_processor_dict() - processor = self.processor_class(**processor_components) + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() image_input = self.prepare_image_inputs() @@ -195,8 +196,9 @@ def test_image_processor_defaults_preserved_by_image_kwargs(self): "image_processor", do_rescale=True, rescale_factor=-1 ) processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + processor_kwargs = self.prepare_processor_dict() - processor = self.processor_class(**processor_components) + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() @@ -210,8 +212,9 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self): self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest") + processor_kwargs = self.prepare_processor_dict() - processor = self.processor_class(**processor_components) + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() image_input = self.prepare_image_inputs() @@ -228,8 +231,9 @@ def test_kwargs_overrides_default_image_processor_kwargs(self): "image_processor", do_rescale=True, rescale_factor=1 ) processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + processor_kwargs = self.prepare_processor_dict() - processor = self.processor_class(**processor_components) + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() @@ -242,7 +246,8 @@ def test_unstructured_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() - processor = self.processor_class(**processor_components) + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() @@ -264,7 +269,8 @@ def test_unstructured_kwargs_batched(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() - processor = self.processor_class(**processor_components) + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs(batch_size=2) @@ -289,7 +295,8 @@ def test_doubly_passed_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() - processor = self.processor_class(**processor_components) + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = [self.prepare_text_inputs()] @@ -307,7 +314,8 @@ def test_structured_kwargs_nested(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() - processor = self.processor_class(**processor_components) + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() @@ -330,7 +338,8 @@ def test_structured_kwargs_nested_from_dict(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") processor_components = self.prepare_components() - processor = self.processor_class(**processor_components) + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class(**processor_components, **processor_kwargs) self.skip_processor_without_typed_kwargs(processor) input_str = self.prepare_text_inputs() image_input = self.prepare_image_inputs() diff --git a/tests/utils/tiny_model_summary.json b/tests/utils/tiny_model_summary.json index f27f720ec3d593..6c36448e5aba7e 100644 --- a/tests/utils/tiny_model_summary.json +++ b/tests/utils/tiny_model_summary.json @@ -626,7 +626,7 @@ "model_classes": [ "Blip2ForConditionalGeneration" ], - "sha": "35e1ef43da3554af62eb29a7b3dbbef3f3bef48e" + "sha": "d0de11fd1f8ca481231c07ee0934924be96cb281" }, "Blip2Model": { "tokenizer_classes": [ diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 1c81c08fd845b1..0b8ce82e54c204 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -295,6 +295,9 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "unk_index", "mask_index", "image_token_index", # for VLMs + "video_token_index", + "image_seq_length", + "video_seq_length", "image_size", "use_cache", "out_features",