Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fall back to slow image processor in ImageProcessingAuto when no fast processor available #34785

Merged
merged 7 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions docs/source/en/main_classes/image_processor.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from transformers import AutoImageProcessor

processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True)
```
Note that `use_fast` will be set to `True` by default in a future release.

When using a fast image processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise.

Expand All @@ -42,21 +43,17 @@ images_processed = processor(images, return_tensors="pt", device="cuda")
Here are some speed comparisons between the base and fast image processors for the `DETR` and `RT-DETR` models, and how they impact overall inference time:

<div class="flex">
<div>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_padded.png" />
</div>
<div>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_batched_compiled.png" />
</div>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_padded.png" />
</div>
<div class="flex">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_batched_compiled.png" />
</div>

<div class="flex">
<div>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_single.png" />
</div>
<div>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_batched.png" />
</div>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_single.png" />
</div>
<div class="flex">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_batched.png" />
</div>

These benchmarks were run on an [AWS EC2 g5.2xlarge instance](https://aws.amazon.com/ec2/instance-types/g5/), utilizing an NVIDIA A10G Tensor Core GPU.
Expand Down
60 changes: 45 additions & 15 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)


def image_processor_class_from_name(class_name: str):
def get_image_processor_class_from_name(class_name: str):
if class_name == "BaseImageProcessorFast":
return BaseImageProcessorFast

Expand Down Expand Up @@ -368,7 +368,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
identifier allowed by git.
use_fast (`bool`, *optional*, defaults to `False`):
Use a fast torchvision-base image processor if it is supported for a given model.
If a fast tokenizer is not available for a given model, a normal numpy-based image processor
If a fast image processor is not available for a given model, a normal numpy-based image processor
is returned instead.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
If `False`, then this function returns just the final image processor object. If `True`, then this
Expand Down Expand Up @@ -416,6 +416,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
kwargs["token"] = use_auth_token

config = kwargs.pop("config", None)
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
use_fast = kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
Expand Down Expand Up @@ -451,42 +452,71 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if not is_timm_config_dict(config_dict):
raise initial_exception

image_processor_class = config_dict.get("image_processor_type", None)
image_processor_type = config_dict.get("image_processor_type", None)
image_processor_auto_map = None
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]

# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
# and if so, infer the image processor class from there.
if image_processor_class is None and image_processor_auto_map is None:
if image_processor_type is None and image_processor_auto_map is None:
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
if feature_extractor_class is not None:
image_processor_class = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")

# If we don't find the image processor class in the image processor config, let's try the model config.
if image_processor_class is None and image_processor_auto_map is None:
if image_processor_type is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
# It could be in `config.image_processor_type``
image_processor_class = getattr(config, "image_processor_type", None)
image_processor_type = getattr(config, "image_processor_type", None)
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
image_processor_auto_map = config.auto_map["AutoImageProcessor"]

