Skip to content

Commit

Permalink
Fix LLaVA-NeXT handling of non-square images
Browse files Browse the repository at this point in the history
We could get shape mismatches with non-square images, resulting in an
exception that crashed the backend.

When post-processing an image, features corresponding to padding are
removed when padding was needed. This is also reflected in the calculation
of the number of image tokens to get the correct number of slots.
However, there was a mismatch between the post-processing and the slot
calculation. The image post-processing could exclude fewer padding features
due to rounding. This change updates the image token calculation to
correspond to the image postprocessing.

Fixes #1777.

While investigating this, I found another issue where the upstream code
contains a bug that swaps the height and width dimensions after computing
the image grid shape. Since the models were also trained with this bug,
we should reproduce the same bug to ensure that we are generating the
same features.
  • Loading branch information
danieldk committed Jun 20, 2024
1 parent 9ce4552 commit e7b1d5e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
10 changes: 7 additions & 3 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ fn get_unpadded_features(
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
(new_height, current_width)
let padding = (current_height - new_height) / 2;
(current_height - (2 * padding), current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
let padding = (current_width - new_width) / 2;
(current_height, current_width - (2 * padding))
};

let unpadded_features = current_height * current_width;
Expand All @@ -88,7 +90,9 @@ impl LlavaNext {
let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) =
// Dimensions are intentionally swapped to be bug-compatible with
// upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
let (num_patch_width, num_patch_height) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);

let (unpadded_features, newline_features) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
tuple: The shape of the image patch grid in the format (height, width).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
Expand Down Expand Up @@ -229,7 +229,10 @@ def forward(
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(

# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
Expand Down
25 changes: 17 additions & 8 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
Expand Down Expand Up @@ -64,19 +64,26 @@ def image_text_replacement(processor, image_input, config, image_id) -> str:


def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width

aspect_ratio: float = width / height
aspect_ratio: float = original_width / original_height
current_aspect_ratio: float = current_width / current_height

if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = (width * current_height) // height
current_width = new_width
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)

unpadded_features = current_height * current_width
newline_features = current_height
Expand All @@ -95,7 +102,9 @@ def get_number_of_features(height: int, width: int, config) -> int:

npatches = image_size // patch_size

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
Expand Down

0 comments on commit e7b1d5e

Please sign in to comment.