diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index d929997ee87369..e9b9467f580a2a 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -19,6 +19,8 @@ import pathlib import tempfile +import requests + from transformers import AutoImageProcessor, BatchFeature from transformers.image_utils import AnnotationFormat, AnnotionFormat from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision @@ -146,6 +148,55 @@ def setUp(self): self.image_processor_list = image_processor_list + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest("Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest("Skipping slow/fast equivalence test as one of the image processors is not defined") + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3)) + + @require_vision + @require_torch + def test_fast_is_faster_than_slow(self): + import time + + def measure_time(self, image_processor, dummy_image): + start = time.time() + _ = image_processor(dummy_image, return_tensors="pt") + return time.time() - start + + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest("Skipping speed test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest("Skipping speed test as one of the image processors is not defined") + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + slow_time = self.measure_time(image_processor_slow, dummy_image) + fast_time = self.measure_time(image_processor_fast, dummy_image) + + self.assertLessEqual(fast_time, slow_time) + def test_image_processor_to_json_string(self): for image_processing_class in self.image_processor_list: image_processor = image_processing_class(**self.image_processor_dict)