From 9bfbc82676e9327798652dbd1412b7784d9a4578 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 2 Jan 2025 11:00:02 +0000 Subject: [PATCH] Implement merged processor for llava-next Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 1 + .../vllm_add_dummy_model/my_llava.py | 4 +- vllm/model_executor/models/clip.py | 25 ++ vllm/model_executor/models/fuyu.py | 6 +- vllm/model_executor/models/llava.py | 335 +++++++++++------- vllm/model_executor/models/llava_next.py | 321 ++++++----------- vllm/model_executor/models/pixtral.py | 66 +++- vllm/model_executor/models/siglip.py | 25 ++ vllm/model_executor/models/utils.py | 2 +- vllm/model_executor/models/vision.py | 52 +++ 10 files changed, 483 insertions(+), 354 deletions(-) create mode 100644 vllm/model_executor/models/vision.py diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 9573351b4dff1..6cb0c266a0fe3 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -631,6 +631,7 @@ def _test_processing_cache_correctness( ("facebook/chameleon-7b", {"image": False}), ("adept/fuyu-8b", {"image": False}), ("llava-hf/llava-1.5-7b-hf", {"image": True}), + ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), ("mistral-community/pixtral-12b", {"image": True}), ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 0d90635093ac7..06dfebbb95527 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -3,13 +3,11 @@ import torch from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - LlavaMultiModalProcessor, - get_max_llava_image_tokens) + LlavaMultiModalProcessor) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) class MyLlava(LlavaForConditionalGeneration): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a5300dfd986f3..0188452054b8c 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -24,6 +24,8 @@ resolve_visual_encoder_outputs) from vllm.sequence import SequenceData +from .vision import VisionEncoderInfo + def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -149,6 +151,29 @@ def input_processor_for_clip( multi_modal_placeholders={"image": ranges}) +class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return get_clip_image_feature_size(self.vision_config) + + def get_max_image_tokens(self) -> int: + return get_max_clip_image_tokens(self.vision_config) + + def get_num_patches(self) -> int: + return get_clip_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) + + def get_image_size(self) -> int: + return self.vision_config.image_size + + # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa class CLIPVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7fb8c5d1ab09c..3680d01725238 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -76,7 +76,7 @@ def _get_image_target_size(self) -> ImageSize: return ImageSize(width=target_size["width"], height=target_size["height"]) - def _get_image_grid_size( + def _get_image_feature_grid_size( self, *, image_width: int, @@ -99,7 +99,7 @@ def _get_image_grid_size( def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: target_width, target_height = self._get_image_target_size() - max_ncols, max_nrows = self._get_image_grid_size( + max_ncols, max_nrows = self._get_image_feature_grid_size( image_width=target_width, image_height=target_height, ) @@ -172,7 +172,7 @@ def get_replacement_fuyu(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - ncols, nrows = self._get_image_grid_size( + ncols, nrows = self._get_image_feature_grid_size( image_width=image_size.width, image_height=image_size.height, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 808e61edb6fb4..f96086635843b 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,6 +1,7 @@ +from abc import abstractmethod from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, - Tuple, TypedDict, Union) +from typing import (Final, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn @@ -12,7 +13,6 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -23,23 +23,22 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement, + InputProcessingContext, + MultiModalDataItems, ProcessingCache, + ProcessorInputs, PromptReplacement, full_groupby_modality) from vllm.sequence import IntermediateTensors -from .clip import (CLIPVisionModel, dummy_image_for_clip, - get_max_clip_image_tokens) +from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - get_max_pixtral_hf_image_tokens, - get_pixtral_hf_image_feature_size) -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - get_max_siglip_image_tokens) +from .pixtral import (PixtralHFVisionModel, + get_pixtral_hf_image_feature_grid_size) +from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import vision_encoder_info class LlavaImagePixelInputs(TypedDict): @@ -94,39 +93,163 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -def get_max_llava_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config +class LlavaLikeConfig(Protocol): + vision_config: Final[PretrainedConfig] + vision_feature_select_strategy: Final[str] + vision_feature_layer: Final[Union[int, List[int]]] - if isinstance(vision_config, CLIPVisionConfig): - num_image_tokens = get_max_clip_image_tokens(vision_config) - elif isinstance(vision_config, SiglipVisionConfig): - num_image_tokens = get_max_siglip_image_tokens(vision_config) - elif isinstance(vision_config, PixtralVisionConfig): - num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - strategy = hf_config.vision_feature_select_strategy - if strategy == "default": - return num_image_tokens - 1 - elif strategy == "full": - return num_image_tokens - else: - raise ValueError(f"Unexpected select feature strategy: {strategy}") +class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): + + def __init__(self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + super().__init__(ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks) + vision_config = self._get_hf_config().vision_config + self._vision_encoder_info = vision_encoder_info(vision_config) -class LlavaMultiModalProcessor(BaseMultiModalProcessor): + @abstractmethod + def _get_hf_config(self) -> LlavaLikeConfig: + raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} + def _apply_feature_select_strategy( + self, + strategy: str, + encoder_num_image_tokens: int, + ) -> int: + if strategy == "default": + return encoder_num_image_tokens - 1 + if strategy == "full": + return encoder_num_image_tokens + + msg = f"Unexpected feature select strategy: {strategy!r}" + raise NotImplementedError(msg) + + def _get_max_image_tokens(self) -> int: + hf_config = self._get_hf_config() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + self._vision_encoder_info.get_max_image_tokens(), + ) + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: - return {"image": get_max_llava_image_tokens(self.ctx)} + return {"image": self._get_max_image_tokens()} + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) - def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: - return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) + def _get_dummy_image_size(self) -> ImageSize: + image_size = self._vision_encoder_info.get_image_size() + return ImageSize(image_size, image_size) + + @abstractmethod + def _get_image_token(self) -> str: + raise NotImplementedError + + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + image_token = self._get_image_token() + target_width, target_height = self._get_dummy_image_size() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): + + def _get_hf_config(self) -> LlavaConfig: + return self.ctx.get_hf_config(LlavaConfig) + + def _get_hf_processor(self) -> LlavaProcessor: + return self.ctx.get_hf_processor(LlavaProcessor) + + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token + + def _get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + self._vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self._get_hf_config() + image_token_id = hf_config.image_token_index + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + num_image_tokens = self._get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor): + + def _get_hf_config(self) -> LlavaConfig: + return self.ctx.get_hf_config(LlavaConfig) + + def _get_hf_processor(self) -> PixtralProcessor: + return self.ctx.get_hf_processor(PixtralProcessor) + + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token def _call_hf_processor( self, @@ -140,119 +263,82 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - # NOTE: pixel_values=None for MLlavaProcessor pixel_values = processed_outputs.get("pixel_values") if pixel_values is not None: images = mm_data["images"] assert isinstance(images, list) - if isinstance(self._get_hf_processor(), PixtralProcessor): - # Original output: (1, num_images, C, H, W) - # New output: (num_images, C, H, W) - assert (isinstance(pixel_values, list) - and len(pixel_values) == 1) - assert (isinstance(pixel_values[0], list) - and len(pixel_values[0]) == len(images)) + # Original output: (1, num_images, C, H, W) + # New output: (num_images, C, H, W) + assert (isinstance(pixel_values, list) and len(pixel_values) == 1) + assert (isinstance(pixel_values[0], list) + and len(pixel_values[0]) == len(images)) - processed_outputs["pixel_values"] = pixel_values[0] + processed_outputs["pixel_values"] = pixel_values[0] return processed_outputs - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), - ) - def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.ctx.get_hf_config(LlavaConfig) + hf_config = self._get_hf_config() image_token_id = hf_config.image_token_index processor = self._get_hf_processor() - if isinstance(processor, PixtralProcessor): - image_token = processor.image_token - image_break_token = processor.image_break_token - image_end_token = processor.image_end_token - - vision_config = hf_config.vision_config - assert isinstance(vision_config, PixtralVisionConfig) - - def get_replacement_pixtral(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - - ( - num_width_tokens, - num_height_tokens, - ) = get_pixtral_hf_image_feature_size( - vision_config, - image_width=image_size.width, - image_height=image_size.height, - ) - - tokens = ([image_token] * num_width_tokens + - [image_break_token]) * num_height_tokens - tokens[-1] = image_end_token - - return "".join(tokens) - - return [ - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_replacement_pixtral, - ), - ] - - max_image_tokens = get_max_llava_image_tokens(self.ctx) + image_token = processor.image_token + image_break_token = processor.image_break_token + image_end_token = processor.image_end_token + + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = get_pixtral_hf_image_feature_grid_size( + vision_config, + image_width=image_size.width, + image_height=image_size.height, + ) + + tokens = ([image_token] * ncols + [image_break_token]) * nrows + tokens[-1] = image_end_token + + return "".join(tokens) return [ PromptReplacement( modality="image", target=[image_token_id], - replacement=[image_token_id] * max_image_tokens, - ) + replacement=get_replacement, + ), ] - def _get_dummy_mm_inputs( - self, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - hf_config = self.ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config - num_images = mm_counts.get("image", 0) - - if isinstance(vision_config, CLIPVisionConfig): - data = dummy_image_for_clip(vision_config, num_images) - elif isinstance(vision_config, SiglipVisionConfig): - data = dummy_image_for_siglip(vision_config, num_images) - elif isinstance(vision_config, PixtralVisionConfig): - data = dummy_image_for_pixtral_hf(vision_config, num_images) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - hf_processor = self._get_hf_processor() - image_token = hf_processor.image_token +def _build_llava_or_pixtral_hf_processor( + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True, +) -> BaseLlavaMultiModalProcessor: + hf_config = ctx.get_hf_config(LlavaConfig) - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=data, + if isinstance(hf_config.vision_config, PixtralVisionConfig): + return PixtralHFMultiModalProcessor( + ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks, ) - -class LlavaLikeConfig(Protocol): - vision_config: PretrainedConfig - vision_feature_layer: Union[int, List[int]] + return LlavaMultiModalProcessor( + ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: @@ -330,7 +416,7 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -596,7 +682,8 @@ def apply( ) -> MultiModalInputsV2: hf_config = self.ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index - max_image_tokens = get_max_llava_image_tokens(self.ctx) + + num_image_tokens = self._get_num_image_tokens() result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) @@ -609,14 +696,14 @@ def apply( def get_replacement_mantis(item_idx: int): return "".join([ f"(image {item_idx+1}: ", # 7 tokens - "" * max_image_tokens, + "" * num_image_tokens, ")", # 3 tokens ]) mantis_repls = self._bind_prompt_replacements([ PromptReplacement( modality="image", - target=[image_token_id] * max_image_tokens, + target=[image_token_id] * num_image_tokens, replacement=get_replacement_mantis, ) ]) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 5e70c11363c83..24debd1cbf3fe 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -4,31 +4,25 @@ import torch import torch.nn as nn -from PIL import Image -from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig +from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.parse import ImageSize from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of -from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_clip_image_feature_size, - get_clip_patch_grid_length, input_processor_for_clip) +from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_siglip_image_feature_size, - get_siglip_patch_grid_length, input_processor_for_siglip) +from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector, + init_vision_tower_for_llava) +from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -65,218 +59,127 @@ class LlavaNextImageEmbeddingInputs(TypedDict): LlavaNextImageEmbeddingInputs] -# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 -def _get_llava_next_num_unpadded_features( - original_height: int, - original_width: int, - npatches: int, - num_patch_height: int, - num_patch_width: int, -) -> Tuple[int, int]: - current_height = npatches * num_patch_height - current_width = npatches * num_patch_width - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) - padding = (current_height - new_height) // 2 - current_height -= 2 * padding - else: - scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) - padding = (current_width - new_width) // 2 - current_width -= 2 * padding - - unpadded_features = current_height * current_width - newline_features = current_height - return (unpadded_features, newline_features) - - -# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 -def get_llava_next_image_feature_size( - hf_config: LlavaNextConfig, - *, - input_height: int, - input_width: int, -) -> int: - vision_config = hf_config.vision_config - - if isinstance(vision_config, CLIPVisionConfig): - num_patches = get_clip_patch_grid_length( - image_size=vision_config.image_size, - patch_size=vision_config.patch_size, - ) - base_feature_size = get_clip_image_feature_size(vision_config) - elif isinstance(vision_config, SiglipVisionConfig): - num_patches = get_siglip_patch_grid_length( - image_size=vision_config.image_size, - patch_size=vision_config.patch_size, - ) - base_feature_size = get_siglip_image_feature_size(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - strategy = hf_config.vision_feature_select_strategy - if strategy == "default": - base_feature_size -= 1 - elif strategy == "full": - pass - else: - raise ValueError(f"Unexpected select feature strategy: {strategy}") +class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_size=(input_height, input_width), - grid_pinpoints=hf_config.image_grid_pinpoints, - patch_size=vision_config.image_size, - ) - - ( - unpadded_feature_size, - newline_feature_size, - ) = _get_llava_next_num_unpadded_features(input_height, input_width, - num_patches, num_patch_height, - num_patch_width) - - return unpadded_feature_size + newline_feature_size + base_feature_size - - -def get_max_llava_next_image_tokens(ctx: InputContext): - """Compute the max feature size for all possible image grid pinpoints.""" - return _get_pinpoint_with_largest_features(ctx)[0] - - -def _get_pinpoint_with_largest_features( - ctx: InputContext) -> Tuple[int, Tuple[int, int]]: - """Get the grid pinpoint with the largest features & its feature size.""" - hf_config = ctx.get_hf_config(LlavaNextConfig) - largest_feature_size = 0 - largest_feature_pinpoint = None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = get_llava_next_image_feature_size( - hf_config, - input_height=height, - input_width=width, - ) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = (height, width) - if not largest_feature_size or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - return largest_feature_size, largest_feature_pinpoint - - -def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(LlavaNextConfig) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx) - max_feat_height, max_feat_width = pinpoint - - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) + def _get_hf_config(self) -> LlavaNextConfig: + return self.ctx.get_hf_config(LlavaNextConfig) + + def _get_hf_processor(self) -> LlavaNextProcessor: + return self.ctx.get_hf_processor(LlavaNextProcessor) - mm_data = dummy_image_for_clip( - vision_config, - num_images, - image_width_override=max_feat_width, - image_height_override=max_feat_height, + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), ) - return DummyData(seq_data, mm_data, ranges) - elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, + def _get_max_image_tokens(self) -> int: + largest_feature_size, _ = self._get_pinpoint_with_most_features() + return largest_feature_size + + def _get_dummy_image_size(self) -> ImageSize: + _, pinpoint = self._get_pinpoint_with_most_features() + return pinpoint + + # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 + def _get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + + base_feature_size = self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + self._vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), ) + num_patches = self._vision_encoder_info.get_num_patches() - mm_data = dummy_image_for_siglip( - vision_config, - num_images, - image_width_override=max_feat_width, - image_height_override=max_feat_height, + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(image_height, image_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=self._vision_encoder_info.get_image_size(), ) - return DummyData(seq_data, mm_data, ranges) + ( + unpadded_feature_size, + newline_feature_size, + ) = self._get_num_unpadded_features( + original_height=image_height, + original_width=image_width, + npatches=num_patches, + num_patch_height=num_patch_height, + num_patch_width=num_patch_width, + ) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return unpadded_feature_size + newline_feature_size + base_feature_size + # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 + def _get_num_unpadded_features( + self, + *, + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, + ) -> tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + current_height -= 2 * padding + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + current_width -= 2 * padding -def input_processor_for_llava_next(ctx: InputContext, - inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) - model_config = ctx.model_config - hf_config = ctx.get_hf_config(LlavaNextConfig) - vision_config = hf_config.vision_config + def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: + """ + Get the grid pinpoint with the most features and + the corresponding feature size. + """ + hf_config = self._get_hf_config() - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - width, height = image_data.size + largest_feature_size, largest_feature_pinpoint = 0, None + for (height, width) in hf_config.image_grid_pinpoints: + feat_size = self._get_num_image_tokens(image_width=width, + image_height=height) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) - image_feature_size = get_llava_next_image_feature_size( - hf_config, - input_height=height, - input_width=width, - ) - elif is_list_of(image_data, Image.Image): - image_feature_size = [ - get_llava_next_image_feature_size(hf_config, - input_height=img.height, - input_width=img.width) - for img in image_data - ] - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - vision_config = hf_config.vision_config - - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return largest_feature_size, largest_feature_pinpoint -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) +@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @@ -507,7 +410,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: assert self.vision_tower is not None pixel_values = inputs["data"] diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 2bce13792a88d..d7233bd6028ed 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -38,6 +38,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import VisionEncoderInfo try: from xformers import ops as xops @@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int, return image_size // patch_size -def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int: - grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size, - patch_size=patch_size) - return grid_length * grid_length +def get_pixtral_hf_image_feature_size( + *, + image_size: int, + patch_size: int, +) -> int: + grid_length = get_pixtral_hf_patch_grid_length( + image_size=image_size, + patch_size=patch_size, + ) + + # Consider the image_break_token + return (grid_length + 1) * grid_length def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: @@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf( return {"image": image if num_images == 1 else [image] * num_images} -def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, - image_width: int, - image_height: int) -> Tuple[int, int]: - # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 - # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501 - max_width, max_height = hf_config.image_size, hf_config.image_size - patch_width, patch_height = hf_config.patch_size, hf_config.patch_size +# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 +# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 +def get_pixtral_hf_image_feature_grid_size( + hf_config: PixtralVisionConfig, + *, + image_width: int, + image_height: int, +) -> tuple[int, int]: + max_width = max_height = hf_config.image_size + patch_width = patch_height = hf_config.patch_size ratio = max(image_width / max_width, image_height / max_height) @@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, image_width = int(math.ceil(image_width / ratio)) image_height = int(math.ceil(image_height / ratio)) - num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens( + nrows, ncols = _get_pixtral_hf_num_image_tokens( (image_height, image_width), (patch_height, patch_width), - ) + ) # type: ignore + + return ncols, nrows + + +class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return get_pixtral_hf_image_feature_size( + image_size=self.vision_config.image_size, + patch_size=self.get_image_size(), + ) + + def get_max_image_tokens(self) -> int: + return get_max_pixtral_hf_image_tokens(self.vision_config) + + def get_num_patches(self) -> int: + return get_pixtral_hf_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) - return num_width_tokens, num_height_tokens + def get_image_size(self) -> int: + return self.vision_config.image_size class PixtralHFMLP(nn.Module): diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 6fb9e2cc4584f..115eaaac900e0 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -28,6 +28,8 @@ resolve_visual_encoder_outputs) from vllm.sequence import SequenceData +from .vision import VisionEncoderInfo + def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: # Since interpolation is applied, the image size need not be divisible @@ -156,6 +158,29 @@ def input_processor_for_siglip( multi_modal_placeholders={"image": ranges}) +class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return get_siglip_image_feature_size(self.vision_config) + + def get_max_image_tokens(self) -> int: + return get_max_siglip_image_tokens(self.vision_config) + + def get_num_patches(self) -> int: + return get_siglip_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) + + def get_image_size(self) -> int: + return self.vision_config.image_size + + # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa class SiglipVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 269b66806adf4..31017f16d3c97 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -373,7 +373,7 @@ def embed_multimodal( input_ids: torch.Tensor, multimodal_token_id: int, get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]], + multimodal_embeds: NestedTensors, ) -> torch.Tensor: """ Embed token IDs and multimodal inputs and combine their embeddings. diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py new file mode 100644 index 0000000000000..65a773480d2a1 --- /dev/null +++ b/vllm/model_executor/models/vision.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from transformers import PretrainedConfig + +_C = TypeVar("_C", bound=PretrainedConfig) + + +class VisionEncoderInfo(ABC, Generic[_C]): + + def __init__(self, vision_config: _C) -> None: + super().__init__() + + self.vision_config = vision_config + + @abstractmethod + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + raise NotImplementedError + + @abstractmethod + def get_max_image_tokens(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_num_patches(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_size(self) -> int: + raise NotImplementedError + + +def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: + # Avoid circular imports + from .clip import CLIPEncoderInfo, CLIPVisionConfig + from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig + from .siglip import SiglipEncoderInfo, SiglipVisionConfig + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPEncoderInfo(vision_config) + if isinstance(vision_config, PixtralVisionConfig): + return PixtralHFEncoderInfo(vision_config) + if isinstance(vision_config, SiglipVisionConfig): + return SiglipEncoderInfo(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg)