From 54073c8d32c846469116639c8a2742facbe4148c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Nov 2024 10:02:22 +0000 Subject: [PATCH] [WIP] Support for qwen2vl models --- src/configs.js | 1 + src/models.js | 108 ++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/src/configs.js b/src/configs.js index 459d2bc86..9cd7a0b54 100644 --- a/src/configs.js +++ b/src/configs.js @@ -102,6 +102,7 @@ function getNormalizedConfig(config) { case 'mistral': case 'starcoder2': case 'qwen2': + case 'qwen2_vl': mapping['num_heads'] = 'num_key_value_heads'; mapping['num_layers'] = 'num_hidden_layers'; mapping['hidden_size'] = 'hidden_size'; diff --git a/src/models.js b/src/models.js index ba294fec9..4438f4ca5 100644 --- a/src/models.js +++ b/src/models.js @@ -581,11 +581,11 @@ async function imageTextToTextForward(self, { if (!inputs_embeds) { // 1. Extract the input embeddings - inputs_embeds = await self.encode_text({ input_ids }); + inputs_embeds = await self.encode_text({ input_ids, ...kwargs }); // 2. Possibly, merge text and images if (pixel_values && input_ids.dims[1] !== 1) { - const image_features = await self.encode_image({ pixel_values }); + const image_features = await self.encode_image({ pixel_values, ...kwargs }); ({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({ image_features, @@ -606,6 +606,16 @@ async function imageTextToTextForward(self, { } } + if (!position_ids) { + + if (self.config.model_type === 'qwen2_vl') { + // Special case for qwen2_vl models + // @ts-ignore + const { image_grid_thw, video_grid_thw } = kwargs; + [position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + } + } + const outputs = await decoderForward(self, { inputs_embeds, past_key_values, @@ -617,15 +627,20 @@ async function imageTextToTextForward(self, { return outputs; } +/** + * If the model supports providing position_ids, we create position_ids on the fly for batch generation, + * by computing the cumulative sum of the attention mask along the sequence length dimension. + * + * Equivalent to: + * ```python + * position_ids = attention_mask.long().cumsum(-1) - 1 + * position_ids.masked_fill_(attention_mask == 0, 1) + * if past_key_values: + * position_ids = position_ids[:, -input_ids.shape[1] :] + * ``` + */ function createPositionIds(model_inputs, past_key_values = null) { - // If the model supports providing position_ids, we create position_ids on the fly for batch generation, - // by computing the cumulative sum of the attention mask along the sequence length dimension. - // - // Equivalent to: - // position_ids = attention_mask.long().cumsum(-1) - 1 - // position_ids.masked_fill_(attention_mask == 0, 1) - // if past_key_values: - // position_ids = position_ids[:, -input_ids.shape[1] :] + const { input_ids, inputs_embeds, attention_mask } = model_inputs; const [bz, seq_len] = attention_mask.dims; @@ -3968,6 +3983,78 @@ export class Qwen2Model extends Qwen2PreTrainedModel { } export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { } ////////////////////////////////////////////////// +export class Qwen2VLPreTrainedModel extends PreTrainedModel { + forward_params = [ + // Text inputs + 'input_ids', + 'attention_mask', + 'position_ids', + 'past_key_values', + + // Vision inputs + 'pixel_values', + 'image_grid_thw', + ]; +} +export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel { + + /** + * Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + * + * Explanation: + * Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + * + * For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + * Examples: + * input_ids: [T T T T T], here T is for text. + * temporal position_ids: [0, 1, 2, 3, 4] + * height position_ids: [0, 1, 2, 3, 4] + * width position_ids: [0, 1, 2, 3, 4] + * + * For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + * and 1D rotary position embeddin for text part. + * Examples: + * Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + * input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + * vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + * vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + * vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + * text temporal position_ids: [3, 4, 5, 6, 7] + * text height position_ids: [3, 4, 5, 6, 7] + * text width position_ids: [3, 4, 5, 6, 7] + * Here we calculate the text start position_ids as the max vision position_ids plus 1. + * + * @param {Tensor} input_ids Indices of input sequence tokens in the vocabulary. Tensor of shape `(batch_size, sequence_length)`. + * @param {Tensor} image_grid_thw (Optional) The temporal, height and width of feature shape of each image in LLM. Tensor of shape `(num_images, 3)`. + * @param {Tensor} video_grid_thw (Optional) The temporal, height and width of feature shape of each video in LLM. Tensor of shape `(num_videos, 3)`. + * @param {Tensor} attention_mask (Optional) Mask to avoid performing attention on padding token indices. Tensor of shape `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`: + * - 1 for tokens that are **not masked**, + * - 0 for tokens that are **masked**. + * @returns {[Tensor, Tensor]} [position_ids, mrope_position_deltas] with: + * - position_ids: Tensor of shape `(3, batch_size, sequence_length)`. + * - mrope_position_deltas: Tensor of shape `(batch_size)`. + */ + get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) { + throw new Error('Not yet implemented'); + } + + + async encode_image({ pixel_values, image_grid_thw }) { + const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, grid_thw: image_grid_thw })).image_features; + return features; + } + + _merge_input_ids_with_image_features({ + inputs_embeds, + image_features, + input_ids, + attention_mask, + }) { + throw new Error('Not yet implemented'); + return { inputs_embeds, attention_mask } + } +} + ////////////////////////////////////////////////// // Phi models @@ -6518,6 +6605,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ ['llava_onevision', ['LlavaOnevisionForConditionalGeneration', LlavaOnevisionForConditionalGeneration]], ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]], ['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]], + ['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]], ]); const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([