diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 4fef6012012f36..81e8d9185623aa 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -123,11 +123,11 @@ def rescale( if not isinstance(image, np.ndarray): raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") - rescaled_image = image * scale + rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first if data_format is not None: rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format) - rescaled_image = rescaled_image.astype(dtype) + rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end return rescaled_image