if image_processor_class is not None:
# Update class name to reflect the use_fast option. If class is not found, None is returned.
if use_fast is not None:
if use_fast and not image_processor_class.endswith("Fast"):
image_processor_class += "Fast"
elif not use_fast and image_processor_class.endswith("Fast"):
image_processor_class = image_processor_class[:-4]
image_processor_class = image_processor_class_from_name(image_processor_class)
image_processor_class = None
# TODO: @yoni, change logic in v4.48 (when use_fast set to True by default)
if image_processor_type is not None:
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
if use_fast is None:
use_fast = image_processor_type.endswith("Fast")
if not use_fast:
logger.warning_once(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
# Update class name to reflect the use_fast option. If class is not found, we fall back to the slow version.
if use_fast and not is_torchvision_available():
logger.warning_once(
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
)
use_fast = False
if use_fast:
if not image_processor_type.endswith("Fast"):
image_processor_type += "Fast"
for _, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
if image_processor_type in image_processors:
break
else:
image_processor_type = image_processor_type[:-4]
use_fast = False
logger.warning_once(
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
" Falling back to the slow version."
)
image_processor_class = get_image_processor_class_from_name(image_processor_type)
else:
image_processor_type = (
image_processor_type[:-4] if image_processor_type.endswith("Fast") else image_processor_type
)
image_processor_class = get_image_processor_class_from_name(image_processor_type)

has_remote_code = image_processor_auto_map is not None
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/vit/image_processing_vit_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
return_tensors = "pt" if return_tensors is None else return_tensors
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, the fast vit image processor will crash in the default behavior (when return_tensors is not specified). This is now a bigger problem with fast image processors used by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this kind of breaking change?

(I am also curious what if a user is using TF/Flax model while their environment has torch/torchvision installed. Is the fast image processor will be used by default and will return torch tensor by default ..?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, yes this might be a problem. In general, I'm not too sure why it was decided to constrain fast image processors to output only torch tensors. Would be glad to know if there was a reason for that, otherwise it might be something we would want to reconsider.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I am not even sure TF is used for image models! Fine by me!

# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
Expand Down
1 change: 1 addition & 0 deletions tests/models/auto/test_image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_image_processor_not_found(self):
def test_use_fast_selection(self):
checkpoint = "hf-internal-testing/tiny-random-vit"

# TODO: @yoni, change in v4.48 (when use_fast set to True by default)
# Slow image processor is selected by default
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
self.assertIsInstance(image_processor, ViTImageProcessor)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/detr/test_image_processing_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
Expand Down Expand Up @@ -669,6 +669,7 @@ def test_longest_edge_shortest_edge_resizing_strategy(self):

@slow
@require_torch_gpu
@require_torchvision
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
Expand Down Expand Up @@ -724,6 +725,7 @@ def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):

@slow
@require_torch_gpu
@require_torchvision
def test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations(self):
# prepare image, target and masks_path
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
Expand Down
3 changes: 2 additions & 1 deletion tests/models/rt_detr/test_image_processing_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import requests

from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
Expand Down Expand Up @@ -374,6 +374,7 @@ def test_batched_coco_detection_annotations(self):

@slow
@require_torch_gpu
@require_torchvision
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
# prepare image and target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from transformers import BertTokenizerFast
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer
from transformers.testing_utils import require_tokenizers, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torchvision_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor
from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor, ViTImageProcessorFast


@require_tokenizers
Expand Down Expand Up @@ -63,6 +63,8 @@ def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_image_processor(self, **kwargs):
if is_torchvision_available():
return ViTImageProcessorFast.from_pretrained(self.tmpdirname, **kwargs)
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)

def tearDown(self):
Expand All @@ -81,7 +83,7 @@ def test_save_load_pretrained_default(self):
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))

self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
self.assertIsInstance(processor.image_processor, (ViTImageProcessor, ViTImageProcessorFast))

def test_save_load_pretrained_additional_features(self):
processor = VisionTextDualEncoderProcessor(
Expand All @@ -100,7 +102,7 @@ def test_save_load_pretrained_additional_features(self):
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))

self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
self.assertIsInstance(processor.image_processor, (ViTImageProcessor, ViTImageProcessorFast))

def test_image_processor(self):
image_processor = self.get_image_processor()
Expand All @@ -110,8 +112,8 @@ def test_image_processor(self):

image_input = self.prepare_image_inputs()

input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
input_feat_extract = image_processor(image_input, return_tensors="pt")
input_processor = processor(images=image_input, return_tensors="pt")

for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_image_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,15 @@ def test_image_processor_from_and_save_pretrained(self):
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())

def test_image_processor_save_load_with_autoimageprocessor(self):
for image_processing_class in self.image_processor_list:
for i, image_processing_class in enumerate(self.image_processor_list):
image_processor_first = image_processing_class(**self.image_processor_dict)

with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)

image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
use_fast = i == 1
image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=use_fast)

self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())

Expand Down