diff --git a/notebooks/openvino/sentence_transformer_quantization.ipynb b/notebooks/openvino/sentence_transformer_quantization.ipynb index 714544aa9a..067e1c9913 100644 --- a/notebooks/openvino/sentence_transformer_quantization.ipynb +++ b/notebooks/openvino/sentence_transformer_quantization.ipynb @@ -170,9 +170,11 @@ ], "source": [ "from functools import partial\n", - "import datasets\n", + "\n", "from transformers import AutoTokenizer\n", - "from optimum.intel import OVModelForFeatureExtraction, OVQuantizer, OVQuantizationConfig, OVConfig\n", + "\n", + "from optimum.intel import OVConfig, OVModelForFeatureExtraction, OVQuantizationConfig, OVQuantizer\n", + "\n", "\n", "MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n", "base_model_path = \"all-MiniLM-L6-v2\"\n", @@ -187,6 +189,7 @@ "\n", "quantizer = OVQuantizer.from_pretrained(model)\n", "\n", + "\n", "def preprocess_function(examples, tokenizer):\n", " return tokenizer(examples[\"sentence\"], padding=\"max_length\", max_length=384, truncation=True)\n", "\n", @@ -225,9 +228,9 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import Pipeline\n", - "import torch.nn.functional as F\n", "import torch\n", + "import torch.nn.functional as F\n", + "from transformers import Pipeline\n", "\n", "\n", "# copied from the model card \"sentence-transformers/all-MiniLM-L6-v2\"\n", @@ -296,6 +299,7 @@ "from datasets import load_dataset\n", "from evaluate import load\n", "\n", + "\n", "eval_dataset = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n", "metric = load(\"glue\", \"stsb\")" ] diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index d9c0165d98..4ec92a6302 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -89,6 +89,8 @@ Phi3ModelPatcher, Phi3VisionImageEmbeddingsPatcher, QwenModelPatcher, + Qwen2VLLanguageModelPatcher, + Qwen2VLVisionEmbMergerPatcher, RotaryEmbPatcher, UpdateCausalMaskModelPatcher, XverseModelPatcher, @@ -106,9 +108,13 @@ def init_model_configs(): "transformers", "LlavaNextForConditionalGeneration", ) - TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[ - "image-text-to-text" - ] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"] + TasksManager._CUSTOM_CLASSES[("pt", "qwen2-vl", "image-text-to-text")] = ( + "transformers", + "Qwen2VLForConditionalGeneration", + ) + TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["image-text-to-text"] = ( + TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"] + ) supported_model_types = [ "_SUPPORTED_MODEL_TYPE", @@ -1288,18 +1294,26 @@ def patch_model_for_export( class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig): - def __init__(self, export_config): + def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None): self.orig_export_config = export_config + if dummy_input_generator is not None: + export_config.DUMMY_INPUT_GENERATOR_CLASSES = ( + dummy_input_generator, + ) + export_config.DUMMY_INPUT_GENERATOR_CLASSES self.DUMMY_INPUT_GENERATOR_CLASSES = export_config.DUMMY_INPUT_GENERATOR_CLASSES self.DEFAULT_ONNX_OPSET = export_config.DEFAULT_ONNX_OPSET self.DUMMY_PKV_GENERATOR_CLASS = export_config.DUMMY_PKV_GENERATOR_CLASS self._config = export_config._config self._normalized_config = export_config._normalized_config self.use_past = export_config.use_past + self.patcher_cls = patcher_cls + self.input_info_upd = inputs_update def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": + if self.patcher_cls is not None: + return self.patcher_cls(self, model, model_kwargs=model_kwargs) # Refer to DecoderModelPatcher. return self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs) @@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]: orig_inputs = self.orig_export_config.inputs input_ids_config = orig_inputs.pop("input_ids") orig_inputs["inputs_embeds"] = input_ids_config + if self.input_info_upd is not None: + orig_inputs.update(self.input_info_upd) return orig_inputs def generate_dummy_inputs(self, framework: str = "pt", **kwargs): @@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt return export_config -def get_vlm_text_generation_config(model_type, model_config, int_dtype, float_dtype): +def get_vlm_text_generation_config( + model_type, + model_config, + int_dtype, + float_dtype, + model_patcher=None, + dummy_input_generator=None, + inputs_update=None, +): internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype) - export_config = LMInputEmbedsConfigHelper(internal_export_config) + export_config = LMInputEmbedsConfigHelper( + internal_export_config, + patcher_cls=model_patcher, + dummy_input_generator=dummy_input_generator, + inputs_update=inputs_update, + ) export_config._normalized_config = internal_export_config._normalized_config return export_config @@ -1820,9 +1849,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int img_ids_height = self.height // 2 img_ids_width = self.width // 2 return self.random_int_tensor( - [self.batch_size, img_ids_height * img_ids_width, 3] - if is_diffusers_version("<", "0.31.0") - else [img_ids_height * img_ids_width, 3], + ( + [self.batch_size, img_ids_height * img_ids_width, 3] + if is_diffusers_version("<", "0.31.0") + else [img_ids_height * img_ids_width, 3] + ), min_value=0, max_value=min(img_ids_height, img_ids_width), framework=framework, @@ -2259,3 +2290,218 @@ def patch_model_for_export( if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs) return super().patch_model_for_export(model, model_kwargs) + + +class DummyQwen2VLLMInputGenerator(DummyTextInputGenerator): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + generated_input = super().generate(input_name, framework, int_dtype, float_dtype) + if input_name == "position_ids": + return generated_input.unsqueeze(0).expand(3, -1, -1) + return generated_input + + +class DummyQwen2VLVisionEMbedInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states",) + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = 1, + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = 420, + height: int = 420, + **kwargs, + ): + self.batch_size = batch_size + self.height = height + self.width = width + self.num_channels = num_channels + self.temporal_patch_size = normalized_config.config.temporal_patch_size + self.patch_size = normalized_config.config.patch_size + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size + grid_t = self.batch_size + shape = [ + grid_t * grid_h * grid_w, + self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size, + ] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +class DummyQwen2VLVisionEmbedMergerInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states", "attention_mask", "rotary_pos_emb") + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = 1, + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = 420, + height: int = 420, + **kwargs, + ): + self.batch_size = batch_size + self.height = height + self.width = width + self.num_channels = num_channels + self.temporal_patch_size = normalized_config.config.temporal_patch_size + self.patch_size = normalized_config.config.patch_size + self.embed_dim = normalized_config.config.embed_dim + self.num_heads = normalized_config.config.num_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size + grid_t = self.batch_size + + if input_name == "hidden_states": + return self.random_float_tensor( + [grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype + ) + + if input_name == "attention_mask": + return self.random_mask_tensor( + [1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype + ) + + if input_name == "rotary_pos_emb": + dim = self.embed_dim // self.num_heads // 2 + return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype) + + +class Qwen2VLConfigBehavior(str, enum.Enum): + LANGUAGE = "language" + VISION_EMBEDDINGS = "vision_embeddings" + VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger" + TEXT_EMBEDDINGS = "text_embeddings" + + +@register_in_tasks_manager("qwen2-vl", *["image-text-to-text"], library_name="transformers") +class Qwen2VLOpenVINOConfig(OnnxConfig): + SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior] + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,) + MIN_TRANSFORMERS_VERSION = version.parse("4.45.0") + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: Qwen2VLConfigBehavior = Qwen2VLConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + self._behavior = behavior + self._orig_config = config + if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,) + if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedMergerInputGenerator,) + + @staticmethod + def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior): + behavior = Qwen2VLConfigBehavior(behavior) + + if behavior == Qwen2VLConfigBehavior.LANGUAGE: + return model + + if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS: + vision_embeddings = model.visual.patch_embed + vision_embeddings.config = model.config.vision_config + return vision_embeddings + + if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + vision_emb_merger = model.visual + vision_emb_merger.config = model.config.vision_config + return vision_emb_merger + + if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.model.embed_tokens + text_embedding.config = model.config + return text_embedding + + def with_behavior( + self, + behavior: Union[str, Qwen2VLConfigBehavior], + ): + """ + Creates a config for different behaviour. + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior): + behavior = Qwen2VLConfigBehavior(behavior) + + if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) + + if behavior == Qwen2VLConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config( + "qwen2", + self._orig_config, + self.int_dtype, + self.float_dtype, + model_patcher=Qwen2VLLanguageModelPatcher, + dummy_input_generator=DummyQwen2VLLMInputGenerator, + inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}}, + ) + + if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + model_kwargs = model_kwargs or {} + if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return Qwen2VLVisionEmbMergerPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: + return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}} + if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return { + "hidden_states": {0: "sequence_length"}, + "attention_mask": {1: "sequence_length", 2: "sequence_length"}, + "rotary_pos_emb": {0: "sequence_length"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior in [Qwen2VLConfigBehavior.VISION_EMBEDDINGS, Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER]: + return {"last_hidden_state": {0: "seq_len"}} + return {} diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index c71cbfe003..af830849eb 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -18,9 +18,11 @@ import math import types from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from optimum.exporters.onnx.base import OnnxConfig import torch import torch.nn.functional as F +from transformers import PreTrainedModel, TFPreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.utils import is_tf_available @@ -421,9 +423,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -2058,9 +2060,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -3378,3 +3380,103 @@ def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.model.layers: block.self_attn.forward = block.self_attn._orig_forward + + +class Qwen2VLLanguageModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: OnnxConfig, + model: PreTrainedModel | TFPreTrainedModel, + model_kwargs: Dict[str, Any] | None = None, + ): + + model.__orig_forward = model.forward + + def forward_wrap( + self, + attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + input_ids=None, + ): + from transformers.cache_utils import DynamicCache + + new_past_key_values = DynamicCache.from_legacy_cache(past_key_values) + result = self.__orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=new_past_key_values, + inputs_embeds=inputs_embeds, + ) + if past_key_values is not None: + result["past_key_values"] = result["past_key_values"].to_legacy_cache() + return result + + model.forward = types.MethodType(forward_wrap, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + +class Qwen2VLVisionEmbMergerPatcher(ModelPatcher): + def __init__( + self, + config: OnnxConfig, + model: PreTrainedModel | TFPreTrainedModel, + model_kwargs: Dict[str, Any] | None = None, + ): + model.__orig_forward = model.forward + + def image_embed_forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor + ) -> torch.Tensor: + for blk in self.blocks: + hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) + return self.merger(hidden_states) + + model.forward = types.MethodType(image_embed_forward, model) + super().__init__(config, model, model_kwargs) + + def __enter__(self): + def sdpa_attn_forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision + + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + for block in self._model.blocks: + block._orig_forward = block.forward + block.forward = types.MethodType(block_forward, block) + block.attn._orig_forward = block.attn.forward + block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + for block in self._model.blocks: + block.forward = block._orig_forward + block.attn.forward = block.attn._orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 302fee5b65..5242db1d1a 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -216,7 +216,7 @@ def get_submodels(model): return custom_export, fn_get_submodels -MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v"] +MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v", "qwen2-vl"] def save_config(config, save_dir): diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index aa9b23a0ee..7ca3a0cf15 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -1,10 +1,11 @@ import copy +from dataclasses import dataclass import logging import os import warnings from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, Any import numpy as np import openvino as ov @@ -22,6 +23,7 @@ PreTrainedTokenizer, ) from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.utils import ModelOutput from ...exporters.openvino import main_export from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name @@ -216,7 +218,12 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None: key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name() ] self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} - self._main_input = "images" if model_has_input_output_name(self.model, "images") else "pixel_values" + if model_has_input_output_name(self.model, "images"): + self._main_input = "images" + elif model_has_input_output_name(self.model, "hidden_states"): + self._main_input = "hidden_states" + else: + self._main_input = "pixel_values" def forward(self, pixel_values, **kwargs): self._compile() @@ -269,6 +276,7 @@ def forward(self, img_features): "language_model": OVModelWithEmbedForCausalLM, "vision_embeddings": OVVisionEmbedding, "vision_projection": OVVisionProjection, + "vision_embeddings_merger": OVVisionEmbedding, } @@ -676,6 +684,10 @@ def forward( position_ids=None, image_bound=None, tgt_sizes=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + rope_deltas=None, **kwargs, ): inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings( @@ -687,6 +699,10 @@ def forward( past_key_values=past_key_values, image_bound=image_bound, tgt_sizes=tgt_sizes, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, **kwargs, ) return self.language_model.forward( @@ -781,6 +797,9 @@ def prepare_inputs_for_generation( "image_sizes": image_sizes, "image_bound": kwargs.get("image_bound"), "tgt_sizes": kwargs.get("tgt_sizes"), + "pixel_values_videos": kwargs.get("pixel_values_videos"), + "image_grid_thw": kwargs.get("image_grid_thw"), + "video_grid_thw": kwargs.get("video_grid_thw"), } ) return model_inputs @@ -2036,6 +2055,393 @@ def preprocess_inputs( return inputs +@dataclass +class QWen2VLModelOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + rope_deltas: Optional[torch.FloatTensor] = None + + +class _OVQwen2VLForCausalLM(OVModelForVisualCausalLM): + additional_parts = ["vision_embeddings_merger"] + + def __init__( + self, + language_model: ov.Model, + text_embeddings: ov.Model, + vision_embeddings: ov.Model, + config: PretrainedConfig = None, + device: str = "CPU", + dynamic_shapes: bool = True, + ov_config: Optional[Dict[str, str]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + super().__init__( + language_model=language_model, + text_embeddings=text_embeddings, + vision_embeddings=vision_embeddings, + config=config, + device=device, + dynamic_shapes=dynamic_shapes, + ov_config=ov_config, + model_save_dir=model_save_dir, + quantization_config=quantization_config, + **kwargs, + ) + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding + + self._rotary_pos_emb = VisionRotaryEmbedding( + self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2 + ) + except ImportError: + raise ValueError( + f"Initialization model for {self.config.model_type} required at least transformers >= 4.45" + ) + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if image_grid_thw is not None or video_grid_thw is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + rope_deltas = kwargs.get("rope_deltas", None) + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "rope_deltas": rope_deltas, + } + ) + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + if getattr(outputs, "rope_deltas", None) is not None: + model_kwargs["rope_deltas"] = outputs.rope_deltas + + return model_kwargs + + def get_vision_embeddings(self, pixel_values, grid_thw, **kwargs): + hidden_states = self.vision_embeddings(pixel_values)[0] + rotary_pos_emb = self.rot_pos_emb(grid_thw) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) + causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + + causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf")) + + res = self.vision_embeddings_merger( + pixel_values=hidden_states, attention_mask=causal_mask, rotary_pos_emb=rotary_pos_emb + )[0] + return res + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.config.vision_config.spatial_merge_size, + self.config.vision_config.spatial_merge_size, + w // self.config.vision_config.spatial_merge_size, + self.config.vision_config.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.config.vision_config.spatial_merge_size, + self.config.vision_config.spatial_merge_size, + w // self.config.vision_config.spatial_merge_size, + self.config.vision_config.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self._rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_multimodal_embeddings( + self, + input_ids, + pixel_values=None, + attention_mask=None, + position_ids=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + + inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids)) + if pixel_values is not None and input_ids.shape[1] != 1: + image_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values, image_grid_thw)) + image_mask = input_ids == self.config.image_token_id + inputs_embeds[image_mask] = image_embeds + if pixel_values_videos is not None and input_ids.shape[1] != 1: + pixel_values_videos = pixel_values_videos + video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw)) + video_mask = input_ids == self.config.video_token_id + inputs_embeds[video_mask] = video_embeds + return inputs_embeds, attention_mask, position_ids + + def forward( + self, + input_ids, + pixel_values=None, + past_key_values=None, + inputs_embeds=None, + image_sizes=None, + attention_mask=None, + position_ids=None, + image_bound=None, + tgt_sizes=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + rope_deltas=None, + **kwargs, + ): + result = super().forward( + input_ids, + pixel_values, + past_key_values, + inputs_embeds, + image_sizes, + attention_mask, + position_ids, + image_bound, + tgt_sizes, + pixel_values_videos, + image_grid_thw, + video_grid_thw, + rope_deltas, + **kwargs, + ) + final_result = QWen2VLModelOutputWithPast( + logits=result.logits, past_key_values=result.past_key_values, rope_deltas=rope_deltas + ) + return final_result + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if image is not None: + conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + }, + {"type": "text", "text": text}, + ], + } + ] + else: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + ], + } + ] + text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + inputs = processor(images=image, text=text_prompt, return_tensors="pt") + return inputs + + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, @@ -2043,4 +2449,5 @@ def preprocess_inputs( "llava-qwen2": _OVNanoLlavaForCausalLM, "phi3_v": _OVPhi3VisionForCausalLM, "internvl_chat": _OVInternVLForCausalLM, + "qwen2_vl": _OVQwen2VLForCausalLM, }