Skip to content

Commit

Permalink
[WIP] Support for qwen2vl models
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 22, 2024
1 parent 29de4b0 commit 54073c8
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ function getNormalizedConfig(config) {
case 'mistral':
case 'starcoder2':
case 'qwen2':
case 'qwen2_vl':
mapping['num_heads'] = 'num_key_value_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['hidden_size'] = 'hidden_size';
Expand Down
108 changes: 98 additions & 10 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -581,11 +581,11 @@ async function imageTextToTextForward(self, {

if (!inputs_embeds) {
// 1. Extract the input embeddings
inputs_embeds = await self.encode_text({ input_ids });
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });

// 2. Possibly, merge text and images
if (pixel_values && input_ids.dims[1] !== 1) {
const image_features = await self.encode_image({ pixel_values });
const image_features = await self.encode_image({ pixel_values, ...kwargs });

({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
image_features,
Expand All @@ -606,6 +606,16 @@ async function imageTextToTextForward(self, {
}
}

if (!position_ids) {

if (self.config.model_type === 'qwen2_vl') {
// Special case for qwen2_vl models
// @ts-ignore
const { image_grid_thw, video_grid_thw } = kwargs;
[position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
}
}

const outputs = await decoderForward(self, {
inputs_embeds,
past_key_values,
Expand All @@ -617,15 +627,20 @@ async function imageTextToTextForward(self, {
return outputs;
}

/**
* If the model supports providing position_ids, we create position_ids on the fly for batch generation,
* by computing the cumulative sum of the attention mask along the sequence length dimension.
*
* Equivalent to:
* ```python
* position_ids = attention_mask.long().cumsum(-1) - 1
* position_ids.masked_fill_(attention_mask == 0, 1)
* if past_key_values:
* position_ids = position_ids[:, -input_ids.shape[1] :]
* ```
*/
function createPositionIds(model_inputs, past_key_values = null) {
// If the model supports providing position_ids, we create position_ids on the fly for batch generation,
// by computing the cumulative sum of the attention mask along the sequence length dimension.
//
// Equivalent to:
// position_ids = attention_mask.long().cumsum(-1) - 1
// position_ids.masked_fill_(attention_mask == 0, 1)
// if past_key_values:
// position_ids = position_ids[:, -input_ids.shape[1] :]

const { input_ids, inputs_embeds, attention_mask } = model_inputs;
const [bz, seq_len] = attention_mask.dims;

Expand Down Expand Up @@ -3968,6 +3983,78 @@ export class Qwen2Model extends Qwen2PreTrainedModel { }
export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
//////////////////////////////////////////////////

export class Qwen2VLPreTrainedModel extends PreTrainedModel {
forward_params = [
// Text inputs
'input_ids',
'attention_mask',
'position_ids',
'past_key_values',

// Vision inputs
'pixel_values',
'image_grid_thw',
];
}
export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {

/**
* Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
*
* Explanation:
* Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
*
* For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
* Examples:
* input_ids: [T T T T T], here T is for text.
* temporal position_ids: [0, 1, 2, 3, 4]
* height position_ids: [0, 1, 2, 3, 4]
* width position_ids: [0, 1, 2, 3, 4]
*
* For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
* and 1D rotary position embeddin for text part.
* Examples:
* Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
* input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
* vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
* vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
* vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
* text temporal position_ids: [3, 4, 5, 6, 7]
* text height position_ids: [3, 4, 5, 6, 7]
* text width position_ids: [3, 4, 5, 6, 7]
* Here we calculate the text start position_ids as the max vision position_ids plus 1.
*
* @param {Tensor} input_ids Indices of input sequence tokens in the vocabulary. Tensor of shape `(batch_size, sequence_length)`.
* @param {Tensor} image_grid_thw (Optional) The temporal, height and width of feature shape of each image in LLM. Tensor of shape `(num_images, 3)`.
* @param {Tensor} video_grid_thw (Optional) The temporal, height and width of feature shape of each video in LLM. Tensor of shape `(num_videos, 3)`.
* @param {Tensor} attention_mask (Optional) Mask to avoid performing attention on padding token indices. Tensor of shape `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`:
* - 1 for tokens that are **not masked**,
* - 0 for tokens that are **masked**.
* @returns {[Tensor, Tensor]} [position_ids, mrope_position_deltas] with:
* - position_ids: Tensor of shape `(3, batch_size, sequence_length)`.
* - mrope_position_deltas: Tensor of shape `(batch_size)`.
*/
get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) {
throw new Error('Not yet implemented');
}


async encode_image({ pixel_values, image_grid_thw }) {
const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, grid_thw: image_grid_thw })).image_features;
return features;
}

_merge_input_ids_with_image_features({
inputs_embeds,
image_features,
input_ids,
attention_mask,
}) {
throw new Error('Not yet implemented');
return { inputs_embeds, attention_mask }
}
}


//////////////////////////////////////////////////
// Phi models
Expand Down Expand Up @@ -6518,6 +6605,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
['llava_onevision', ['LlavaOnevisionForConditionalGeneration', LlavaOnevisionForConditionalGeneration]],
['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
]);

const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
Expand Down

0 comments on commit 54073c8

Please sign in to comment.