diff --git a/docs/source/en/model_doc/blip.md b/docs/source/en/model_doc/blip.md index fa06191834f898..0545400b835538 100644 --- a/docs/source/en/model_doc/blip.md +++ b/docs/source/en/model_doc/blip.md @@ -61,6 +61,11 @@ The original code can be found [here](https://github.com/salesforce/BLIP). [[autodoc]] BlipImageProcessor - preprocess +## BlipImageProcessorFast + +[[autodoc]] BlipImageProcessorFast + - preprocess + diff --git a/docs/source/en/model_doc/clip.md b/docs/source/en/model_doc/clip.md index f0829f484aaa51..cd2d56229b4e87 100644 --- a/docs/source/en/model_doc/clip.md +++ b/docs/source/en/model_doc/clip.md @@ -251,6 +251,11 @@ The resource should ideally demonstrate something new instead of duplicating an [[autodoc]] CLIPImageProcessor - preprocess +## CLIPImageProcessorFast + +[[autodoc]] CLIPImageProcessorFast + - preprocess + ## CLIPFeatureExtractor [[autodoc]] CLIPFeatureExtractor diff --git a/docs/source/en/model_doc/convnext.md b/docs/source/en/model_doc/convnext.md index 5222834b1f69d6..f3d10d77b1d2c2 100644 --- a/docs/source/en/model_doc/convnext.md +++ b/docs/source/en/model_doc/convnext.md @@ -64,6 +64,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] ConvNextImageProcessor - preprocess +## ConvNextImageProcessorFast + +[[autodoc]] ConvNextImageProcessorFast + - preprocess + diff --git a/docs/source/en/model_doc/deit.md b/docs/source/en/model_doc/deit.md index 6a4e141facaeac..a24632d5f867f1 100644 --- a/docs/source/en/model_doc/deit.md +++ b/docs/source/en/model_doc/deit.md @@ -125,6 +125,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] DeiTImageProcessor - preprocess +## DeiTImageProcessorFast + +[[autodoc]] DeiTImageProcessorFast + - preprocess + diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index 88bd63e7101f17..1710def1cf9edd 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -288,6 +288,11 @@ model = AutoModelForImageTextToText.from_pretrained( [[autodoc]] LlavaNextImageProcessor - preprocess +## LlavaNextImageProcessorFast + +[[autodoc]] LlavaNextImageProcessorFast + - preprocess + ## LlavaNextProcessor [[autodoc]] LlavaNextProcessor diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index b6b0a2bfa1d123..466b6adfff6dd3 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -100,8 +100,8 @@ import torch from PIL import Image import requests -processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") -model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) +processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") +model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) model.to("cuda:0") # prepare image and text prompt, using the appropriate prompt template @@ -298,8 +298,8 @@ First make sure to install flash-attn. Refer to the [original repository of Flas from transformers import LlavaOnevisionForConditionalGeneration model = LlavaOnevisionForConditionalGeneration.from_pretrained( - model_id, - torch_dtype=torch.float16, + model_id, + torch_dtype=torch.float16, low_cpu_mem_usage=True, use_flash_attention_2=True ).to(0) @@ -318,6 +318,11 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained( [[autodoc]] LlavaOnevisionImageProcessor +## LlavaOnevisionImageProcessorFast + +[[autodoc]] LlavaOnevisionImageProcessorFast + - preprocess + ## LlavaOnevisionVideoProcessor [[autodoc]] LlavaOnevisionVideoProcessor diff --git a/docs/source/en/model_doc/siglip.md b/docs/source/en/model_doc/siglip.md index 88e38cbb590edc..54588854d54c92 100644 --- a/docs/source/en/model_doc/siglip.md +++ b/docs/source/en/model_doc/siglip.md @@ -215,6 +215,11 @@ Below is an expected speedup diagram that compares inference time between the na [[autodoc]] SiglipImageProcessor - preprocess +## SiglipImageProcessorFast + +[[autodoc]] SiglipImageProcessorFast + - preprocess + ## SiglipProcessor [[autodoc]] SiglipProcessor diff --git a/docs/source/ja/model_doc/blip.md b/docs/source/ja/model_doc/blip.md index c145af701f23bb..8e8550318bd4c8 100644 --- a/docs/source/ja/model_doc/blip.md +++ b/docs/source/ja/model_doc/blip.md @@ -61,6 +61,11 @@ BLIP は、次のようなさまざまなマルチモーダル タスクを実 [[autodoc]] BlipImageProcessor - preprocess +## BlipImageProcessorFast + +[[autodoc]] BlipImageProcessorFast + - preprocess + diff --git a/docs/source/ja/model_doc/clip.md b/docs/source/ja/model_doc/clip.md index 697971e9224848..db896c91164a8b 100644 --- a/docs/source/ja/model_doc/clip.md +++ b/docs/source/ja/model_doc/clip.md @@ -133,6 +133,11 @@ CLIP を使い始めるのに役立つ公式 Hugging Face およびコミュニ [[autodoc]] CLIPImageProcessor - preprocess +## CLIPImageProcessorFast + +[[autodoc]] CLIPImageProcessorFast + - preprocess + ## CLIPFeatureExtractor [[autodoc]] CLIPFeatureExtractor diff --git a/docs/source/ja/model_doc/convnext.md b/docs/source/ja/model_doc/convnext.md index 4386a7df8ceadb..efbe3bb0f4b793 100644 --- a/docs/source/ja/model_doc/convnext.md +++ b/docs/source/ja/model_doc/convnext.md @@ -64,6 +64,11 @@ ConvNeXT の使用を開始するのに役立つ公式 Hugging Face およびコ [[autodoc]] ConvNextImageProcessor - preprocess +## ConvNextImageProcessorFast + +[[autodoc]] ConvNextImageProcessorFast + - preprocess + diff --git a/docs/source/ja/model_doc/deit.md b/docs/source/ja/model_doc/deit.md index aa8c66c90be0b8..00fa82e113c53f 100644 --- a/docs/source/ja/model_doc/deit.md +++ b/docs/source/ja/model_doc/deit.md @@ -98,6 +98,11 @@ DeiT を始めるのに役立つ公式 Hugging Face およびコミュニティ [[autodoc]] DeiTImageProcessor - preprocess +## DeiTImageProcessorFast + +[[autodoc]] DeiTImageProcessorFast + - preprocess + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7df1af049de626..be1d26b3e74acc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1278,10 +1278,17 @@ ] else: _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] + _import_structure["models.blip"].append("BlipImageProcessorFast") + _import_structure["models.clip"].append("CLIPImageProcessorFast") + _import_structure["models.convnext"].append("ConvNextImageProcessorFast") _import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast") + _import_structure["models.deit"].append("DeiTImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast") + _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") + _import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast") _import_structure["models.pixtral"].append("PixtralImageProcessorFast") _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") + _import_structure["models.siglip"].append("SiglipImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") try: @@ -6298,10 +6305,17 @@ from .utils.dummy_torchvision_objects import * else: from .image_processing_utils_fast import BaseImageProcessorFast + from .models.blip import BlipImageProcessorFast + from .models.clip import CLIPImageProcessorFast + from .models.convnext import ConvNextImageProcessorFast from .models.deformable_detr import DeformableDetrImageProcessorFast + from .models.deit import DeiTImageProcessorFast from .models.detr import DetrImageProcessorFast + from .models.llava_next import LlavaNextImageProcessorFast + from .models.llava_onevision import LlavaOnevisionImageProcessorFast from .models.pixtral import PixtralImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast + from .models.siglip import SiglipImageProcessorFast from .models.vit import ViTImageProcessorFast try: diff --git a/src/transformers/commands/add_fast_image_processor.py b/src/transformers/commands/add_fast_image_processor.py new file mode 100644 index 00000000000000..87a4f989757009 --- /dev/null +++ b/src/transformers/commands/add_fast_image_processor.py @@ -0,0 +1,692 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from argparse import ArgumentParser, Namespace +from datetime import date +from pathlib import Path + +from ..utils import logging +from . import BaseTransformersCLICommand + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +CURRENT_YEAR = date.today().year +TRANSFORMERS_PATH = Path(__file__).parent.parent +REPO_PATH = TRANSFORMERS_PATH.parent.parent + +DEFAULT_CLASS_DOCSTRING = """r\"\"\" + Constructs a fast {model_name} image processor. + + Args: + do_resize (`bool`, *optional*): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*): + Whether to convert the image to RGB. + \"\"\" +""" + + +def add_import_structure_entry_init(content: str, fast_image_processor_name: str, model_name: str): + """ + Add an entry to the `_import_structure` dictionary in the `__init__.py` file of the transformers package. + """ + # Step 1: Find the block + block_regex = re.compile( + r"if not is_torchvision_available\(\):.*?else:\s*(\n(?P\s+)_import_structure\[.*?\].*?\n(?:\s*(?P=indent)_import_structure\[.*?\].*?\n)*)", + re.DOTALL, + ) + match = block_regex.search(content) + + if not match: + raise ValueError("Couldn't find the '_import_structure' block.") + + # Capture the block content and indentation + block_content = match.group(1) + indent = match.group("indent") + + # Step 2: Parse existing entries + lines = block_content.strip().split("\n") + entries = [] + + import_structure_header = indent + lines[0] + entries = lines[1:] + + # Add the new entry, maintaining alphabetical order + new_entry = f'{indent}_import_structure["models.{model_name}"].append("{fast_image_processor_name}")' + if new_entry not in entries: + entries.append(new_entry) + + entries.sort() + entries = [import_structure_header] + entries + + # Step 3: Reconstruct the block + updated_block = "\n".join(entry for entry in entries) + + # Replace the original block in the content + updated_content = content[: match.start(1)] + "\n" + updated_block + "\n" + content[match.end(1) :] + + return updated_content + + +def add_import_statement_init(content: str, fast_image_processor_name: str, model_name: str): + """ + Add an import statement to the `__init__.py` file of the transformers package. + """ + # Step 1: Find the block + block_regex = re.compile( + r"if not is_torchvision_available\(\):\s+raise OptionalDependencyNotAvailable\(\)\s+except OptionalDependencyNotAvailable:\s+from \.utils\.dummy_torchvision_objects import \*\s+else:(?P\s*(\n\s*from .+ import .*\n)+)(?=\s*# Modeling)", + re.DOTALL, + ) + match = block_regex.search(content) + + if match: + block_content = match.group("else_block") # The captured import block + else: + print("Couldn't find the import statement block.") + + # Step 2: Parse existing entries + lines = block_content.strip().split("\n") + entries = [] + + indent = " " * (len(lines[1]) - len(lines[1].lstrip())) + import_structure_header = indent + lines[0] + entries = lines[1:] + + # Add the new entry, maintaining alphabetical order + new_entry = f"{indent}from .models.{model_name} import {fast_image_processor_name}" + if new_entry not in entries: + entries.append(new_entry) + + entries.sort() + entries = [import_structure_header] + entries + + # Step 3: Reconstruct the block + updated_block = "\n".join(entry for entry in entries) + + # Replace the original block in the content + updated_content = ( + content[: match.start("else_block")] + "\n" + updated_block + "\n\n" + content[match.end("else_block") :] + ) + + return updated_content + + +def add_fast_image_processor_to_main_init(fast_image_processor_name: str, model_name: str): + """ + Add the fast image processor to the main __init__.py file of the transformers package. + """ + with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f: + content = f.read() + + # add _import_structure entry + content = add_import_structure_entry_init(content, fast_image_processor_name, model_name) + # add import statement + content = add_import_statement_init(content, fast_image_processor_name, model_name) + + # write the updated content + with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f: + f.write(content) + + +def add_fast_image_processor_to_model_init( + fast_image_processing_module_file: str, fast_image_processor_name, model_name: str +): + """ + Add the fast image processor to the __init__.py file of the model. + """ + with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "r", encoding="utf-8") as f: + content = f.read() + + fast_image_processing_module_file = fast_image_processing_module_file.split(os.sep)[-1].replace(".py", "") + + if "import *" in content: + # we have an init file in the updated format + # get the indented block after if TYPE_CHECKING: and before else:, append the new import, sort the imports and write the updated content + # Step 1: Find the block + block_regex = re.compile( + r"if TYPE_CHECKING:\n(?P.*?)(?=\s*else:)", + re.DOTALL, + ) + match = block_regex.search(content) + + if not match: + raise ValueError("Couldn't find the 'if TYPE_CHECKING' block.") + + block_content = match.group("if_block") # The captured import block + + # Step 2: Parse existing entries + entries = block_content.split("\n") + indent = " " * (len(entries[0]) - len(entries[0].lstrip())) + new_entry = f"{indent}from .{fast_image_processing_module_file} import *" + if new_entry not in entries: + entries.append(new_entry) + entries.sort() + updated_block = "\n".join(entry for entry in entries) + + # Replace the original block in the content + updated_content = content[: match.start("if_block")] + updated_block + content[match.end("if_block") :] + else: + # we have an init file in the old format + + # add "is_torchvision_available" import to from ...utils import ( + # Regex to match import statements from transformers.utils + pattern = r""" + from\s+\.\.\.utils\s+import\s+ + (?: # Non-capturing group for either: + ([\w, ]+) # 1. Single-line imports (e.g., 'a, b') + | # OR + \((.*?)\) # 2. Multi-line imports (e.g., '(a, ... b)') + ) + """ + regex = re.compile(pattern, re.VERBOSE | re.DOTALL) + + def replacement_function(match): + # Extract existing imports + imports = (match.group(1) or match.group(2)).split(",") + imports = imports[:-1] if imports[-1] == "\n" else imports + imports = [imp.strip() for imp in imports] + + # Add the new import if not already present + if "is_torchvision_available" not in imports: + imports.append("is_torchvision_available") + imports.sort() + + # Convert to multi-line import in all cases + updated_imports = "(\n " + ",\n ".join(imports) + ",\n)" + + return f"from ...utils import {updated_imports}" + + # Replace all matches in the file content + updated_content = regex.sub(replacement_function, content) + + vision_import_structure_block = f' _import_structure["{fast_image_processing_module_file[:-5]}"] = ["{fast_image_processor_name[:-4]}"]\n' + + added_import_structure_block = ( + "try:\n if not is_torchvision_available():\n" + " raise OptionalDependencyNotAvailable()\n" + "except OptionalDependencyNotAvailable:\n" + " pass\n" + "else:\n" + f' _import_structure["{fast_image_processing_module_file}"] = ["{fast_image_processor_name}"]\n' + ) + + if vision_import_structure_block not in updated_content: + raise ValueError("Couldn't find the 'vision _import_structure block' block.") + + if added_import_structure_block not in updated_content: + updated_content = updated_content.replace( + vision_import_structure_block, vision_import_structure_block + "\n" + added_import_structure_block + ) + + vision_import_statement_block = ( + f" from .{fast_image_processing_module_file[:-5]} import {fast_image_processor_name[:-4]}\n" + ) + + added_import_statement_block = ( + " try:\n if not is_torchvision_available():\n" + " raise OptionalDependencyNotAvailable()\n" + " except OptionalDependencyNotAvailable:\n" + " pass\n" + " else:\n" + f" from .{fast_image_processing_module_file} import {fast_image_processor_name}\n" + ) + + if vision_import_statement_block not in updated_content: + raise ValueError("Couldn't find the 'vision _import_structure block' block.") + + if added_import_statement_block not in updated_content: + updated_content = updated_content.replace( + vision_import_statement_block, vision_import_statement_block + "\n" + added_import_statement_block + ) + + # write the updated content + with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "w", encoding="utf-8") as f: + f.write(updated_content) + + +def add_fast_image_processor_to_auto(image_processor_name: str, fast_image_processor_name: str): + """ + Add the fast image processor to the auto module. + """ + with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "r", encoding="utf-8") as f: + content = f.read() + + # get all lines containing the image processor name + updated_content = content.replace( + f'("{image_processor_name}",)', f'("{image_processor_name}", "{fast_image_processor_name}")' + ) + + # write the updated content + with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "w", encoding="utf-8") as f: + f.write(updated_content) + + +def add_fast_image_processor_to_dummy(fast_image_processor_name: str): + """ + Add the fast image processor to the dummy torchvision objects file. + """ + dummy_torchvision_objects_file = TRANSFORMERS_PATH / "utils" / "dummy_torchvision_objects.py" + with open(dummy_torchvision_objects_file, "r", encoding="utf-8") as f: + content = f.read() + + # regex to find objects starting with "class " and ending with "ImageProcessorFast", including "ImageProcessorFast" in the match + image_processor_names = re.findall(r"class (\w*ImageProcessorFast)", content) + image_processor_names.append(fast_image_processor_name) + image_processor_names.sort() + index_new = image_processor_names.index(fast_image_processor_name) + + new_dummy_object = ( + f"class {fast_image_processor_name}(metaclass=DummyObject):\n" + ' _backends = ["torchvision"]\n\n' + " def __init__(self, *args, **kwargs):\n" + ' requires_backends(self, ["torchvision"])\n' + ) + if new_dummy_object not in content: + if index_new != len(image_processor_names) - 1: + # add the dummy object just before the next ImageProcessorFast + first_line = f"class {image_processor_names[index_new+1]}(metaclass=DummyObject):" + updated_content = content.replace(first_line, new_dummy_object + "\n\n" + first_line) + else: + # add the dummy object at the very end + updated_content = content + "\n\n" + new_dummy_object + + # write the updated content + with open(dummy_torchvision_objects_file, "w", encoding="utf-8") as f: + f.write(updated_content) + + +def add_fast_image_processor_to_doc(fast_image_processor_name: str, model_name: str): + """ + Add the fast image processor to the model's doc file. + """ + doc_source = REPO_PATH / "docs" / "source" + # find the doc files + doc_files = list(doc_source.glob(f"*/model_doc/{model_name}.md")) + if not doc_files: + # try again with "-" + doc_files = list(doc_source.glob(f"*/model_doc/{model_name.replace('_', '-')}.md")) + if not doc_files: + raise ValueError(f"No doc files found for {model_name}") + + base_doc_string = ( + f"## {fast_image_processor_name[:-4]}\n\n" f"[[autodoc]] {fast_image_processor_name[:-4]}\n" " - preprocess" + ) + fast_doc_string = ( + f"## {fast_image_processor_name}\n\n" f"[[autodoc]] {fast_image_processor_name}\n" " - preprocess" + ) + + for doc_file in doc_files: + with open(doc_file, "r", encoding="utf-8") as f: + content = f.read() + + if fast_doc_string not in content: + # add the fast image processor to the doc + updated_content = content.replace( + base_doc_string, + base_doc_string + "\n\n" + fast_doc_string, + ) + + # write the updated content + with open(doc_file, "w", encoding="utf-8") as f: + f.write(updated_content) + + +def add_fast_image_processor_to_tests(fast_image_processor_name: str, model_name: str): + """ + Add the fast image processor to the image processing tests. + """ + tests_path = REPO_PATH / "tests" / "models" / model_name + test_file = tests_path / f"test_image_processing_{model_name}.py" + if not os.path.exists(test_file): + logger.warning(f"No test file found for {model_name}. Skipping.") + return + + with open(test_file, "r", encoding="utf-8") as f: + content = f.read() + + # add is_torchvision_available import to the imports + # Regex to match import statements from transformers.utils + pattern = r""" + from\s+transformers\.utils\s+import\s+ + (?: # Non-capturing group for either: + ([\w, ]+) # 1. Single-line imports (e.g., 'a, b') + | # OR + \((.*?)\) # 2. Multi-line imports (e.g., '(a, ... b)') + ) + """ + regex = re.compile(pattern, re.VERBOSE | re.DOTALL) + + def replacement_function(match): + # Extract existing imports + existing_imports = (match.group(1) or match.group(2)).split(",") + existing_imports = existing_imports[:-1] if existing_imports[-1] == "\n" else existing_imports + existing_imports = [imp.strip() for imp in existing_imports] + + # Add the new import if not already present + if "is_torchvision_available" not in existing_imports: + existing_imports.append("is_torchvision_available") + existing_imports.sort() + + # Rebuild the import statement + if match.group(1): # Single-line import + updated_imports = ", ".join(existing_imports) + else: # Multi-line import + updated_imports = "(\n " + ",\n ".join(existing_imports) + ",\n)" + + return f"from transformers.utils import {updated_imports}" + + # Replace all matches in the file content + updated_content = regex.sub(replacement_function, content) + + # add the fast image processor to the imports + base_import_string = f" from transformers import {fast_image_processor_name[:-4]}" + fast_import_string = ( + " if is_torchvision_available():\n" f" from transformers import {fast_image_processor_name}" + ) + if fast_import_string not in updated_content: + updated_content = updated_content.replace(base_import_string, base_import_string + "\n\n" + fast_import_string) + + # get line starting with " image_processing_class = " and add a line after it starting with " fast_image_processing_class = " + image_processing_class_line = re.search(r" image_processing_class = .*", updated_content) + if not image_processing_class_line: + logger.warning(f"Couldn't find the 'image_processing_class' line in {test_file}. Skipping.") + return + + fast_image_processing_class_line = ( + f" fast_image_processing_class = {fast_image_processor_name} if is_torchvision_available() else None" + ) + if " fast_image_processing_class = " not in updated_content: + updated_content = updated_content.replace( + image_processing_class_line.group(0), + image_processing_class_line.group(0) + "\n" + fast_image_processing_class_line, + ) + + # write the updated content + with open(test_file, "w", encoding="utf-8") as f: + f.write(updated_content) + + +def get_fast_image_processing_content_header(content: str) -> str: + """ + Get the header of the slow image processor file. + """ + # get all lines before and including the line containing """Image processor + content_header = re.search(r"^(.*?\n)*?\"\"\"Image processor.*", content) + content_header = content_header.group(0) + content_header = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content_header) + content_header = content_header.replace("Image processor", "Fast Image processor") + return content_header + + +def write_default_fast_image_processor_file( + fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str +): + """ + Write a default fast image processor file. Used when encountering a problem while parsing the slow image processor file. + """ + imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n\n\n" + content_header = get_fast_image_processing_content_header(content_base_file) + content_base_file = ( + f"class {fast_image_processor_name}(BaseImageProcessorFast):\n" + " # To be implemented\n" + " resample = None\n" + " image_mean = None\n" + " image_std = None\n" + " size = None\n" + " default_to_square = None\n" + " crop_size = None\n" + " do_resize = None\n" + " do_center_crop = None\n" + " do_rescale = None\n" + " do_normalize = None\n" + " do_convert_rgb = None\n\n\n" + f'__all__ = ["{fast_image_processor_name}"]\n' + ) + + content = content_header + imports + content_base_file + + with open(fast_image_processing_module_file, "w", encoding="utf-8") as f: + f.write(content) + + +def add_fast_image_processor_file( + fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str +): + """ + Add the fast image processor file to the model's folder. + """ + # if the file already exists, do nothing + if os.path.exists(fast_image_processing_module_file): + print(f"{fast_image_processing_module_file} already exists. Skipping.") + return + + regex = rf"class {fast_image_processor_name[:-4]}.*?(\n\S|$)" + match = re.search(regex, content_base_file, re.DOTALL) + if not match: + print(f"Couldn't find the {fast_image_processor_name[:-4]} class in {fast_image_processing_module_file}") + print("Creating a new file with the default content.") + return write_default_fast_image_processor_file( + fast_image_processing_module_file, fast_image_processor_name, content_base_file + ) + # Exclude the last unindented line + slow_class_content = match.group(0).rstrip() + # get default args: + # find the __init__ block which start with def __init__ and ends with def + match = re.search(r"def __init__.*?def ", slow_class_content, re.DOTALL) + if not match: + print( + f"Couldn't find the __init__ block for {fast_image_processor_name[:-4]} in {fast_image_processing_module_file}" + ) + print("Creating a new file with the default content.") + return write_default_fast_image_processor_file( + fast_image_processing_module_file, fast_image_processor_name, content_base_file + ) + init = match.group(0) + init_signature_block = init.split(")")[0] + arg_names = init_signature_block.split(":") + arg_names = [arg_name.split("\n")[-1].strip() for arg_name in arg_names] + # get the default values + default_args = re.findall(r"= (.*?)(?:,|\))", init_signature_block) + + # build default args dict + default_args_dict = dict(zip(arg_names, default_args)) + pattern_default_size = r"size = size if size is not None else\s+(.*)" + match_default_size = re.findall(pattern_default_size, init) + default_args_dict["size"] = match_default_size[0] if match_default_size else None + pattern_default_crop_size = r"crop_size = crop_size if crop_size is not None else\s+(.*)" + match_default_crop_size = re.findall(pattern_default_crop_size, init) + default_args_dict["crop_size"] = match_default_crop_size[0] if match_default_crop_size else None + pattern_default_image_mean = r"self.image_mean = image_mean if image_mean is not None else\s+(.*)" + match_default_image_mean = re.findall(pattern_default_image_mean, init) + default_args_dict["image_mean"] = match_default_image_mean[0] if match_default_image_mean else None + pattern_default_image_std = r"self.image_std = image_std if image_std is not None else\s+(.*)" + match_default_image_std = re.findall(pattern_default_image_std, init) + default_args_dict["image_std"] = match_default_image_std[0] if match_default_image_std else None + default_args_dict["default_to_square"] = False if "(size, default_to_square=False" in init else None + + content_header = get_fast_image_processing_content_header(content_base_file) + class_docstring = DEFAULT_CLASS_DOCSTRING.format( + model_name=fast_image_processor_name.replace("ImageProcessorFast", "") + ) + content_base_file = ( + f"class {fast_image_processor_name}(BaseImageProcessorFast):\n" + f" {class_docstring}\n\n" + " # This generated class can be used as a starting point for the fast image processor.\n" + " # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n" + " # only the default values should be set in the class.\n" + " # If the image processor requires more complex augmentations, methods from BaseImageProcessorFast can be overridden.\n" + " # For an example of a fast image processor requiring more complex augmentations, see `LlavaOnevisionImageProcessorFast`.\n\n" + " # Default values should be checked against the slow image processor\n" + " # None values left after checking can be removed\n" + f' resample = {default_args_dict.get("resample")}\n' + f' image_mean = {default_args_dict.get("image_mean")}\n' + f' image_std = {default_args_dict.get("image_std")}\n' + f' size = {default_args_dict.get("size")}\n' + f' default_to_square = {default_args_dict.get("default_to_square")}\n' + f' crop_size = {default_args_dict.get("crop_size")}\n' + f' do_resize = {default_args_dict.get("do_resize")}\n' + f' do_center_crop = {default_args_dict.get("do_center_crop")}\n' + f' do_rescale = {default_args_dict.get("do_rescale")}\n' + f' do_normalize = {default_args_dict.get("do_normalize")}\n' + f' do_convert_rgb = {default_args_dict.get("do_convert_rgb")}\n\n\n' + f'__all__ = ["{fast_image_processor_name}"]\n' + ) + + imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n" + image_utils_imports = [] + if default_args_dict.get("resample") is not None and "PILImageResampling" in default_args_dict.get("resample"): + image_utils_imports.append("PILImageResampling") + if default_args_dict.get("image_mean") is not None and not any( + char.isdigit() for char in default_args_dict.get("image_mean") + ): + image_utils_imports.append(default_args_dict.get("image_mean")) + if default_args_dict.get("image_std") is not None and not any( + char.isdigit() for char in default_args_dict.get("image_std") + ): + image_utils_imports.append(default_args_dict.get("image_std")) + + if image_utils_imports: + # sort imports + image_utils_imports.sort() + imports += f"from ...image_utils import {', '.join(image_utils_imports)}\n" + + content = content_header + imports + "\n\n" + content_base_file + + with open(fast_image_processing_module_file, "w", encoding="utf-8") as f: + f.write(content) + + +def add_fast_image_processor(model_name: str): + """ + Add the necessary references to the fast image processor in the transformers package, + and create the fast image processor file in the model's folder. + """ + model_module = TRANSFORMERS_PATH / "models" / model_name + image_processing_module_file = list(model_module.glob("image_processing*.py")) + if not image_processing_module_file: + raise ValueError(f"No image processing module found in {model_module}") + elif len(image_processing_module_file) > 1: + for file_name in image_processing_module_file: + if not str(file_name).endswith("_fast.py"): + image_processing_module_file = str(file_name) + break + else: + image_processing_module_file = str(image_processing_module_file[0]) + + with open(image_processing_module_file, "r", encoding="utf-8") as f: + content_base_file = f.read() + + # regex to find object starting with "class " and ending with "ImageProcessor", including "ImageProcessor" in the match + image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file) + if not image_processor_name: + raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}") + elif len(image_processor_name) > 1: + raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}") + + image_processor_name = image_processor_name[0] + fast_image_processor_name = image_processor_name + "Fast" + fast_image_processing_module_file = image_processing_module_file.replace(".py", "_fast.py") + + print(f"Adding {fast_image_processor_name} to {fast_image_processing_module_file}") + + add_fast_image_processor_to_main_init( + fast_image_processor_name=fast_image_processor_name, + model_name=model_name, + ) + + add_fast_image_processor_to_model_init( + fast_image_processing_module_file=fast_image_processing_module_file, + fast_image_processor_name=fast_image_processor_name, + model_name=model_name, + ) + + add_fast_image_processor_to_auto( + image_processor_name=image_processor_name, + fast_image_processor_name=fast_image_processor_name, + ) + + add_fast_image_processor_to_dummy(fast_image_processor_name=fast_image_processor_name) + + add_fast_image_processor_to_doc( + fast_image_processor_name=fast_image_processor_name, + model_name=model_name, + ) + + add_fast_image_processor_to_tests( + fast_image_processor_name=fast_image_processor_name, + model_name=model_name, + ) + + add_fast_image_processor_file( + fast_image_processing_module_file=fast_image_processing_module_file, + fast_image_processor_name=fast_image_processor_name, + content_base_file=content_base_file, + ) + + +def add_new_model_like_command_factory(args: Namespace): + return AddFastImageProcessorCommand(model_name=args.model_name) + + +class AddFastImageProcessorCommand(BaseTransformersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + add_fast_image_processor_parser = parser.add_parser("add-fast-image-processor") + add_fast_image_processor_parser.add_argument( + "--model-name", + type=str, + required=True, + help="The name of the folder containing the model's implementation.", + ) + add_fast_image_processor_parser.set_defaults(func=add_new_model_like_command_factory) + + def __init__(self, model_name: str, *args): + self.model_name = model_name + + def run(self): + add_fast_image_processor(model_name=self.model_name) diff --git a/src/transformers/commands/transformers_cli.py b/src/transformers/commands/transformers_cli.py index 6e8cfea0c3141a..61e30086f7c8b1 100644 --- a/src/transformers/commands/transformers_cli.py +++ b/src/transformers/commands/transformers_cli.py @@ -15,6 +15,7 @@ from argparse import ArgumentParser +from .add_fast_image_processor import AddFastImageProcessorCommand from .add_new_model_like import AddNewModelLikeCommand from .convert import ConvertCommand from .download import DownloadCommand @@ -40,6 +41,7 @@ def main(): AddNewModelLikeCommand.register_subcommand(commands_parser) LfsCommands.register_subcommand(commands_parser) PTtoTFCommand.register_subcommand(commands_parser) + AddFastImageProcessorCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index 0279f26a963e35..59aea9b8a5a8d7 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Dict, Iterable, Optional, Union import numpy as np from .image_processing_base import BatchFeature, ImageProcessingMixin from .image_transforms import center_crop, normalize, rescale -from .image_utils import ChannelDimension +from .image_utils import ChannelDimension, get_image_size from .utils import logging @@ -285,3 +286,23 @@ def select_best_resolution(original_size: tuple, possible_resolutions: list) -> best_fit = (height, width) return best_fit + + +def get_patch_output_size(image, target_resolution, input_data_format): + """ + Given an image and a target resolution, calculate the output size of the image after cropping to the target + """ + original_height, original_width = get_image_size(image, channel_dim=input_data_format) + target_height, target_width = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + return new_height, new_width diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 3c1be325b7eb30..e5d3ae8377e1dc 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -13,94 +13,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from dataclasses import dataclass -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from .image_processing_utils import BaseImageProcessor -from .utils.import_utils import is_torch_available, is_torchvision_available +import numpy as np +from .image_processing_utils import ( + BaseImageProcessor, + BatchFeature, + get_size_dict, +) +from .image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + get_size_with_aspect_ratio, + group_images_by_shape, + reorder_images, +) +from .image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + SizeDict, + get_image_size, + get_image_size_for_max_height_width, + get_image_type, + infer_channel_dimension_format, + make_list_of_images, + validate_fast_preprocess_arguments, + validate_kwargs, +) +from .utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, +) -if is_torchvision_available(): - from torchvision.transforms import Compose + +if is_vision_available(): + from .image_utils import PILImageResampling if is_torch_available(): import torch +if is_torchvision_available(): + from .image_utils import pil_torch_interpolation_mapping -@dataclass(frozen=True) -class SizeDict: - """ - Hashable dictionary to store image size information. - """ - - height: int = None - width: int = None - longest_edge: int = None - shortest_edge: int = None - max_height: int = None - max_width: int = None - - def __getitem__(self, key): - if hasattr(self, key): - return getattr(self, key) - raise KeyError(f"Key {key} not found in SizeDict.") - - -class BaseImageProcessorFast(BaseImageProcessor): - _transform_params = None - - def _build_transforms(self, **kwargs) -> "Compose": - """ - Given the input settings e.g. do_resize, build the image transforms. - """ - raise NotImplementedError - - def _validate_params(self, **kwargs) -> None: - for k, v in kwargs.items(): - if k not in self._transform_params: - raise ValueError(f"Invalid transform parameter {k}={v}.") - - @functools.lru_cache(maxsize=1) - def get_transforms(self, **kwargs) -> "Compose": - self._validate_params(**kwargs) - return self._build_transforms(**kwargs) - - def to_dict(self): - encoder_dict = super().to_dict() - encoder_dict.pop("_transform_params", None) - return encoder_dict - - -def get_image_size_for_max_height_width( - image_size: Tuple[int, int], - max_height: int, - max_width: int, -) -> Tuple[int, int]: - """ - Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. - Important, even if image_height < max_height and image_width < max_width, the image will be resized - to at least one of the edges be equal to max_height or max_width. - - For example: - - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) - - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) - - Args: - image_size (`Tuple[int, int]`): - The image to resize. - max_height (`int`): - The maximum allowed height. - max_width (`int`): - The maximum allowed width. - """ - height, width = image_size - height_scale = max_height / height - width_scale = max_width / width - min_scale = min(height_scale, width_scale) - new_height = int(height * min_scale) - new_width = int(width * min_scale) - return new_height, new_width + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": @@ -131,3 +94,556 @@ def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]: _, max_height, max_width = max_across_indices([img.shape for img in images]) return (max_height, max_width) + + +def divide_to_patches( + image: Union[np.array, "torch.Tensor"], patch_size: int +) -> List[Union[np.array, "torch.Tensor"]]: + """ + Divides an image into patches of a specified size. + + Args: + image (`Union[np.array, "torch.Tensor"]`): + The input image. + patch_size (`int`): + The size of each patch. + Returns: + list: A list of Union[np.array, "torch.Tensor"] representing the patches. + """ + patches = [] + height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST) + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + patch = image[:, i : i + patch_size, j : j + patch_size] + patches.append(patch) + + return patches + + +class BaseImageProcessorFast(BaseImageProcessor): + r""" + Constructs a fast base image processor. + + Args: + do_resize (`bool`, *optional*): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*): + Whether to convert the image to RGB. + """ + + resample = None + image_mean = None + image_std = None + size = None + default_to_square = None + crop_size = None + do_resize = None + do_center_crop = None + do_rescale = None + do_normalize = None + do_convert_rgb = None + model_input_names = ["pixel_values"] + valid_extra_kwargs = ["default_to_square"] + + def __init__( + self, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: Union["PILImageResampling", "F.InterpolationMode"] = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_rescale: bool = None, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = None, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + do_convert_rgb: bool = None, + **kwargs, + ) -> None: + size = size if size is not None else self.size + default_to_square = kwargs.pop( + "default_to_square", self.default_to_square if self.default_to_square is not None else True + ) + size = get_size_dict(size, default_to_square=default_to_square) if size is not None else None + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + + super().__init__(**kwargs) + self.do_resize = do_resize if do_resize is not None else self.do_resize + self.size = size if size is not None else self.size + self.resample = resample if resample is not None else self.resample + self.do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + self.crop_size = crop_size if crop_size is not None else self.crop_size + self.do_rescale = do_rescale if do_rescale is not None else self.do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize if do_normalize is not None else self.do_normalize + self.image_mean = image_mean if image_mean is not None else self.image_mean + self.image_std = image_std if image_std is not None else self.image_std + self.do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size.shortest_edge, + size.longest_edge, + ) + elif size.shortest_edge: + new_size = get_resize_output_image_size( + image, + size=size.shortest_edge, + default_to_square=False, + input_data_format=ChannelDimension.FIRST, + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width) + elif size.height and size.width: + new_size = (size.height, size.width) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" + f" {size}." + ) + return F.resize(image, new_size, interpolation=interpolation) + + def rescale( + self, + image: "torch.Tensor", + scale: float, + **kwargs, + ) -> "torch.Tensor": + """ + Rescale an image by a scale factor. image = image * scale. + + Args: + image (`torch.Tensor`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + + Returns: + `torch.Tensor`: The rescaled image. + """ + return image * scale + + def normalize( + self, + image: "torch.Tensor", + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + **kwargs, + ) -> "torch.Tensor": + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`torch.Tensor`): + Image to normalize. + mean (`torch.Tensor`, `float` or `Iterable[float]`): + Image mean to use for normalization. + std (`torch.Tensor`, `float` or `Iterable[float]`): + Image standard deviation to use for normalization. + + Returns: + `torch.Tensor`: The normalized image. + """ + return F.normalize(image, mean, std) + + def rescale_and_normalize( + self, + images: "torch.Tensor", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, List[float]], + image_std: Union[float, List[float]], + ) -> "torch.Tensor": + """ + Rescale and normalize images. + """ + if do_rescale and do_normalize: + images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) + elif do_rescale: + images = images * rescale_factor + elif do_normalize: + images = self.normalize(images, image_mean, image_std) + + return images + + def center_crop( + self, + image: "torch.Tensor", + size: Dict[str, int], + **kwargs, + ) -> "torch.Tensor": + """ + Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along + any edge, the image is padded with 0's and then center cropped. + + Args: + image (`"torch.Tensor"`): + Image to center crop. + size (`Dict[str, int]`): + Size of the output image. + + Returns: + `torch.Tensor`: The center cropped image. + """ + if size.height is None or size.width is None: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + return F.center_crop(image, (size["height"], size["width"])) + + def convert_to_rgb( + self, + image: ImageInput, + ) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + Args: + image (ImageInput): + The image to convert. + + Returns: + ImageInput: The converted image. + """ + return convert_to_rgb(image) + + def _prepare_images_structure( + self, + images: ImageInput, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_list_of_images(images) + + def _prepare_input_images( + self, + images: ImageInput, + do_convert_rgb: bool = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + ) -> List["torch.Tensor"]: + """ + Prepare the input images for processing. + """ + images = self._prepare_images_structure(images) + image_type = get_image_type(images[0]) + if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: + raise ValueError(f"Unsupported input image type {image_type}") + + if do_convert_rgb: + images = [self.convert_to_rgb(image) for image in images] + + if image_type == ImageType.PIL: + images = [F.pil_to_tensor(image) for image in images] + elif image_type == ImageType.NUMPY: + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + images = [torch.from_numpy(image).contiguous() for image in images] + + # Now that we have torch tensors, we can move them to the right device + if device is not None: + images = [image.to(device) for image in images] + + # We assume that all images have the same channel dimension format. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + if input_data_format == ChannelDimension.LAST: + # We force the channel dimension to be first for torch tensors as this is what torchvision expects. + images = [image.permute(2, 0, 1).contiguous() for image in images] + + return images + + def _prepare_process_arguments( + self, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + device: Optional["torch.device"] = None, + ) -> tuple: + """ + Prepare the arguments for the process method. + """ + # Make hashable for cache + size = SizeDict(**size) if size is not None else None + crop_size = SizeDict(**crop_size) if crop_size is not None else None + image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean + image_std = tuple(image_std) if isinstance(image_std, list) else image_std + + validate_fast_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + return_tensors=return_tensors, + data_format=data_format, + ) + + if do_rescale and do_normalize: + # Fused rescale and normalize + image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) + + interpolation = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + return image_mean, image_std, size, crop_size, interpolation + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the output image after applying `center_crop`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Default to `"pt"` for PyTorch tensors if unset. + Fast image processors only support PyTorch tensors. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + default_to_square = kwargs.pop( + "default_to_square", self.default_to_square if self.default_to_square is not None else True + ) + size = get_size_dict(size=size, default_to_square=default_to_square) if size is not None else None + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = self._prepare_input_images( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + image_mean, image_std, size, crop_size, interpolation = self._prepare_process_arguments( + do_resize=do_resize, + size=size, + resample=resample, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + device=images[0].device, + ) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + def to_dict(self): + encoder_dict = super().to_dict() + encoder_dict.pop("_valid_processor_keys", None) + return encoder_dict + + +class SemanticSegmentationMixin: + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`MobileNetV2ForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + # if is_torch_tensor(target_sizes): + # target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index e7d3a5abb7a8db..e2ed98a3a01488 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -15,15 +15,17 @@ import warnings from math import ceil -from typing import Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np from .image_utils import ( ChannelDimension, ImageInput, + SizeDict, get_channel_dimension_axis, get_image_size, + get_image_size_for_max_height_width, infer_channel_dimension_format, ) from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor @@ -216,6 +218,45 @@ def to_pil_image( return PIL.Image.fromarray(image, mode=image_mode) +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`Tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + # Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366 def get_resize_output_image_size( input_image: np.ndarray, @@ -821,32 +862,37 @@ def _cast_tensor_to_float(x): return x.float() -class FusedRescaleNormalize: +def group_images_by_shape( + images: List["torch.Tensor"], +) -> Tuple[Dict[Tuple[int, int], List["torch.Tensor"]], Dict[int, Tuple[Tuple[int, int], int]]]: """ - Rescale and normalize the input image in one step. + Groups images by shape. + Returns a dictionary with the shape as key and a list of images with that shape as value, + and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value. """ - - def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False): - self.mean = torch.tensor(mean) * (1.0 / rescale_factor) - self.std = torch.tensor(std) * (1.0 / rescale_factor) - self.inplace = inplace - - def __call__(self, image: "torch.Tensor"): - image = _cast_tensor_to_float(image) - return F.normalize(image, self.mean, self.std, inplace=self.inplace) - - -class Rescale: + grouped_images = {} + grouped_images_index = {} + for i, image in enumerate(images): + shape = image.shape[1:] + if shape not in grouped_images: + grouped_images[shape] = [] + grouped_images[shape].append(image) + grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1) + # stack images with the same shape + grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()} + return grouped_images, grouped_images_index + + +def reorder_images( + processed_images: Dict[Tuple[int, int], "torch.Tensor"], grouped_images_index: Dict[int, Tuple[int, int]] +) -> List["torch.Tensor"]: """ - Rescale the input image by rescale factor: image *= rescale_factor. + Reconstructs a list of images in the original order. """ - - def __init__(self, rescale_factor: float = 1.0): - self.rescale_factor = rescale_factor - - def __call__(self, image: "torch.Tensor"): - image = image * self.rescale_factor - return image + return [ + processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] + for i in range(len(grouped_images_index)) + ] class NumpyToTensor: @@ -858,3 +904,158 @@ def __call__(self, image: np.ndarray): # Same as in PyTorch, we assume incoming numpy images are in HWC format # c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154 return torch.from_numpy(image.transpose(2, 0, 1)).contiguous() + + +if is_torch_available(): + + class GroupByShape(torch.nn.Module): + """ + Groups images by shape. + Returns a dictionary with the shape as key and a list of images with that shape as value, + and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value. + """ + + def __init__(self): + super().__init__() + + def forward(self, images: List["torch.Tensor"]): + grouped_images, grouped_images_index = group_images_by_shape(images) + return {"grouped_images": grouped_images, "grouped_images_index": grouped_images_index} + + class Resize(torch.nn.Module): + """ + Resize the input images to the specified size. + The input images can be a torch.Tensor or images grouped by shape. + See `GroupByShape` for more information on grouping images by shape. + """ + + def __init__( + self, + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + ): + super().__init__() + self.size = size + self.interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + + def forward(self, images: Union["torch.Tensor", dict]): + def _get_size_tuple(image_group: "torch.Tensor", size: SizeDict): + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + size_tuple = get_size_with_aspect_ratio( + image_group.size()[-2:], + size.shortest_edge, + size.longest_edge, + ) + elif size.shortest_edge: + size_tuple = get_resize_output_image_size( + image_group, + size=size.shortest_edge, + default_to_square=False, + input_data_format=ChannelDimension.FIRST, + ) + elif size.max_height and size.max_width: + size_tuple = get_image_size_for_max_height_width( + image_group.size()[-2:], size.max_height, size.max_width + ) + elif size.height and size.width: + size_tuple = (size.height, size.width) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" + f" {size}." + ) + return size_tuple + + if isinstance(images, dict): + grouped_images = images["grouped_images"] + grouped_images_index = images["grouped_images_index"] + resized_images = {} + for shape, image_group in grouped_images.items(): + resized_images[shape] = F.resize( + image_group, size=_get_size_tuple(image_group, self.size), interpolation=self.interpolation + ) + return {"grouped_images": resized_images, "grouped_images_index": grouped_images_index} + elif isinstance(images, torch.Tensor): + return F.resize(images, size=_get_size_tuple(images, self.size), interpolation=self.interpolation) + + raise ValueError( + "Inputs to Resize must be a list of torch.Tensor or a dictionary with 'grouped_images' and 'grouped_images_index' keys, got {images}." + ) + + class Normalize(torch.nn.Module): + def __init__(self, mean: Union[float, List[float]], std: Union[float, List[float]]): + super().__init__() + self.mean = mean + self.std = std + + def forward(self, images: Union["torch.Tensor", dict]): + if isinstance(images, dict): + grouped_images = images["grouped_images"] + grouped_images_index = images["grouped_images_index"] + normalized_images = {} + for shape, image_group in grouped_images.items(): + image_group = _cast_tensor_to_float(image_group) + normalized_images[shape] = F.normalize(image_group, mean=self.mean, std=self.std) + + return {"grouped_images": normalized_images, "grouped_images_index": grouped_images_index} + elif isinstance(images, torch.Tensor): + return F.normalize(_cast_tensor_to_float(images), mean=self.mean, std=self.std) + + raise ValueError( + f"Inputs to Normalize must be a list of torch.Tensor or a dictionary with 'grouped_images' and 'grouped_images_index' keys, got {images}." + ) + + class Rescale(torch.nn.Module): + def __init__(self, rescale_factor: float): + super().__init__() + self.rescale_factor = rescale_factor + + def forward(self, images: Union["torch.Tensor", dict]): + if isinstance(images, dict): + grouped_images = images["grouped_images"] + grouped_images_index = images["grouped_images_index"] + rescaled_images = {} + for shape, image_group in grouped_images.items(): + image_group = torch.stack(image_group, dim=0) + rescaled_images[shape] = image_group * self.rescale_factor + return {"grouped_images": rescaled_images, "grouped_images_index": grouped_images_index} + elif isinstance(images, torch.Tensor): + return images * self.rescale_factor + + raise ValueError( + f"Inputs to Rescale must be a list of torch.Tensor or a dictionary with 'grouped_images' and 'grouped_images_index' keys, got {images}." + ) + + class CenterCrop(torch.nn.Module): + def __init__(self, size: Tuple[int, int]): + super().__init__() + self.size = size + + def forward(self, images: Union["torch.Tensor", dict]): + if isinstance(images, dict): + grouped_images = images["grouped_images"] + grouped_images_index = images["grouped_images_index"] + cropped_images = {} + for shape, image_group in grouped_images.items(): + cropped_images[shape] = F.center_crop(image_group, self.size) + return {"grouped_images": cropped_images, "grouped_images_index": grouped_images_index} + elif isinstance(images, torch.Tensor): + return F.center_crop(images, self.size) + + raise ValueError( + f"Inputs to CenterCrop must be a torch.Tensor or a dictionary with 'grouped_images' and 'grouped_images_index' keys, got {images}." + ) + + class ReorderImages(torch.nn.Module): + """ + Reorders images back to the original order. + This transform is used to reorder images after they have been grouped by shape. + """ + + def __init__(self): + super().__init__() + + def forward(self, images: dict): + return reorder_images(images["grouped_images"], images["grouped_images_index"]) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 51199d9f3698fc..2e1bbf841fb8e3 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -15,6 +15,7 @@ import base64 import os +from dataclasses import dataclass from io import BytesIO from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -209,6 +210,35 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: ) +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_pil_image(images): + return [images] + + elif is_valid_image(images): + if len(images.shape) == 4: + return images + elif len(images.shape) == 3: + return [images] + + raise ValueError(f"Could not make batched video from {images}") + + def to_numpy_array(img) -> np.ndarray: if not is_valid_image(img): raise ValueError(f"Invalid image type: {type(img)}") @@ -303,6 +333,37 @@ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> T raise ValueError(f"Unsupported data format: {channel_dim}") +def get_image_size_for_max_height_width( + image_size: Tuple[int, int], + max_height: int, + max_width: int, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + image_size (`Tuple[int, int]`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + """ + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool: if ( isinstance(annotation, dict) @@ -474,12 +535,16 @@ def validate_fast_preprocess_arguments( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisibility, + do_center_crop=do_center_crop, + crop_size=crop_size, do_resize=do_resize, size=size, resample=resample, ) # Extra checks for ImageProcessorFast - if return_tensors != "pt": + if return_tensors is not None and return_tensors != "pt": raise ValueError("Only returning PyTorch tensors is currently supported.") if data_format != ChannelDimension.FIRST: @@ -869,3 +934,22 @@ def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]) unused_key_str = ", ".join(unused_keys) # TODO raise a warning here instead of simply logging? logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.") + + +@dataclass(frozen=True) +class SizeDict: + """ + Hashable dictionary to store image size information. + """ + + height: int = None + width: int = None + longest_edge: int = None + shortest_edge: int = None + max_height: int = None + max_width: int = None + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise KeyError(f"Key {key} not found in SizeDict.") diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py index 7b00665aa2859d..8e9c8e2272f979 100644 --- a/src/transformers/models/aria/image_processing_aria.py +++ b/src/transformers/models/aria/image_processing_aria.py @@ -31,7 +31,7 @@ PILImageResampling, get_image_size, infer_channel_dimension_format, - is_valid_image, + make_batched_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -39,29 +39,6 @@ from ...utils import TensorType -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ Divides an image into patches of a specified size. diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index db25591eaa3544..bc9ecc4c4afc93 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -59,20 +59,20 @@ ("aria", ("AriaImageProcessor")), ("beit", ("BeitImageProcessor",)), ("bit", ("BitImageProcessor",)), - ("blip", ("BlipImageProcessor",)), - ("blip-2", ("BlipImageProcessor",)), + ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")), + ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")), ("bridgetower", ("BridgeTowerImageProcessor",)), ("chameleon", ("ChameleonImageProcessor",)), ("chinese_clip", ("ChineseCLIPImageProcessor",)), - ("clip", ("CLIPImageProcessor",)), + ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")), ("conditional_detr", ("ConditionalDetrImageProcessor",)), - ("convnext", ("ConvNextImageProcessor",)), - ("convnextv2", ("ConvNextImageProcessor",)), - ("cvt", ("ConvNextImageProcessor",)), + ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("data2vec-vision", ("BeitImageProcessor",)), ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), - ("deit", ("DeiTImageProcessor",)), + ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), ("depth_anything", ("DPTImageProcessor",)), ("deta", ("DetaImageProcessor",)), ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")), @@ -85,26 +85,26 @@ ("flava", ("FlavaImageProcessor",)), ("focalnet", ("BitImageProcessor",)), ("fuyu", ("FuyuImageProcessor",)), - ("git", ("CLIPImageProcessor",)), + ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("glpn", ("GLPNImageProcessor",)), ("grounding-dino", ("GroundingDinoImageProcessor",)), - ("groupvit", ("CLIPImageProcessor",)), + ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("hiera", ("BitImageProcessor",)), ("idefics", ("IdeficsImageProcessor",)), ("idefics2", ("Idefics2ImageProcessor",)), ("idefics3", ("Idefics3ImageProcessor",)), ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")), ("imagegpt", ("ImageGPTImageProcessor",)), - ("instructblip", ("BlipImageProcessor",)), + ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("instructblipvideo", ("InstructBlipVideoImageProcessor",)), - ("kosmos-2", ("CLIPImageProcessor",)), + ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("layoutlmv2", ("LayoutLMv2ImageProcessor",)), ("layoutlmv3", ("LayoutLMv3ImageProcessor",)), ("levit", ("LevitImageProcessor",)), - ("llava", ("CLIPImageProcessor",)), - ("llava_next", ("LlavaNextImageProcessor",)), + ("llava", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), ("llava_next_video", ("LlavaNextVideoImageProcessor",)), - ("llava_onevision", ("LlavaOnevisionImageProcessor",)), + ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")), ("mask2former", ("Mask2FormerImageProcessor",)), ("maskformer", ("MaskFormerImageProcessor",)), ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")), @@ -118,7 +118,7 @@ ("oneformer", ("OneFormerImageProcessor",)), ("owlv2", ("Owlv2ImageProcessor",)), ("owlvit", ("OwlViTImageProcessor",)), - ("paligemma", ("SiglipImageProcessor",)), + ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor",)), ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), @@ -126,13 +126,13 @@ ("pvt", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)), ("qwen2_vl", ("Qwen2VLImageProcessor",)), - ("regnet", ("ConvNextImageProcessor",)), - ("resnet", ("ConvNextImageProcessor",)), + ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("sam", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), - ("siglip", ("SiglipImageProcessor",)), + ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin2sr", ("Swin2SRImageProcessor",)), @@ -144,16 +144,16 @@ ("tvp", ("TvpImageProcessor",)), ("udop", ("LayoutLMv3ImageProcessor",)), ("upernet", ("SegformerImageProcessor",)), - ("van", ("ConvNextImageProcessor",)), + ("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("videomae", ("VideoMAEImageProcessor",)), ("vilt", ("ViltImageProcessor",)), - ("vipllava", ("CLIPImageProcessor",)), + ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vit_hybrid", ("ViTHybridImageProcessor",)), ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vitmatte", ("VitMatteImageProcessor",)), - ("xclip", ("CLIPImageProcessor",)), + ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("yolos", ("YolosImageProcessor",)), ("zoedepth", ("ZoeDepthImageProcessor",)), ] diff --git a/src/transformers/models/blip/__init__.py b/src/transformers/models/blip/__init__.py index 5443a3f6747aaa..1102af75d1164a 100644 --- a/src/transformers/models/blip/__init__.py +++ b/src/transformers/models/blip/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_blip import * from .image_processing_blip import * + from .image_processing_blip_fast import * from .modeling_blip import * from .modeling_tf_blip import * from .processing_blip import * diff --git a/src/transformers/models/blip/image_processing_blip_fast.py b/src/transformers/models/blip/image_processing_blip_fast.py new file mode 100644 index 00000000000000..b333df70f0fa30 --- /dev/null +++ b/src/transformers/models/blip/image_processing_blip_fast.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for BLIP.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling + + +class BlipImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + # To be checked against the slow image processor + # None values left after checking can be removed + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + +__all__ = ["BlipImageProcessorFast"] diff --git a/src/transformers/models/clip/__init__.py b/src/transformers/models/clip/__init__.py index f2c43e0b51d63b..18a4db32e9943d 100644 --- a/src/transformers/models/clip/__init__.py +++ b/src/transformers/models/clip/__init__.py @@ -21,6 +21,7 @@ from .configuration_clip import * from .feature_extraction_clip import * from .image_processing_clip import * + from .image_processing_clip_fast import * from .modeling_clip import * from .modeling_flax_clip import * from .modeling_tf_clip import * diff --git a/src/transformers/models/clip/image_processing_clip_fast.py b/src/transformers/models/clip/image_processing_clip_fast.py new file mode 100644 index 00000000000000..2b0447edd6f8b6 --- /dev/null +++ b/src/transformers/models/clip/image_processing_clip_fast.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for CLIP.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling + + +class CLIPImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + # To be checked against the slow image processor + # None values left after checking can be removed + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + +__all__ = ["CLIPImageProcessorFast"] diff --git a/src/transformers/models/convnext/image_processing_convnext_fast.py b/src/transformers/models/convnext/image_processing_convnext_fast.py new file mode 100644 index 00000000000000..0db037ce8377ce --- /dev/null +++ b/src/transformers/models/convnext/image_processing_convnext_fast.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for ConvNeXT.""" + +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images +from ...image_transforms import get_resize_output_image_size +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + validate_kwargs, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class ConvNextImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast ConvNeXT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`): + Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is + resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will + be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can + be overriden by `size` in the `preprocess` method. + crop_pct (`float` *optional*, defaults to 224 / 256): + Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be + overriden by `crop_pct` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"shortest_edge": 384} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + + def __init__( + self, + do_resize: bool = None, + size: Dict[str, int] = None, + crop_pct: float = None, + resample: PILImageResampling = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__( + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + **kwargs, + ) + self.crop_pct = crop_pct if crop_pct is not None else 224 / 256 + + def resize( + self, + image: "torch.Tensor", + size: Dict[str, int], + crop_pct: float, + interpolation: PILImageResampling = PILImageResampling.BICUBIC, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`Dict[str, int]`): + Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If + `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`. + Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`, + after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`. + crop_pct (`float`): + Percentage of the image to crop. Only has an effect if size < 384. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resizing the image. + + Returns: + `torch.Tensor`: Resized image. + """ + if not size.shortest_edge: + raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}") + shortest_edge = size["shortest_edge"] + + if shortest_edge < 384: + # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct + resize_shortest_edge = int(shortest_edge / crop_pct) + resize_size = get_resize_output_image_size( + image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST + ) + image = F.resize( + image, + resize_size, + interpolation=interpolation, + **kwargs, + ) + # then crop to (shortest_edge, shortest_edge) + return F.center_crop( + image, + (shortest_edge, shortest_edge), + **kwargs, + ) + else: + # warping (no cropping) when evaluated at 384 or larger + return F.resize( + image, + (shortest_edge, shortest_edge), + interpolation=interpolation, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + crop_pct: float = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop if size < 384. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Default to `"pt"` for PyTorch tensors if unset. + Fast image processors only support PyTorch tensors. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + default_to_square = kwargs.pop( + "default_to_square", self.default_to_square if self.default_to_square is not None else True + ) + size = get_size_dict(size=size, default_to_square=default_to_square) if size is not None else None + crop_pct = crop_pct if crop_pct is not None else self.crop_pct + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + return_tensors = "pt" if return_tensors is None else return_tensors + + images = self._prepare_input_images( + images=images, + do_convert_rgb=do_convert_rgb, + device=device, + input_data_format=input_data_format, + ) + + image_mean, image_std, size, crop_size, interpolation = self._prepare_process_arguments( + do_resize=do_resize, + size=size, + resample=resample, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + device=images[0].device, + **kwargs, + ) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["ConvNextImageProcessorFast"] diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py index 0a2fbc14ee94f8..4b64d84d191880 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py @@ -381,27 +381,6 @@ def __init__( self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad self.pad_size = pad_size - self._valid_processor_keys = [ - "images", - "annotations", - "return_segmentation_masks", - "masks_path", - "do_resize", - "size", - "resample", - "do_rescale", - "rescale_factor", - "do_normalize", - "do_convert_annotations", - "image_mean", - "image_std", - "do_pad", - "pad_size", - "format", - "return_tensors", - "data_format", - "input_data_format", - ] @classmethod # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast.from_dict with Detr->DeformableDetr @@ -695,6 +674,7 @@ def preprocess( data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, pad_size: Optional[Dict[str, int]] = None, + device: Optional["torch.device"] = None, **kwargs, ) -> BatchFeature: """ @@ -770,6 +750,8 @@ def preprocess( The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest height and width in the batch. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. """ if "pad_and_return_pixel_mask" in kwargs: logger.warning_once( @@ -799,7 +781,6 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad pad_size = self.pad_size if pad_size is None else pad_size format = self.format if format is None else format - device = kwargs.pop("device", None) # Make hashable for cache size = SizeDict(**size) @@ -811,7 +792,7 @@ def preprocess( if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: raise ValueError(f"Unsupported input image type {image_type}") - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) self._validate_input_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/deit/image_processing_deit_fast.py b/src/transformers/models/deit/image_processing_deit_fast.py new file mode 100644 index 00000000000000..2e541f31eb7cfb --- /dev/null +++ b/src/transformers/models/deit/image_processing_deit_fast.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for DeiT.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling + + +class DeiTImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast DeiT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in `preprocess`. + size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`): + Size of the image after `resize`. Can be overridden by `size` in `preprocess`. + resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*): + Whether to convert the image to RGB. + """ + + # To be checked against the slow image processor + # None values left after checking can be removed + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 256, "width": 256} + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + + +__all__ = ["DeiTImageProcessorFast"] diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index f010ffe272294f..4119d97199da61 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -390,27 +390,6 @@ def __init__( self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad self.pad_size = pad_size - self._valid_processor_keys = [ - "images", - "annotations", - "return_segmentation_masks", - "masks_path", - "do_resize", - "size", - "resample", - "do_rescale", - "rescale_factor", - "do_normalize", - "do_convert_annotations", - "image_mean", - "image_std", - "do_pad", - "pad_size", - "format", - "return_tensors", - "data_format", - "input_data_format", - ] @classmethod def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): @@ -695,6 +674,7 @@ def preprocess( data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, pad_size: Optional[Dict[str, int]] = None, + device: Optional["torch.device"] = None, **kwargs, ) -> BatchFeature: """ @@ -770,6 +750,8 @@ def preprocess( The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest height and width in the batch. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. """ if "pad_and_return_pixel_mask" in kwargs: logger.warning_once( @@ -799,7 +781,6 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad pad_size = self.pad_size if pad_size is None else pad_size format = self.format if format is None else format - device = kwargs.pop("device", None) # Make hashable for cache size = SizeDict(**size) @@ -811,7 +792,7 @@ def preprocess( if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: raise ValueError(f"Unsupported input image type {image_type}") - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) self._validate_input_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index e43d95b80a0ada..ce52acbeaf36ae 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -37,7 +37,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_images, make_list_of_images, to_numpy_array, valid_images, @@ -53,29 +53,6 @@ from PIL import Image -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ Divides an image into patches of a specified size. diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py new file mode 100644 index 00000000000000..f8521dad1f695b --- /dev/null +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LLaVa-NeXT.""" + +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BatchFeature, get_patch_output_size, get_size_dict, select_best_resolution +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + divide_to_patches, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + make_batched_images, + validate_kwargs, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class LlavaNextImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques + for processing high resolution images as explained in the [LLaVa paper](https://arxiv.org/abs/2310.03744). + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`): + A list of possible resolutions to use for processing high resolution images. The best resolution is selected + based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + # To be checked against the slow image processor + # None values left after checking can be removed + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + do_pad = True + image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + + def __init__( + self, + do_resize: bool = None, + size: Dict[str, int] = None, + image_grid_pinpoints: List = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_rescale: bool = None, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + do_convert_rgb: bool = None, + **kwargs, + ) -> None: + super().__init__( + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + **kwargs, + ) + self.image_grid_pinpoints = ( + image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints + ) + self.do_pad = do_pad if do_pad is not None else self.do_pad + + def _prepare_images_structure( + self, + images: ImageInput, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_batched_images(images) + + def _resize_for_patching( + self, + image: "torch.Tensor", + target_resolution: tuple, + interpolation: "F.InterpolationMode", + input_data_format: ChannelDimension, + ) -> "torch.Tensor": + """ + Resizes an image to a target resolution while maintaining aspect ratio. + + Args: + image ("torch.Tensor"): + The input image. + target_resolution (tuple): + The target resolution (height, width) of the image. + interpolation (`InterpolationMode`): + Resampling filter to use if resizing the image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + "torch.Tensor": The resized and padded image. + """ + new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + + # Resize the image + resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + + return resized_image + + def _pad_for_patching( + self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension + ) -> "torch.Tensor": + """ + Pad an image to a target resolution while maintaining aspect ratio. + """ + target_height, target_width = target_resolution + new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + + padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y]) + + return padded_image + + def _get_image_patches( + self, + image: "torch.Tensor", + grid_pinpoints, + size: tuple, + patch_size: int, + interpolation: "F.InterpolationMode", + ) -> List["torch.Tensor"]: + """ + Process an image with variable resolutions by dividing it into patches. + + Args: + image ("torch.Tensor"): + The input image to be processed. + grid_pinpoints (List): + A string representation of a list of possible resolutions. + size (`tuple`): + Size to resize the original image to. + patch_size (`int`): + Size of the patches to divide the image into. + interpolation (`"InterpolationMode"`): + Resampling filter to use if resizing the image. + + Returns: + List["torch.Tensor"]: A list of NumPy arrays containing the processed image patches. + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints must be a list of possible resolutions.") + + possible_resolutions = grid_pinpoints + + image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST) + best_resolution = select_best_resolution(image_size, possible_resolutions) + resized_image = self._resize_for_patching( + image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST + ) + padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST) + patches = divide_to_patches(padded_image, patch_size=patch_size) + resized_original_image = F.resize(image, size=size, interpolation=interpolation) + + image_patches = [resized_original_image] + patches + + return image_patches + + def _pad_for_batching( + self, + pixel_values: List["torch.Tensor"], + ) -> List["torch.Tensor"]: + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`List[torch.Tensor]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + + Returns: + List[`torch.Tensor`]: The padded images. + """ + max_patch = max(len(x) for x in pixel_values) + pixel_values = [ + torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]]) + for image in pixel_values + ] + + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + image_grid_pinpoints: List = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the output image after applying `center_crop`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Default to `"pt"` for PyTorch tensors if unset. + Fast image processors only support PyTorch tensors. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + default_to_square = kwargs.pop( + "default_to_square", self.default_to_square if self.default_to_square is not None else True + ) + size = get_size_dict(size=size, default_to_square=default_to_square) if size is not None else None + image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = self._prepare_input_images( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + image_mean, image_std, size, crop_size, interpolation = self._prepare_process_arguments( + do_resize=do_resize, + size=size, + resample=resample, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + device=images[0].device, + **kwargs, + ) + + processed_images = [] + image_sizes = [] + for image in images: + size_tuple = ( + (size.height, size.width) if size.height and size.width else (size.shortest_edge, size.shortest_edge) + ) + patch_size = ( + crop_size.height + if crop_size is not None and crop_size.height + else size.height + if size.height + else size.shortest_edge + ) + image_patches = self._get_image_patches( + image, + image_grid_pinpoints, + size=size_tuple, + patch_size=patch_size, + interpolation=interpolation, + ) + + # Group images by size for batched processing + processed_image_patches_grouped = {} + grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches) + for shape, stacked_image_patches in grouped_image_patches.items(): + if do_resize: + stacked_image_patches = self.resize( + image=stacked_image_patches, + size=size, + interpolation=interpolation, + ) + if do_center_crop: + stacked_image_patches = self.center_crop(stacked_image_patches, crop_size) + # Fused rescale and normalize + stacked_image_patches = self.rescale_and_normalize( + stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_image_patches_grouped[shape] = stacked_image_patches + processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) + processed_images.append(processed_image_patches) + image_sizes.append(get_image_size(image, input_data_format)) + + if do_pad: + processed_images = self._pad_for_batching(processed_images) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature( + data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors + ) + + +__all__ = ["LlavaNextImageProcessorFast"] diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index cde9a643e3efc4..bb62eeec524d83 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -36,7 +36,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -51,30 +51,6 @@ from PIL import Image -# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - # Copied from transformers.models.llava_next.image_processing_llava_next.divide_to_patches def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ @@ -143,7 +119,7 @@ def _get_patch_output_size(image, target_resolution, input_data_format): class LlavaOnevisionImageProcessor(BaseImageProcessor): r""" - Constructs a LLaVa-Onevisino-Video video processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame. + Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame. Args: do_resize (`bool`, *optional*, defaults to `True`): diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py new file mode 100644 index 00000000000000..7e62a4a0ab0f36 --- /dev/null +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -0,0 +1,539 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LLaVa-Onevision.""" + +import functools +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BatchFeature, get_patch_output_size, get_size_dict, select_best_resolution +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + SizeDict, + divide_to_patches, +) +from ...image_transforms import CenterCrop, GroupByShape, Normalize, ReorderImages, Rescale, Resize +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + make_batched_images, + validate_kwargs, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + from torch.nn import Sequential + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`): + A list of possible resolutions to use for processing high resolution images. The best resolution is selected + based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` + method. Not used for processinf videos. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + # To be checked against the slow image processor + # None values left after checking can be removed + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + do_pad = True + image_grid_pinpoints = [ + [384, 384], + [384, 768], + [384, 1152], + [384, 1536], + [384, 1920], + [384, 2304], + [768, 384], + [768, 768], + [768, 1152], + [768, 1536], + [768, 1920], + [768, 2304], + [1152, 384], + [1152, 768], + [1152, 1152], + [1152, 1536], + [1152, 1920], + [1152, 2304], + [1536, 384], + [1536, 768], + [1536, 1152], + [1536, 1536], + [1536, 1920], + [1536, 2304], + [1920, 384], + [1920, 768], + [1920, 1152], + [1920, 1536], + [1920, 1920], + [1920, 2304], + [2304, 384], + [2304, 768], + [2304, 1152], + [2304, 1536], + [2304, 1920], + [2304, 2304], + ] + + # Copied from transformers.models.llava_next.image_processing_llava_next_fast.LlavaNextImageProcessorFast.__init__ + def __init__( + self, + do_resize: bool = None, + size: Dict[str, int] = None, + image_grid_pinpoints: List = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_rescale: bool = None, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + do_convert_rgb: bool = None, + **kwargs, + ) -> None: + super().__init__( + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + **kwargs, + ) + self.image_grid_pinpoints = ( + image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints + ) + self.do_pad = do_pad if do_pad is not None else self.do_pad + + # Copied from transformers.models.llava_next.image_processing_llava_next_fast.LlavaNextImageProcessorFast._prepare_images_structure + def _prepare_images_structure( + self, + images: ImageInput, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_batched_images(images) + + # Copied from transformers.models.llava_next.image_processing_llava_next_fast.LlavaNextImageProcessorFast._resize_for_patching + def _resize_for_patching( + self, + image: "torch.Tensor", + target_resolution: tuple, + interpolation: "F.InterpolationMode", + input_data_format: ChannelDimension, + ) -> "torch.Tensor": + """ + Resizes an image to a target resolution while maintaining aspect ratio. + + Args: + image ("torch.Tensor"): + The input image. + target_resolution (tuple): + The target resolution (height, width) of the image. + interpolation (`InterpolationMode`): + Resampling filter to use if resizing the image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + "torch.Tensor": The resized and padded image. + """ + new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + + # Resize the image + resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + + return resized_image + + # Copied from transformers.models.llava_next.image_processing_llava_next_fast.LlavaNextImageProcessorFast._pad_for_patching + def _pad_for_patching( + self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension + ) -> "torch.Tensor": + """ + Pad an image to a target resolution while maintaining aspect ratio. + """ + target_height, target_width = target_resolution + new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + + padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y]) + + return padded_image + + # Copied from transformers.models.llava_next.image_processing_llava_next_fast.LlavaNextImageProcessorFast._get_image_patches + def _get_image_patches( + self, + image: "torch.Tensor", + grid_pinpoints, + size: tuple, + patch_size: int, + interpolation: "F.InterpolationMode", + ) -> List["torch.Tensor"]: + """ + Process an image with variable resolutions by dividing it into patches. + + Args: + image ("torch.Tensor"): + The input image to be processed. + grid_pinpoints (List): + A string representation of a list of possible resolutions. + size (`tuple`): + Size to resize the original image to. + patch_size (`int`): + Size of the patches to divide the image into. + interpolation (`"InterpolationMode"`): + Resampling filter to use if resizing the image. + + Returns: + List["torch.Tensor"]: A list of NumPy arrays containing the processed image patches. + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints must be a list of possible resolutions.") + + possible_resolutions = grid_pinpoints + + image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST) + best_resolution = select_best_resolution(image_size, possible_resolutions) + resized_image = self._resize_for_patching( + image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST + ) + padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST) + patches = divide_to_patches(padded_image, patch_size=patch_size) + resized_original_image = F.resize(image, size=size, interpolation=interpolation) + + image_patches = [resized_original_image] + patches + + return image_patches + + # Copied from transformers.models.llava_next.image_processing_llava_next_fast.LlavaNextImageProcessorFast._pad_for_batching + def _pad_for_batching( + self, + pixel_values: List["torch.Tensor"], + ) -> List["torch.Tensor"]: + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`List[torch.Tensor]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + + Returns: + List[`torch.Tensor`]: The padded images. + """ + max_patch = max(len(x) for x in pixel_values) + pixel_values = [ + torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]]) + for image in pixel_values + ] + + return pixel_values + + def _build_transforms( + self, + do_resize: bool, + size: SizeDict, + interpolation: "F.InterpolationMode", + do_center_crop: bool, + crop_size: int, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, List[float]], + image_std: Union[float, List[float]], + ) -> "Sequential": + """ + Given the input settings build the image transforms using a `Sequential` module. + """ + transforms = [] + + transforms.append(GroupByShape()) + if do_resize: + transforms.append(Resize(size, interpolation=interpolation)) + # Since the size was changed, we need to group the images by shape again + transforms.append(ReorderImages()) + transforms.append(GroupByShape()) + if do_center_crop: + transforms.append(CenterCrop(crop_size)) + # Since the size was changed, we need to group the images by shape again + transforms.append(ReorderImages()) + transforms.append(GroupByShape()) + # We can combine rescale and normalize into a single operation for speed + if do_rescale and do_normalize: + # image_mean and image_std have already been adjusted for rescaling + transforms.append(Normalize(image_mean, image_std)) + elif do_rescale: + transforms.append(Rescale(rescale_factor=rescale_factor)) + elif do_normalize: + transforms.append(Normalize(image_mean, image_std)) + + if isinstance(transforms[-1], GroupByShape): + # No added transforms, so we can remove the last GroupByShape + transforms.pop() + else: + # We necessarily have grouped images, so we need to reorder them back to the original order + transforms.append(ReorderImages()) + + return Sequential(*transforms) + + @functools.lru_cache(maxsize=1) + def get_transforms(self, **kwargs) -> "Sequential": + return self._build_transforms(**kwargs) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + image_grid_pinpoints: List = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the output image after applying `center_crop`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Default to `"pt"` for PyTorch tensors if unset. + Fast image processors only support PyTorch tensors. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + default_to_square = kwargs.pop( + "default_to_square", self.default_to_square if self.default_to_square is not None else True + ) + size = get_size_dict(size=size, default_to_square=default_to_square) if size is not None else None + image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = self._prepare_input_images( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + image_mean, image_std, size, crop_size, interpolation = self._prepare_process_arguments( + do_resize=do_resize, + size=size, + resample=resample, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + device=images[0].device, + **kwargs, + ) + + patches_transforms = self.get_transforms( + do_resize=do_resize, + size=size, + interpolation=interpolation, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + processed_images = [] + image_sizes = [] + for image in images: + size_tuple = ( + (size.height, size.width) if size.height and size.width else (size.shortest_edge, size.shortest_edge) + ) + patch_size = ( + crop_size.height + if crop_size is not None and crop_size.height + else size.height + if size.height + else size.shortest_edge + ) + image_patches = self._get_image_patches( + image, + image_grid_pinpoints, + size=size_tuple, + patch_size=patch_size, + interpolation=interpolation, + ) + + # apply torchvision transforms to patches + processed_image_patches = patches_transforms(image_patches) + + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) + processed_images.append(processed_image_patches) + image_sizes.append(get_image_size(image, input_data_format)) + + if do_pad: + processed_images = self._pad_for_batching(processed_images) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature( + data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors + ) + + +__all__ = ["LlavaOnevisionImageProcessorFast"] diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index f2d0afed9461b1..93460c24444950 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -19,7 +19,7 @@ from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, is_valid_image +from ...image_utils import ImageInput, is_valid_image, make_batched_images from ...processing_utils import ( ImagesKwargs, ProcessingKwargs, @@ -99,30 +99,6 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" -# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - class PaliGemmaProcessor(ProcessorMixin): r""" Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py index 082e255c8435b5..8b3bf5fb442d9a 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral_fast.py +++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -125,22 +125,6 @@ def __init__( self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] self.do_convert_rgb = do_convert_rgb - self._valid_processor_keys = [ - "images", - "do_resize", - "size", - "patch_size", - "resample", - "do_rescale", - "rescale_factor", - "do_normalize", - "image_mean", - "image_std", - "do_convert_rgb", - "return_tensors", - "data_format", - "input_data_format", - ] def resize( self, @@ -205,6 +189,7 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, **kwargs, ) -> BatchMixFeature: """ @@ -254,6 +239,8 @@ def preprocess( - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. """ patch_size = patch_size if patch_size is not None else self.patch_size patch_size = get_size_dict(patch_size, default_to_square=True) @@ -267,9 +254,8 @@ def preprocess( image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - device = kwargs.pop("device", None) - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) images_list = make_list_of_images(images) image_type = get_image_type(images_list[0][0]) @@ -311,8 +297,8 @@ def preprocess( if do_rescale and do_normalize: # fused rescale and normalize - new_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor) - new_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor) + image_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor) batch_images = [] batch_image_sizes = [] @@ -333,13 +319,10 @@ def preprocess( interpolation=interpolation, ) - if do_rescale and do_normalize: - # fused rescale and normalize - image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) - elif do_rescale: - image = image * rescale_factor - elif do_normalize: - image = F.normalize(image, image_mean, image_std) + # Fused rescale and normalize + image = self.rescale_and_normalize( + image, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) images.append(image) image_sizes.append(get_image_size(image, input_data_format)) diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index 5d8d0f58328a0b..6c2954de60b260 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -494,6 +494,7 @@ def preprocess( data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, pad_size: Optional[Dict[str, int]] = None, + device: Optional["torch.device"] = None, **kwargs, ) -> BatchFeature: """ @@ -569,6 +570,8 @@ def preprocess( The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest height and width in the batch. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. """ do_resize = self.do_resize if do_resize is None else do_resize size = self.size if size is None else size @@ -586,7 +589,6 @@ def preprocess( pad_size = self.pad_size if pad_size is None else pad_size format = self.format if format is None else format return_tensors = "pt" if return_tensors is None else return_tensors - device = kwargs.pop("device", None) # Make hashable for cache size = SizeDict(**size) diff --git a/src/transformers/models/siglip/image_processing_siglip_fast.py b/src/transformers/models/siglip/image_processing_siglip_fast.py new file mode 100644 index 00000000000000..cb9c32f5dec3a6 --- /dev/null +++ b/src/transformers/models/siglip/image_processing_siglip_fast.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling + + +class SiglipImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + + +__all__ = ["SiglipImageProcessorFast"] diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py index 5abf6cf10aa48e..113b48ee223fca 100644 --- a/src/transformers/models/vit/image_processing_vit_fast.py +++ b/src/transformers/models/vit/image_processing_vit_fast.py @@ -17,39 +17,46 @@ import functools from typing import Dict, List, Optional, Union -from ...image_processing_base import BatchFeature -from ...image_processing_utils import get_size_dict -from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict -from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale, convert_to_rgb +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + SizeDict, +) +from ...image_transforms import GroupByShape, Normalize, ReorderImages, Rescale, Resize from ...image_utils import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ChannelDimension, ImageInput, - ImageType, PILImageResampling, - get_image_type, - make_list_of_images, - pil_torch_interpolation_mapping, + validate_kwargs, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, ) -from ...utils import TensorType, logging -from ...utils.import_utils import is_torch_available, is_torchvision_available - - -logger = logging.get_logger(__name__) if is_torch_available(): import torch - + from torch.nn import Sequential if is_torchvision_available(): - from torchvision.transforms import Compose, Normalize, PILToTensor, Resize + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +logger = logging.get_logger(__name__) class ViTImageProcessorFast(BaseImageProcessorFast): r""" - Constructs a ViT image processor. + Constructs a fast ViT image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): @@ -61,6 +68,12 @@ class ViTImageProcessorFast(BaseImageProcessorFast): resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. @@ -80,132 +93,78 @@ class ViTImageProcessorFast(BaseImageProcessorFast): Whether to convert the image to RGB. """ - model_input_names = ["pixel_values"] - _transform_params = [ - "do_resize", - "do_rescale", - "do_normalize", - "size", - "resample", - "rescale_factor", - "image_mean", - "image_std", - "image_type", - ] - - def __init__( - self, - do_resize: bool = True, - size: Optional[Dict[str, int]] = None, - resample: PILImageResampling = PILImageResampling.BILINEAR, - do_rescale: bool = True, - rescale_factor: Union[int, float] = 1 / 255, - do_normalize: bool = True, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: Optional[bool] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - size = size if size is not None else {"height": 224, "width": 224} - size = get_size_dict(size) - self.do_resize = do_resize - self.do_rescale = do_rescale - self.do_normalize = do_normalize - self.size = size - self.resample = resample - self.rescale_factor = rescale_factor - self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN - self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD - self.do_convert_rgb = do_convert_rgb + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + do_resize = True + do_rescale = True + do_normalize = True def _build_transforms( self, do_resize: bool, - size: Dict[str, int], - resample: PILImageResampling, + size: SizeDict, + interpolation: "F.InterpolationMode", do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: Union[float, List[float]], image_std: Union[float, List[float]], - image_type: ImageType, - ) -> "Compose": + ) -> "Sequential": """ - Given the input settings build the image transforms using `torchvision.transforms.Compose`. + Given the input settings build the image transforms using a `Sequential` module. """ transforms = [] - # All PIL and numpy values need to be converted to a torch tensor - # to keep cross compatibility with slow image processors - if image_type == ImageType.PIL: - transforms.append(PILToTensor()) - - elif image_type == ImageType.NUMPY: - transforms.append(NumpyToTensor()) - + transforms.append(GroupByShape()) if do_resize: - transforms.append( - Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample]) - ) - + transforms.append(Resize(size, interpolation=interpolation)) + # Since the size was changed, we need to group the images by shape again + transforms.append(ReorderImages()) + transforms.append(GroupByShape()) # We can combine rescale and normalize into a single operation for speed if do_rescale and do_normalize: - transforms.append(FusedRescaleNormalize(image_mean, image_std, rescale_factor=rescale_factor)) + # image_mean and image_std have already been adjusted for rescaling + transforms.append(Normalize(image_mean, image_std)) elif do_rescale: transforms.append(Rescale(rescale_factor=rescale_factor)) elif do_normalize: transforms.append(Normalize(image_mean, image_std)) - return Compose(transforms) - - @functools.lru_cache(maxsize=1) - def _validate_input_arguments( - self, - return_tensors: Union[str, TensorType], - do_resize: bool, - size: Dict[str, int], - resample: PILImageResampling, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: Union[float, List[float]], - image_std: Union[float, List[float]], - data_format: Union[str, ChannelDimension], - image_type: ImageType, - ): - if return_tensors != "pt": - raise ValueError("Only returning PyTorch tensors is currently supported.") - - if data_format != ChannelDimension.FIRST: - raise ValueError("Only channel first data format is currently supported.") + if isinstance(transforms[-1], GroupByShape): + # No added transforms, so we can remove the last GroupByShape + transforms.pop() + else: + # We necessarily have grouped images, so we need to reorder them back to the original order + transforms.append(ReorderImages()) - if do_resize and None in (size, resample): - raise ValueError("Size and resample must be specified if do_resize is True.") + return Sequential(*transforms) - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and None in (image_mean, image_std): - raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.") + @functools.lru_cache(maxsize=1) + def get_transforms(self, **kwargs) -> "Sequential": + return self._build_transforms(**kwargs) def preprocess( self, images: ImageInput, - do_resize: Optional[bool] = None, + do_resize: bool = None, size: Dict[str, int] = None, - resample: PILImageResampling = None, - do_rescale: Optional[bool] = None, - rescale_factor: Optional[float] = None, - do_normalize: Optional[bool] = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, - return_tensors: Optional[Union[str, TensorType]] = "pt", - data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - do_convert_rgb: Optional[bool] = None, + device: Optional["torch.device"] = None, **kwargs, - ): + ) -> BatchFeature: """ Preprocess an image or batch of images. @@ -216,60 +175,75 @@ def preprocess( do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): - Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after - resizing. - resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): - `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has - an effect if `do_resize` is set to `True`. + Describes the maximum input dimensions to the model. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the output image after applying `center_crop`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): - Whether to rescale the image values between [0 - 1]. + Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): - Image mean to use if `do_normalize` is set to `True`. + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): - Image standard deviation to use if `do_normalize` is set to `True`. + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Only "pt" is supported + The type of tensors to return. Default to `"pt"` for PyTorch tensors if unset. + Fast image processors only support PyTorch tensors. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): - The channel dimension format for the output image. The following formats are currently supported: + The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - do_convert_rgb (`bool`, *optional*): - Whether to convert the image to RGB. + device (`torch.device`, *optional*): + The device to process the images on. If unset, the device is inferred from the input images. """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs) + do_resize = do_resize if do_resize is not None else self.do_resize - do_rescale = do_rescale if do_rescale is not None else self.do_rescale - do_normalize = do_normalize if do_normalize is not None else self.do_normalize + size = size if size is not None else self.size + default_to_square = kwargs.pop( + "default_to_square", self.default_to_square if self.default_to_square is not None else True + ) + size = get_size_dict(size=size, default_to_square=default_to_square) if size is not None else None resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - size = size if size is not None else self.size do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - return_tensors = "pt" if return_tensors is None else return_tensors - # Make hashable for cache - size = SizeDict(**size) - image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean - image_std = tuple(image_std) if isinstance(image_std, list) else image_std - - images = make_list_of_images(images) - image_type = get_image_type(images[0]) - if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: - raise ValueError(f"Unsupported input image type {image_type}") + images = self._prepare_input_images( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) - self._validate_input_arguments( + image_mean, image_std, size, crop_size, interpolation = self._prepare_process_arguments( do_resize=do_resize, size=size, resample=resample, + crop_size=crop_size, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, @@ -277,27 +251,21 @@ def preprocess( image_std=image_std, return_tensors=return_tensors, data_format=data_format, - image_type=image_type, + device=images[0].device, ) - if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] - transforms = self.get_transforms( do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize, size=size, - resample=resample, + interpolation=interpolation, rescale_factor=rescale_factor, image_mean=image_mean, image_std=image_std, - image_type=image_type, ) - transformed_images = [transforms(image) for image in images] - - data = {"pixel_values": torch.stack(transformed_images, dim=0)} - return BatchFeature(data, tensor_type=return_tensors) + transformed_images = transforms(images) + return BatchFeature(data={"pixel_values": torch.stack(transformed_images, dim=0)}, tensor_type=return_tensors) __all__ = ["ViTImageProcessorFast"] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index c2646300367033..a2352b7db2629f 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -171,6 +171,8 @@ class methods and docstrings. The channel dimension format for the output image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. + device (`str`, *optional*): + The device to use for processing (e.g. "cpu", "cuda"), only relevant for fast image processing. """ do_resize: Optional[bool] @@ -188,6 +190,7 @@ class methods and docstrings. do_center_crop: Optional[bool] data_format: Optional[ChannelDimension] input_data_format: Optional[Union[str, ChannelDimension]] + device: Optional[str] class VideosKwargs(TypedDict, total=False): diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 747f75386490fc..690c5a6af65db4 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -9,6 +9,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class BlipImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + +class CLIPImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + +class ConvNextImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class DeformableDetrImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] @@ -16,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class DeiTImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class DetrImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] @@ -23,6 +51,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class LlavaNextImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + +class LlavaOnevisionImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class PixtralImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] @@ -37,6 +79,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class SiglipImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class ViTImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/tests/models/blip/test_image_processing_blip.py b/tests/models/blip/test_image_processing_blip.py index d745f3420a61e2..12ef340108178c 100644 --- a/tests/models/blip/test_image_processing_blip.py +++ b/tests/models/blip/test_image_processing_blip.py @@ -17,7 +17,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,6 +25,9 @@ if is_vision_available(): from transformers import BlipImageProcessor + if is_torchvision_available(): + from transformers import BlipImageProcessorFast + class BlipImageProcessingTester(unittest.TestCase): def __init__( @@ -89,6 +92,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class BlipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = BlipImageProcessor if is_vision_available() else None + fast_image_processing_class = BlipImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -99,50 +103,36 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_resize")) - self.assertTrue(hasattr(image_processor, "size")) - self.assertTrue(hasattr(image_processor, "do_normalize")) - self.assertTrue(hasattr(image_processor, "image_mean")) - self.assertTrue(hasattr(image_processor, "image_std")) - self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) @require_torch @require_vision class BlipImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = BlipImageProcessor if is_vision_available() else None + fast_image_processing_class = BlipImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() - self.image_processor_tester = BlipImageProcessingTester(self, num_channels=4) - self.expected_encoded_image_num_channels = 3 + self.image_processor_tester = BlipImageProcessingTester(self) @property def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_resize")) - self.assertTrue(hasattr(image_processor, "size")) - self.assertTrue(hasattr(image_processor, "do_normalize")) - self.assertTrue(hasattr(image_processor, "image_mean")) - self.assertTrue(hasattr(image_processor, "image_std")) - self.assertTrue(hasattr(image_processor, "do_convert_rgb")) - - @unittest.skip(reason="BlipImageProcessor does not support 4 channels yet") # FIXME Amy - def test_call_numpy(self): - return super().test_call_numpy() - - @unittest.skip(reason="BlipImageProcessor does not support 4 channels yet") # FIXME Amy - def test_call_pytorch(self): - return super().test_call_torch() - - @unittest.skip(reason="BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy - def test_call_pil(self): - pass - - @unittest.skip(reason="BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy - def test_call_numpy_4_channels(self): - pass + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) diff --git a/tests/models/clip/test_image_processing_clip.py b/tests/models/clip/test_image_processing_clip.py index ef4fdc819b2c4e..7387ede6ed1656 100644 --- a/tests/models/clip/test_image_processing_clip.py +++ b/tests/models/clip/test_image_processing_clip.py @@ -17,7 +17,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,6 +25,9 @@ if is_vision_available(): from transformers import CLIPImageProcessor + if is_torchvision_available(): + from transformers import CLIPImageProcessorFast + class CLIPImageProcessingTester: def __init__( @@ -44,6 +47,7 @@ def __init__( image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, ): + super().__init__() size = size if size is not None else {"shortest_edge": 20} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} self.parent = parent @@ -92,6 +96,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class CLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = CLIPImageProcessor if is_vision_available() else None + fast_image_processing_class = CLIPImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -102,21 +107,23 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) diff --git a/tests/models/convnext/test_image_processing_convnext.py b/tests/models/convnext/test_image_processing_convnext.py index 14a6b3e8e1aabc..513dff048b4b75 100644 --- a/tests/models/convnext/test_image_processing_convnext.py +++ b/tests/models/convnext/test_image_processing_convnext.py @@ -17,7 +17,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,6 +25,9 @@ if is_vision_available(): from transformers import ConvNextImageProcessor + if is_torchvision_available(): + from transformers import ConvNextImageProcessorFast + class ConvNextImageProcessingTester(unittest.TestCase): def __init__( @@ -86,6 +89,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class ConvNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = ConvNextImageProcessor if is_vision_available() else None + fast_image_processing_class = ConvNextImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -96,17 +100,25 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "crop_pct")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "crop_pct")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + + @unittest.skip( + "Skipping as ConvNextImageProcessor uses center_crop and center_crop functions are not equivalent for fast and slow processors" + ) + def test_slow_fast_equivalence_batched(self): + pass diff --git a/tests/models/deit/test_image_processing_deit.py b/tests/models/deit/test_image_processing_deit.py index 7792ac10e057e6..80e03545e6e6b2 100644 --- a/tests/models/deit/test_image_processing_deit.py +++ b/tests/models/deit/test_image_processing_deit.py @@ -17,7 +17,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,6 +25,9 @@ if is_vision_available(): from transformers import DeiTImageProcessor + if is_torchvision_available(): + from transformers import DeiTImageProcessorFast + class DeiTImageProcessingTester(unittest.TestCase): def __init__( @@ -91,6 +94,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class DeiTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = DeiTImageProcessor if is_vision_available() else None + fast_image_processing_class = DeiTImageProcessorFast if is_torchvision_available() else None test_cast_dtype = True def setUp(self): @@ -102,20 +106,22 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 20, "width": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 20, "width": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) diff --git a/tests/models/llava_next/test_image_processing_llava_next.py b/tests/models/llava_next/test_image_processing_llava_next.py index 4b3f5e0dd3ff42..957a5c3abd4818 100644 --- a/tests/models/llava_next/test_image_processing_llava_next.py +++ b/tests/models/llava_next/test_image_processing_llava_next.py @@ -20,7 +20,7 @@ from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from transformers.models.llava_next.image_processing_llava_next import select_best_resolution from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -33,6 +33,9 @@ from transformers import LlavaNextImageProcessor + if is_torchvision_available(): + from transformers import LlavaNextImageProcessorFast + class LlavaNextImageProcessingTester: def __init__( @@ -52,6 +55,7 @@ def __init__( image_std=OPENAI_CLIP_STD, do_convert_rgb=True, ): + super().__init__() size = size if size is not None else {"shortest_edge": 20} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} self.parent = parent @@ -102,6 +106,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = LlavaNextImageProcessor if is_vision_available() else None + fast_image_processing_class = LlavaNextImageProcessorFast if is_torchvision_available() else None # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaNext def setUp(self): @@ -114,26 +119,28 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) - self.assertTrue(hasattr(image_processing, "image_grid_pinpoints")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "image_grid_pinpoints")) # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) def test_select_best_resolution(self): possible_resolutions = [[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]] @@ -143,59 +150,62 @@ def test_select_best_resolution(self): self.assertEqual(best_resolution, (672, 336)) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) @unittest.skip( reason="LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet" @@ -204,19 +214,20 @@ def test_call_numpy_4_channels(self): pass def test_nested_input(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - - # Test batched as a list of images - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched as a nested list of images, where each sublist is one batch - image_inputs_nested = [image_inputs[:3], image_inputs[3:]] - encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1445, 3, 18, 18) - self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) - - # Image processor should return same pixel values, independently of ipnut format - self.assertTrue((encoded_images_nested == encoded_images).all()) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + + # Image processor should return same pixel values, independently of ipnut format + self.assertTrue((encoded_images_nested == encoded_images).all()) diff --git a/tests/models/llava_next_video/test_image_processing_llava_next_video.py b/tests/models/llava_next_video/test_image_processing_llava_next_video.py index 385475c262f197..6a0cadc93c213b 100644 --- a/tests/models/llava_next_video/test_image_processing_llava_next_video.py +++ b/tests/models/llava_next_video/test_image_processing_llava_next_video.py @@ -151,13 +151,14 @@ def test_image_processor_properties(self): # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) def test_call_pil(self): # Initialize image_processing diff --git a/tests/models/llava_onevision/test_image_processing_llava_onevision.py b/tests/models/llava_onevision/test_image_processing_llava_onevision.py index f392f2b8956d4b..3fbd358f972d75 100644 --- a/tests/models/llava_onevision/test_image_processing_llava_onevision.py +++ b/tests/models/llava_onevision/test_image_processing_llava_onevision.py @@ -19,7 +19,7 @@ from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -30,7 +30,10 @@ if is_vision_available(): from PIL import Image - from transformers import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor + from transformers import LlavaOnevisionImageProcessor + + if is_torchvision_available(): + from transformers import LlavaOnevisionImageProcessorFast, LlavaOnevisionVideoProcessor class LlavaOnevisionImageProcessingTester: @@ -49,6 +52,7 @@ def __init__( image_std=OPENAI_CLIP_STD, do_convert_rgb=True, ): + super().__init__() size = size if size is not None else {"height": 20, "width": 20} self.parent = parent self.batch_size = batch_size @@ -121,6 +125,7 @@ def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = LlavaOnevisionImageProcessor if is_vision_available() else None + fast_image_processing_class = LlavaOnevisionImageProcessorFast if is_torchvision_available() else None video_processing_class = LlavaOnevisionVideoProcessor if is_vision_available() else None # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaOnevision @@ -134,14 +139,15 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) - self.assertTrue(hasattr(image_processing, "image_grid_pinpoints")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "image_grid_pinpoints")) def test_video_processor_properties(self): image_processing = self.video_processing_class(**self.image_processor_dict) @@ -153,66 +159,70 @@ def test_video_processor_properties(self): self.assertTrue(hasattr(image_processing, "do_convert_rgb")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 20, "width": 20}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 20, "width": 20}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) @unittest.skip( reason="LlavaOnevisionImageProcessor doesn't treat 4 channel PIL and numpy consistently yet" @@ -221,22 +231,23 @@ def test_call_numpy_4_channels(self): pass def test_nested_input(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - # Test batched as a list of images - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched as a nested list of images, where each sublist is one batch - image_inputs_nested = [image_inputs[:3], image_inputs[3:]] - encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 1522, 3, 20, 20) - self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1522, 3, 20, 20) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) - # Image processor should return same pixel values, independently of input format - self.assertTrue((encoded_images_nested == encoded_images).all()) + # Image processor should return same pixel values, independently of input format + self.assertTrue((encoded_images_nested == encoded_images).all()) def test_call_pil_video(self): # Initialize image_processing @@ -289,3 +300,9 @@ def test_call_pytorch_video(self): encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos expected_output_video_shape = (7, 8, 3, 20, 20) self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + @unittest.skip( + reason="LlavaOnevisionImageProcessorFast doesn't compile (infinitely) when using class transforms" + ) # FIXME yoni + def test_can_compile_fast_image_processor(self): + pass diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py index 1377b676917f47..a0931a5ca157fc 100644 --- a/tests/models/pixtral/test_image_processing_pixtral.py +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -281,7 +281,40 @@ def test_slow_fast_equivalence(self): encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") - self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2)) + self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values[0][0] - encoding_fast.pixel_values[0][0])).item(), 1e-3 + ) + + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") + + for i in range(len(encoding_slow.pixel_values)): + self.assertTrue( + torch.allclose(encoding_slow.pixel_values[i][0], encoding_fast.pixel_values[i][0], atol=1e-1) + ) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values[i][0] - encoding_fast.pixel_values[i][0])).item(), 1e-3 + ) @slow @require_torch_gpu diff --git a/tests/models/siglip/test_image_processing_siglip.py b/tests/models/siglip/test_image_processing_siglip.py index 02bf6d78c8d415..a32e8af0c570a3 100644 --- a/tests/models/siglip/test_image_processing_siglip.py +++ b/tests/models/siglip/test_image_processing_siglip.py @@ -17,7 +17,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,6 +25,9 @@ if is_vision_available(): from transformers import SiglipImageProcessor + if is_torchvision_available(): + from transformers import SiglipImageProcessorFast + class SiglipImageProcessingTester(unittest.TestCase): def __init__( @@ -90,6 +93,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Siglip class SiglipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = SiglipImageProcessor if is_vision_available() else None + fast_image_processing_class = SiglipImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -101,25 +105,27 @@ def image_processor_dict(self): # Ignore copy def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "resample")) - self.assertTrue(hasattr(image_processing, "do_rescale")) - self.assertTrue(hasattr(image_processing, "rescale_factor")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) # Ignore copy def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 18}) - - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size={"height": 84, "width": 84} - ) - self.assertEqual(image_processor.size, {"height": 84, "width": 84}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size={"height": 84, "width": 84} + ) + self.assertEqual(image_processor.size, {"height": 84, "width": 84}) @unittest.skip(reason="not supported") # Ignore copy diff --git a/tests/models/video_llava/test_image_processing_video_llava.py b/tests/models/video_llava/test_image_processing_video_llava.py index b666c20ab848dd..e161e7b2db4a3b 100644 --- a/tests/models/video_llava/test_image_processing_video_llava.py +++ b/tests/models/video_llava/test_image_processing_video_llava.py @@ -153,13 +153,14 @@ def test_image_processor_properties(self): # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) def test_call_pil(self): # Initialize image_processing diff --git a/tests/models/vit/test_image_processing_vit.py b/tests/models/vit/test_image_processing_vit.py index 5a94b4bb6e1270..dd80fa5d18b55c 100644 --- a/tests/models/vit/test_image_processing_vit.py +++ b/tests/models/vit/test_image_processing_vit.py @@ -25,8 +25,8 @@ if is_vision_available(): from transformers import ViTImageProcessor -if is_torchvision_available(): - from transformers import ViTImageProcessorFast + if is_torchvision_available(): + from transformers import ViTImageProcessorFast class ViTImageProcessingTester(unittest.TestCase): diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 1cb92174df1d8a..0d46007dbd8c4c 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -165,23 +165,50 @@ def setUp(self): @require_vision @require_torch def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + dummy_image = Image.open( requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + ) + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): if not self.test_slow_image_processor or not self.test_fast_image_processor: self.skipTest(reason="Skipping slow/fast equivalence test") if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) - encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") - encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") - self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2)) + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + ) @require_vision @require_torch @@ -193,6 +220,8 @@ def test_fast_is_faster_than_slow(self): self.skipTest(reason="Skipping speed test as one of the image processors is not defined") def measure_time(image_processor, image): + # Warmup + _ = image_processor(image, return_tensors="pt") start = time.time() _ = image_processor(image, return_tensors="pt") return time.time() - start @@ -268,8 +297,31 @@ def test_save_load_fast_slow(self): image_processor_fast_1.save_pretrained(tmpdirname) image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname) - self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict()) - self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict()) + dict_slow_0 = image_processor_slow_0.to_dict() + dict_slow_1 = image_processor_slow_1.to_dict() + difference = { + key: dict_slow_0.get(key) if key in dict_slow_0 else dict_slow_1.get(key) + for key in set(dict_slow_0) ^ set(dict_slow_1) + } + dict_slow_0 = {key: dict_slow_0[key] for key in set(dict_slow_0) & set(dict_slow_1)} + dict_slow_1 = {key: dict_slow_1[key] for key in set(dict_slow_0) & set(dict_slow_1)} + # check that all additional keys are None + self.assertTrue(all(value is None for value in difference.values())) + # check that the remaining keys are the same + self.assertEqual(dict_slow_0, dict_slow_1) + + dict_fast_0 = image_processor_fast_0.to_dict() + dict_fast_1 = image_processor_fast_1.to_dict() + difference = { + key: dict_fast_0.get(key) if key in dict_fast_0 else dict_fast_1.get(key) + for key in set(dict_fast_0) ^ set(dict_fast_1) + } + dict_fast_0 = {key: dict_fast_0[key] for key in set(dict_fast_0) & set(dict_fast_1)} + dict_fast_1 = {key: dict_fast_1[key] for key in set(dict_fast_0) & set(dict_fast_1)} + # check that all additional keys are None + self.assertTrue(all(value is None for value in difference.values())) + # check that the remaining keys are the same + self.assertEqual(dict_fast_0, dict_fast_1) def test_save_load_fast_slow_auto(self): "Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor." @@ -291,8 +343,31 @@ def test_save_load_fast_slow_auto(self): image_processor_fast_1.save_pretrained(tmpdirname) image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False) - self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict()) - self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict()) + dict_slow_0 = image_processor_slow_0.to_dict() + dict_slow_1 = image_processor_slow_1.to_dict() + difference = { + key: dict_slow_0.get(key) if key in dict_slow_0 else dict_slow_1.get(key) + for key in set(dict_slow_0) ^ set(dict_slow_1) + } + dict_slow_0 = {key: dict_slow_0[key] for key in set(dict_slow_0) & set(dict_slow_1)} + dict_slow_1 = {key: dict_slow_1[key] for key in set(dict_slow_0) & set(dict_slow_1)} + # check that all additional keys are None + self.assertTrue(all(value is None for value in difference.values())) + # check that the remaining keys are the same + self.assertEqual(dict_slow_0, dict_slow_1) + + dict_fast_0 = image_processor_fast_0.to_dict() + dict_fast_1 = image_processor_fast_1.to_dict() + difference = { + key: dict_fast_0.get(key) if key in dict_fast_0 else dict_fast_1.get(key) + for key in set(dict_fast_0) ^ set(dict_fast_1) + } + dict_fast_0 = {key: dict_fast_0[key] for key in set(dict_fast_0) & set(dict_fast_1)} + dict_fast_1 = {key: dict_fast_1[key] for key in set(dict_fast_0) & set(dict_fast_1)} + # check that all additional keys are None + self.assertTrue(all(value is None for value in difference.values())) + # check that the remaining keys are the same + self.assertEqual(dict_fast_0, dict_fast_1) def test_init_without_params(self): for image_processing_class in self.image_processor_list: