Skip to content

Commit

Permalink
[WIP] Add support for idefics3 (SmolVLM)
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 29, 2024
1 parent 2c92943 commit 3a02a52
Show file tree
Hide file tree
Showing 9 changed files with 501 additions and 45 deletions.
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ function getNormalizedConfig(config) {
case 'paligemma':
case 'florence2':
case 'llava_onevision':
case 'idefics3':
init_normalized_config = getNormalizedConfig(config.text_config);
break;
case 'moondream1':
Expand Down
117 changes: 87 additions & 30 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,41 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
}



function default_merge_input_ids_with_image_features({
image_token_id,
inputs_embeds,
image_features,
input_ids,
attention_mask,
}) {
console.log('input_ids', input_ids)
const image_tokens = input_ids.tolist().map(ids =>
ids.reduce((acc, x, idx) => {
if (x == image_token_id) acc.push(idx);
return acc;
}, [])
);
console.log('image_tokens', image_tokens)
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
const n_image_features = image_features.dims[0];
if (n_image_tokens !== n_image_features) {
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
}

// Equivalent to performing a masked_scatter
let img = 0;
for (let i = 0; i < image_tokens.length; ++i) {
const tokens = image_tokens[i];
const embeds = inputs_embeds[i];
for (let j = 0; j < tokens.length; ++j) {
embeds[tokens[j]].data.set(image_features[img++].data)
}
}
return { inputs_embeds, attention_mask }
}


/**
* Forward pass of an image-text-to-text model.
* @param {Object} self The image-text-to-text model model.
Expand Down Expand Up @@ -582,11 +617,15 @@ async function imageTextToTextForward(self, {

if (!inputs_embeds) {
// 1. Extract the input embeddings
console.log('before encode_text');
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
console.log('after encode_text', inputs_embeds.dims);

// 2. Possibly, merge text and images
if (pixel_values && input_ids.dims[1] !== 1) {
console.log('before encode_image');
const image_features = await self.encode_image({ pixel_values, ...kwargs });
console.log('after encode_image');

({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
image_features,
Expand Down Expand Up @@ -3304,8 +3343,8 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
export class LlavaPreTrainedModel extends PreTrainedModel {
forward_params = [
'input_ids',
'pixel_values',
'attention_mask',
'pixel_values',
'position_ids',
'past_key_values',
];
Expand Down Expand Up @@ -3487,6 +3526,46 @@ export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel
return decoder_outputs;
}
}


//////////////////////////////////////////////////
// Idefics3 Models
export class Idefics3PreTrainedModel extends PreTrainedModel {
forward_params = [
'input_ids',
'attention_mask',
'pixel_values',
'pixel_attention_mask',
'position_ids',
'past_key_values',
];
}

/**
* The LLAVA model which consists of a vision backbone and a language model.
*/
export class Idefics3ForConditionalGeneration extends Idefics3PreTrainedModel {

async encode_image({ pixel_values, pixel_attention_mask }) {
const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, pixel_attention_mask })).image_features;
return features;
}

_merge_input_ids_with_image_features(kwargs) {
const vision_hidden_size = kwargs.image_features.dims.at(-1);
const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);

return default_merge_input_ids_with_image_features({
// @ts-ignore
image_token_id: this.config.image_token_id,
...kwargs,
image_features: reshaped_image_hidden_states,
})
}
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class CLIPPreTrainedModel extends PreTrainedModel { }

/**
Expand Down Expand Up @@ -4280,36 +4359,12 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
return features;
}

_merge_input_ids_with_image_features({
inputs_embeds,
image_features,
input_ids,
attention_mask,
}) {
_merge_input_ids_with_image_features(kwargs) {
return default_merge_input_ids_with_image_features({
// @ts-ignore
const { image_token_id } = this.config;
const image_tokens = input_ids.tolist().map(ids =>
ids.reduce((acc, x, idx) => {
if (x == image_token_id) acc.push(idx);
return acc;
}, [])
);
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
const n_image_features = image_features.dims[0];
if (n_image_tokens !== n_image_features) {
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
}

// Equivalent to performing a masked_scatter
let img = 0;
for (let i = 0; i < image_tokens.length; ++i) {
const tokens = image_tokens[i];
const embeds = inputs_embeds[i];
for (let j = 0; j < tokens.length; ++j) {
embeds[tokens[j]].data.set(image_features[img++].data)
}
}
return { inputs_embeds, attention_mask }
image_token_id: this.config.image_token_id,
...kwargs
})
}

prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
Expand Down Expand Up @@ -6914,6 +6969,7 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([

const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
]);

const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
Expand All @@ -6922,6 +6978,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
]);

const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
Expand Down
169 changes: 169 additions & 0 deletions src/models/idefics3/image_processing_idefics3.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@


import {
ImageProcessor,
} from "../../base/image_processors_utils.js";
import { cat, full, interpolate_4d } from "../../utils/tensor.js";

export class Idefics3ImageProcessor extends ImageProcessor {
constructor(config) {
super(config);

this.do_image_splitting = config.do_image_splitting ?? true;
this.max_image_size = config.max_image_size;
}

/**
* Calculate size to resize images to, to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
* @param {import('../../utils/tensor.js').Tensor} pixel_values Tensor of the image to resize.
* @param {number} vision_encoder_max_size Maximum size of the output image. If the image is larger than this size,
* it will be split into patches of this size, and the original image will be concatenated with the patches, resized to max_size.
*/
get_resize_for_vision_encoder(pixel_values, vision_encoder_max_size) {
let [height, width] = pixel_values.dims.slice(-2);

const aspect_ratio = width / height;
if (width >= height) {
width = Math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size;
height = Math.floor(width / aspect_ratio);
height = Math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size;
} else {
height = Math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size;
width = Math.floor(height * aspect_ratio);
width = Math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size;
}
return { height, width };
}

// /** @param {RawImage|RawImage[]|RawImage[][]} images */
async _call(images, {
do_image_splitting = null,
return_row_col_info = false,
} = {}) {
// TODO: support 2D RawImages
if (!Array.isArray(images)) {
images = [images];
}

let images_list = await Promise.all(images.map(x => this.preprocess(x)));

// Original sizes of images
const original_sizes = images_list.map(x => x.original_size);

// Reshaped sizes of images, before padding or cropping
const reshaped_input_sizes = images_list.map(x => x.reshaped_input_size);

// Convert images to 4D tensors for easier processing
images_list.forEach(x => x.pixel_values.unsqueeze_(0));

let pixel_values;
let images_list_rows = [];
let images_list_cols = [];

const { longest_edge } = this.max_image_size;

if (do_image_splitting ?? this.do_image_splitting) {
let image_rows = new Array(images_list.length);
let image_cols = new Array(images_list.length);

// We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
images_list = await Promise.all(images_list.map(async (x, i) => {
const new_size = this.get_resize_for_vision_encoder(x.pixel_values, longest_edge);

const resized = await interpolate_4d(x.pixel_values, {
size: [new_size.height, new_size.width],
});

const { frames, num_splits_h, num_splits_w } = await this.split_image(resized, this.max_image_size);
image_rows[i] = num_splits_h;
image_cols[i] = num_splits_w;
return cat(frames, 0);
}));

images_list_rows.push(image_rows);
images_list_cols.push(image_cols);
} else {
/** @type {[number, number]} */
const size = [longest_edge, longest_edge];
images_list = await Promise.all(
images_list.map(x => interpolate_4d(x.pixel_values, { size }))
);

images_list_rows.push(new Array(images_list.length).fill(0));
images_list_cols.push(new Array(images_list.length).fill(0));
}

// Stack pixel values
// TODO: support 2D images inputs
pixel_values = cat(images_list, 0);
pixel_values.unsqueeze_(0);

// TODO: Improve pixel_attention_mask
const [b, n, c, h, w] = pixel_values.dims;
const pixel_attention_mask = full([b, n, h, w], true);

return {
pixel_values,
pixel_attention_mask,

original_sizes,
reshaped_input_sizes,
...(
return_row_col_info
? { rows: images_list_rows, cols: images_list_cols }
: {}
),
}
}

async split_image(pixel_values, { longest_edge }) {
const max_height = longest_edge;
const max_width = longest_edge;

const frames = [];

const [height, width] = pixel_values.dims.slice(-2);

let num_splits_h = 0, num_splits_w = 0;

if (height > max_height || width > max_width) {
// Calculate the number of splits
num_splits_h = Math.ceil(height / max_height);
num_splits_w = Math.ceil(width / max_width);

// Calculate the optimal width and height for the sub-images
const optimal_height = Math.ceil(height / num_splits_h);
const optimal_width = Math.ceil(width / num_splits_w);

// Iterate through each row and column
for (let r = 0; r < num_splits_h; r++) {
for (let c = 0; c < num_splits_w; c++) {
// Calculate the starting point of the crop
const start_x = c * optimal_width;
const start_y = r * optimal_height;

// Calculate the ending point of the crop
const end_x = Math.min(start_x + optimal_width, width);
const end_y = Math.min(start_y + optimal_height, height);

// Crop the image
frames.push(pixel_values.slice(null, null, [start_y, end_y], [start_x, end_x]));
}
}

// Resize the global image to match max dimensions for memory efficiency
const global_image_height = max_height;
const global_image_width = max_width;

if (height !== global_image_height || width !== global_image_width) {
pixel_values = await interpolate_4d(pixel_values, {
size: [global_image_height, global_image_width],
})
}
}

frames.push(pixel_values);

return { frames, num_splits_h, num_splits_w };
}
}
Loading

0 comments on commit 3a02a52

Please sign in to comment.