Skip to content

Commit

Permalink
smol improvements to support more flexible usage (#34857)
Browse files Browse the repository at this point in the history
* smol improvements to support more flexible usage

* ruff
  • Loading branch information
andimarafioti authored Nov 22, 2024
1 parent 42b36d7 commit 861758e
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions src/transformers/models/idefics3/image_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@


logger = logging.get_logger(__name__)
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum


if is_vision_available():
Expand Down Expand Up @@ -116,7 +117,6 @@ def _resize_output_size_scale_below_upper_bound(
def get_resize_output_image_size(
image,
resolution_max_side: int,
max_image_size: int = 1820,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
"""
Expand All @@ -126,24 +126,18 @@ def get_resize_output_image_size(
Image to resize.
resolution_max_side (`int`):
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
input aspect ratio, with a lower bound of `min_image_size`.
max_image_size (`int`, *optional*, defaults to 1820):
Maximum image resolution. If the image is larger than this size, the longest edge will be resized to this
value, with the shortest edge resized to keep the input aspect ratio, with a lower bound of `min_image_size`.
input aspect ratio.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
The output size of the image after resizing.
"""
if resolution_max_side > max_image_size:
raise ValueError("`resolution_max_side` cannot be larger than `max_image_size`")

height, width = get_image_size(image, channel_dim=input_data_format)

# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
# Find the output size when scaling the image to be below the max_image_size
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=max_image_size)
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
return height, width


Expand Down Expand Up @@ -251,7 +245,7 @@ def convert_to_rgb(
data_format = input_data_format if data_format is None else data_format

mode = "P" if palette is not None else None
image = to_pil_image(image, image_mode=mode)
image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format)
if image.mode == "P" and palette is not None:
image.putpalette(palette)

Expand Down Expand Up @@ -404,7 +398,7 @@ def resize(
image_mode = None
if image.ndim == 2 or image.shape[-1] == 1:
image_mode = "P"
image = to_pil_image(image, image_mode=image_mode)
image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format)

resized_image = image.resize((size[1], size[0]), resample=resample)
resized_image = np.array(resized_image)
Expand Down Expand Up @@ -754,6 +748,16 @@ def preprocess(
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]

# Extra channel dimension for grayscale images
if input_data_format in [ChannelDimension.LAST, None]:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]

if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
Expand All @@ -764,18 +768,6 @@ def preprocess(
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))

# Extra channel dimension for grayscale images
if input_data_format == ChannelDimension.LAST:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]
else:
raise ValueError(f"Invalid channel dimension format {input_data_format}.")

if do_resize:
images_list = [
[
Expand Down

0 comments on commit 861758e

Please sign in to comment.