diff --git a/.gitignore b/.gitignore index 337f2ef2c735e8..70bf18f45d42f5 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ tags # ruff .ruff_cache +test.py diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index 0e415980c97fdd..c5105037b8d69e 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -53,21 +53,6 @@ logger = logging.get_logger(__name__) -def make_list_of_list_of_images( - images: Union[List[List[ImageInput]], List[ImageInput], ImageInput] -) -> List[List[ImageInput]]: - if is_valid_image(images): - return [[images]] - - if isinstance(images, list) and all(isinstance(image, list) for image in images): - return images - - if isinstance(images, list): - return [make_list_of_images(image) for image in images] - - raise ValueError("images must be a list of list of images or a list of images or an image.") - - class FuyuBatchFeature(BatchFeature): """ BatchFeature class for Fuyu image processor and processor. @@ -356,7 +341,7 @@ def pad_image( input_data_format=input_data_format, ) return padded_image - + def preprocess( self, images, @@ -441,86 +426,54 @@ def preprocess( rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor patch_size = patch_size if patch_size is not None else self.patch_size - if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images): - raise ValueError("Multiple images for a single sample are not yet supported.") - - batch_images = make_list_of_list_of_images(images) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and image_mean is None or image_std is None: - raise ValueError("image_mean and image_std must be specified if do_normalize is True.") - - # All transformations expect numpy arrays. - batch_images = [[to_numpy_array(image) for image in images] for images in batch_images] - - if is_scaled_image(batch_images[0][0]) and do_rescale: - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." - ) - - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(batch_images[0][0]) - - original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] - - if do_resize: - batch_images = [ - [self.resize(image, size=size, input_data_format=input_data_format) for image in images] - for images in batch_images - ] - - image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] - image_unpadded_heights = [[image_size[0]] for image_size in image_sizes] - image_unpadded_widths = [[image_size[1]] for image_size in image_sizes] - - # scale_h is the same as scale_w - image_scale_factors = [ - [resized_size[0] / original_size[0]] - for original_size, resized_size in zip(original_image_sizes, image_sizes) - ] + batch_images = images + original_image_sizes = [] + batch_image_sizes = [] + image_unpadded_heights = [] + image_unpadded_widths = [] + image_scale_factors = [] + + for image_list in batch_images: + original_sizes_per_list = [] + batch_sizes_per_list = [] + unpadded_heights_per_list = [] + unpadded_widths_per_list = [] + scale_factors_per_list = [] + + #If there is no image in the list, make a placeholder image then preprocess the image + for idx, image in enumerate(image_list if image_list else [np.zeros((size['height'], size['width'], 3), dtype=np.uint8)]): + image = to_numpy_array(image) if image_list else image + + original_size = get_image_size(image, channel_dim=input_data_format) + original_sizes_per_list.append(original_size) + + if do_resize: + image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + if do_pad: + image = self.pad_image(image, size=size, mode=padding_mode, constant_values=padding_value, input_data_format=input_data_format) + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + if do_normalize: + image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_data_format) + + batch_size = get_image_size(image, channel_dim=input_data_format) + batch_sizes_per_list.append(batch_size) + unpadded_heights_per_list.append([batch_size[0]]) + unpadded_widths_per_list.append([batch_size[1]]) + scale_factors_per_list.append([batch_size[0] / original_size[0]]) + + if not image_list: + image_list.append(image) + else: + image_list[idx] = image - if do_pad: - batch_images = [ - [ - self.pad_image( - image, - size=size, - mode=padding_mode, - constant_values=padding_value, - input_data_format=input_data_format, - ) - for image in images - ] - for images in batch_images - ] - - if do_rescale: - batch_images = [ - [self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images] - for images in batch_images - ] - - if do_normalize: - batch_images = [ - [ - self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format) - for image in images - ] - for images in batch_images - ] - - if data_format is not None: - batch_images = [ - [to_channel_dimension_format(image, data_format, input_data_format) for image in images] - for images in batch_images - ] + original_image_sizes.append(original_sizes_per_list) + batch_image_sizes.append(batch_sizes_per_list) + image_unpadded_heights.append(unpadded_heights_per_list) + image_unpadded_widths.append(unpadded_widths_per_list) + image_scale_factors.append(scale_factors_per_list) data = { "images": batch_images, @@ -528,6 +481,7 @@ def preprocess( "image_unpadded_widths": image_unpadded_widths, "image_scale_factors": image_scale_factors, } + return FuyuBatchFeature(data=data, tensor_type=return_tensors) def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int: diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index f7078554cbc08d..823176aefccc7c 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -80,9 +80,11 @@ def full_unpacked_stream_to_tensor( def construct_full_unpacked_stream( - num_real_text_tokens: Union[List[List[int]], "torch.Tensor"], input_stream: "torch.Tensor", image_tokens: List[List["torch.Tensor"]], + image_patch_indices_stream: "torch.Tensor", + image_patch_indices: List[List["torch.Tensor"]], + image_indicator_id: int, batch_size: int, num_sub_sequences: int, ) -> List["torch.Tensor"]: @@ -90,21 +92,36 @@ def construct_full_unpacked_stream( padding to account for images and then unpacks the subsequences to create a single sequence per item in the batch. Returns a list of tensors, one for each item in the batch.""" - all_bi_stream = [] + all_bi_stream, all_bi_image_patch_indices_stream = [], [] for batch_index in range(batch_size): - all_si_stream = [] - - # First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence - # and append to lists. We use lists rather than tensors because each subsequence is variable-sized. - # TODO Remove this logic in a subsequent release since subsequences are not supported. - image_adjustment = image_tokens[batch_index][0] - subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0) - num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0] - all_si_stream.append(subsequence_stream[:num_real_tokens]) - all_bi_stream.append(torch.cat(all_si_stream, dim=0)) - - return all_bi_stream + # Extract the subsequence from the input_stream; assume only one subsequence in text + subsequence_stream = input_stream[batch_index, 0].clone() + subsequence_patch_indices_stream = image_patch_indices_stream[batch_index, 0].clone() + + + # If there are image tokens for this batch item + if image_tokens[batch_index]: + # Find indices where image_indicator_id appears + image_indicator_indices = (subsequence_stream == image_indicator_id).nonzero(as_tuple=True)[0] + + # Assert that the number of image indicators matches the number of image tokens + if len(image_indicator_indices) > 0: + assert len(image_indicator_indices) == len(image_tokens[batch_index]), \ + "Number of image indicators does not match the number of image tokens." + + # Replace image_indicator_id with actual image tokens + offset = 0 + for idx, image_token, image_patch_indice in zip(image_indicator_indices, image_tokens[batch_index], image_patch_indices[batch_index]): + adjusted_idx = idx + offset + subsequence_stream = torch.cat([subsequence_stream[:adjusted_idx], image_token, subsequence_stream[adjusted_idx+1:]]) + subsequence_patch_indices_stream = torch.cat([subsequence_patch_indices_stream[:adjusted_idx], image_patch_indice, subsequence_patch_indices_stream[adjusted_idx+1:]]) + offset += len(image_token) - 1 # Adjust offset for subsequent replacements + + all_bi_stream.append(subsequence_stream) + all_bi_image_patch_indices_stream.append(subsequence_patch_indices_stream) + + return all_bi_stream, all_bi_image_patch_indices_stream def _replace_string_repr_with_token_tags(prompt: str) -> str: @@ -211,10 +228,7 @@ def _tokenize_prompts_with_image_and_batch( tokenizer, prompts: List[List[str]], scale_factors: Optional[List[List["torch.Tensor"]]], - max_tokens_to_generate: int, - max_position_embeddings: int, - add_BOS: bool, # Same issue with types as above - add_beginning_of_answer_token: bool, + image_tokens: Optional[List["torch.Tensor"]], ) -> Tuple["torch.Tensor", "torch.Tensor"]: """ Given a set of prompts and number of tokens to generate: @@ -238,46 +252,10 @@ def _tokenize_prompts_with_image_and_batch( prompts_tokens = transformed_prompt_tokens - if add_BOS: - bos_token = tokenizer.vocab[""] - else: - bos_token = tokenizer.vocab["|ENDOFTEXT|"] - prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens] - if add_beginning_of_answer_token: - boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING] - # Only add bbox open token to the last subsequence since that is what will be completed - for token_seq in prompts_tokens: - token_seq[-1].append(boa) - - # Now we have a list of list of tokens which each list has a different - # size. We want to extend this list to: - # - incorporate the tokens that need to be generated - # - make all the sequences equal length. - # Get the prompts length. - - prompts_length = [[len(x) for x in prompts_tokens_seq] for prompts_tokens_seq in prompts_tokens] - # Get the max prompts length. - max_prompt_len: int = np.max(prompts_length) - # Number of tokens in the each sample of the batch. - samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings) - if max_prompt_len + max_tokens_to_generate > max_position_embeddings: - logger.warning( - f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}", - f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.", - ) - # Now update the list of list to be of the same size: samples_length. - for prompt_tokens_seq, prompts_length_seq in zip(prompts_tokens, prompts_length): - for prompt_tokens, prompt_length in zip(prompt_tokens_seq, prompts_length_seq): - if len(prompt_tokens) > samples_length: - raise ValueError("Length of subsequence prompt exceeds sequence length.") - padding_size = samples_length - prompt_length - prompt_tokens.extend([tokenizer.vocab["|ENDOFTEXT|"]] * padding_size) - # Now we are in a structured format, we can convert to tensors. prompts_tokens_tensor = torch.tensor(prompts_tokens, dtype=torch.int64) - prompts_length_tensor = torch.tensor(prompts_length, dtype=torch.int64) - return prompts_tokens_tensor, prompts_length_tensor + return prompts_tokens_tensor # Simplified assuming self.crop_top = self.padding_top = 0 @@ -333,47 +311,47 @@ def __init__(self, image_processor, tokenizer): self.pad_token_id = 0 self.dummy_image_index = -1 - def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool): - max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs) - max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs) + def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool, truncation: bool, truncation_length: int): + max_length_input_ids = min(max(entry["input_ids"].shape[1] for entry in model_inputs), truncation_length) + max_length_image_patch_indices = min(max(entry["image_patches_indices"].shape[1] for entry in model_inputs), truncation_length) batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []} for entry in model_inputs: for key, tensor in entry.items(): - if key == "input_ids": - num_padding_tokens = max_length_input_ids - tensor.shape[1] - padded_input_ids = torch.cat( + if key == "input_ids" or key == "image_patches_indices": + # Truncate if the tensor is longer than the truncation_length + if tensor.shape[1] > truncation_length: + if truncation: + logger.warn(f"Truncating tensor from original length {tensor.shape[1]} to {truncation_length}") + tensor = tensor[:, :truncation_length] + else: + raise ValueError(f"Tensor length {tensor.shape[1]} exceeds truncation_length {truncation_length} (usually max positional embedding length), but truncation is disabled. Please enable truncation or increase truncation_length.") + + # Calculate the number of padding tokens or indices + num_padding = max_length_input_ids - tensor.shape[1] if key == "input_ids" else max_length_image_patch_indices - tensor.shape[1] + + # Pad the tensor + padded_tensor = torch.cat( [ - torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long), + torch.full((tensor.shape[0], num_padding), self.pad_token_id if key == "input_ids" else self.dummy_image_index, dtype=torch.long), tensor, ], dim=1, ) - batched_inputs[key].append(padded_input_ids) + batched_inputs[key].append(padded_tensor) - attention_mask = torch.cat( - [torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)], - dim=1, - ) - batched_inputs["attention_mask"].append(attention_mask) + if key == "input_ids": + attention_mask = torch.cat( + [torch.zeros(tensor.shape[0], num_padding, dtype=torch.long), torch.ones_like(tensor)], + dim=1, + ) + batched_inputs["attention_mask"].append(attention_mask) elif key == "image_patches": # For image_patches, we don't pad but just append them to the list. batched_inputs[key].append(tensor) - else: # for image_patches_indices - num_padding_indices = max_length_image_patch_indices - tensor.shape[1] - padded_indices = torch.cat( - [ - torch.full( - (tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long - ), - tensor, - ], - dim=1, - ) - batched_inputs[key].append(padded_indices) batched_keys = ["input_ids", "image_patches_indices"] if return_attention_mask: batched_keys.append("attention_mask") @@ -388,11 +366,12 @@ def get_sample_encoding( scale_factors, image_unpadded_heights, image_unpadded_widths, + image_indicator_id, image_placeholder_id, image_newline_id, tensor_batch_images, ): - image_present = torch.ones(1, 1, 1) + image_present = torch.ones(tensor_batch_images.shape[0], tensor_batch_images.shape[1], 1) # shape [batch_size, subsequence_size, num_images] model_image_input = self.image_processor.preprocess_with_tokenizer_info( image_input=tensor_batch_images, image_present=image_present, @@ -403,27 +382,18 @@ def get_sample_encoding( variable_sized=True, ) # FIXME max_tokens_to_generate is embedded into this processor's call. - prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( + prompt_tokens = _tokenize_prompts_with_image_and_batch( tokenizer=self.tokenizer, prompts=prompts, scale_factors=scale_factors, - max_tokens_to_generate=self.max_tokens_to_generate, - max_position_embeddings=self.max_position_embeddings, - add_BOS=True, - add_beginning_of_answer_token=True, + image_tokens=model_image_input["image_input_ids"], ) - image_padded_unpacked_tokens = construct_full_unpacked_stream( - num_real_text_tokens=prompts_length, + image_padded_unpacked_tokens, unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream( input_stream=prompt_tokens, image_tokens=model_image_input["image_input_ids"], - batch_size=1, - num_sub_sequences=self.subsequence_length, - ) - # Construct inputs for image patch indices. - unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream( - num_real_text_tokens=prompts_length, - input_stream=torch.full_like(prompt_tokens, -1), - image_tokens=model_image_input["image_patch_indices_per_batch"], + image_patch_indices_stream=torch.full_like(prompt_tokens, -1), + image_patch_indices=model_image_input["image_patch_indices_per_batch"], + image_indicator_id=image_indicator_id, batch_size=1, num_sub_sequences=self.subsequence_length, ) @@ -440,7 +410,7 @@ def get_sample_encoding( new_seq_len=max_seq_len_batch, offset=0, ) - image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]]) + image_patches_tensor = torch.stack([torch.concat(img, dim=0) for img in model_image_input["image_patches"]]) batch_encoding = { "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0), "image_patches": image_patches_tensor, @@ -496,47 +466,62 @@ def __call__( """ requires_backends(self, ["torch"]) - # --- Check input validity --- + # --- Check input validity and Insert the appropriate number of image place indicators before the prompt--- + if not return_attention_mask: raise ValueError("`return_attention_mask=False` is not supported for this model.") - if text is None and images is None: - raise ValueError("You have to specify either text or images. Both cannot be None.") - if text is not None and images is None: - logger.warning("You are processing a text with no associated image. Make sure it is intended.") - self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) - return text_encoding - if text is None and images is not None: - logger.warning("You are processing an image with no associated text. Make sure it is intended.") - prompts = [[""]] - if text is not None and images is not None: - if isinstance(text, str): - prompts = [[text]] - elif isinstance(text, list): - prompts = [[text_seq] for text_seq in text] + prompts = [] + images_list = [] + for text_item, image_group in zip(text, images): + # Handle the case where there is no text but there are images + if text_item is None and image_group is not None: + prompt = "|IMAGESTART|" * len([img for img in image_group if img is not None]) + prompts.append([prompt]) + images_list.append([img for img in image_group if img is not None]) + + # Handle the case where there is text and possibly images + elif text_item is not None: + # Counting the number of "|IMAGESTART|" in text_item + image_indicator_count = text_item.count('|IMAGESTART|') + + # Text with images + if image_group is not None: + + not_none_image_count = len([img for img in image_group if img is not None]) + + if image_indicator_count > not_none_image_count: + raise ValueError(f"Image place indicators exceed the number of images provided. Have {image_indicator_count} images?") + + elif image_indicator_count < not_none_image_count: + + insert_count = len(image_group) - image_indicator_count + logger.warning(f"Inserting {insert_count} image place indicators before the prompt.") + text_item = "|IMAGESTART|" * insert_count + text_item + + prompt = text_item + prompts.append([prompt]) + images_list.append([img for img in image_group if img is not None]) + + # Text without images + else: + if image_indicator_count > 0: + raise ValueError(f"Image place indicators exceed the number of images provided. Have {image_indicator_count} images?") + + else: + prompt = text_item + prompts.append([prompt]) + images_list.append([]) + + # Handle the case where both text and image are None + else: + raise ValueError("You have to specify either text or images. Both cannot be None.") + # --- Preprocess images using self.image_processor --- # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors - image_encoding = self.image_processor.preprocess(images, return_tensors="pt") + image_encoding = self.image_processor.preprocess(images_list, return_tensors="pt") batch_images = image_encoding["images"] image_unpadded_heights = image_encoding["image_unpadded_heights"] image_unpadded_widths = image_encoding["image_unpadded_widths"] @@ -545,10 +530,10 @@ def __call__( self.batch_size = len(batch_images) # --- Use self.tokenizer to get the ids of special tokens to insert into image ids --- - + image_indicator_id = self.tokenizer("|IMAGESTART|", add_special_tokens=False)["input_ids"][1] image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1] image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1] - tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1) + tensor_batch_images = [torch.stack(batch_image).unsqueeze(0) for batch_image in batch_images] # --- Use self.image_processor again to obtain the full token ids and batch inputs --- all_encodings = [] @@ -561,13 +546,14 @@ def __call__( scale_factors=[scale_factor], image_unpadded_heights=torch.tensor([image_unpadded_height]), image_unpadded_widths=torch.tensor([image_unpadded_width]), + image_indicator_id=image_indicator_id, image_placeholder_id=image_placeholder_id, image_newline_id=image_newline_id, - tensor_batch_images=tensor_batch_image.unsqueeze(0), + tensor_batch_images=tensor_batch_image, ) all_encodings.append(sample_encoding) batch_encoding = self._left_pad_inputs_with_attention_mask( - model_inputs=all_encodings, return_attention_mask=return_attention_mask + model_inputs=all_encodings, return_attention_mask=return_attention_mask, truncation=truncation, truncation_length=max_length if max_length is not None else self.max_position_embeddings, ) return FuyuBatchFeature(data=batch_encoding)