Skip to content

Commit

Permalink
Add support for Segment Anything Model (#510)
Browse files Browse the repository at this point in the history
* Update SamModel

* Make `AutoModel.from_pretrained` work with SamModel

* Add listed support for SAM (Segment Anything Model)

* Update types of `calculateDimensions`

* Throw error if reading image from tensor with dims.length != 3

* Make SamProcessor input points optional

* Fix type errors

* `let` -> `const`

* `cat` -> `stack`

* Expose `reshape_input_points` in `SamProcessor`

* Add `input_labels` input parameter for SAM

* Add `input_labels` to sam processor

* Update SAM unit tests

* Remove TODOs

* Update JSDoc
  • Loading branch information
xenova authored Jan 10, 2024
1 parent 4d1d4d3 commit cdcbfc1
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 64 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
Expand Down
19 changes: 14 additions & 5 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,11 +745,20 @@
'distilroberta-base',
],
},
# 'sam': [
# 'facebook/sam-vit-base',
# 'facebook/sam-vit-large',
# 'facebook/sam-vit-huge',
# ],
'sam': {
# Mask generation
'mask-generation': [
# SAM
'facebook/sam-vit-base',
'facebook/sam-vit-large',
'facebook/sam-vit-huge',
'wanglab/medsam-vit-base',

# SlimSAM
'nielsr/slimsam-50-uniform',
'nielsr/slimsam-77-uniform',
],
},
'segformer': {
# Image segmentation
'image-segmentation': [
Expand Down
143 changes: 135 additions & 8 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const MODEL_TYPES = {
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
}
//////////////////////////////////////////////////

Expand Down Expand Up @@ -771,6 +772,13 @@ export class PreTrainedModel extends Callable {
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);

} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'vision_encoder', options),
constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options),
]);

} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
Expand Down Expand Up @@ -4242,12 +4250,130 @@ export class YolosObjectDetectionOutput extends ModelOutput {

//////////////////////////////////////////////////
export class SamPreTrainedModel extends PreTrainedModel { }

/**
* Segment Anything Model (SAM) for generating segmentation masks, given an input image
* and optional 2D location and bounding boxes.
*
* **Example:** Perform mask generation w/ `Xenova/sam-vit-base`.
* ```javascript
* import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers';
*
* const model = await SamModel.from_pretrained('Xenova/sam-vit-base');
* const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base');
*
* const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png';
* const raw_image = await RawImage.read(img_url);
* const input_points = [[[450, 600]]] // 2D localization of a window
*
* const inputs = await processor(raw_image, input_points);
* const outputs = await model(inputs);
*
* const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes);
* // [
* // Tensor {
* // dims: [ 1, 3, 1764, 2646 ],
* // type: 'bool',
* // data: Uint8Array(14002632) [ ... ],
* // size: 14002632
* // }
* // ]
* const scores = outputs.iou_scores;
* // Tensor {
* // dims: [ 1, 1, 3 ],
* // type: 'float32',
* // data: Float32Array(3) [
* // 0.8892380595207214,
* // 0.9311248064041138,
* // 0.983696699142456
* // ],
* // size: 3
* // }
* ```
*/
export class SamModel extends SamPreTrainedModel {
/**
* @param {Object} model_inputs
* @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt.
* @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`.
* Creates a new instance of the `SamModel` class.
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
* @param {Object} vision_encoder The ONNX session containing the vision encoder model.
* @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model.
*/
constructor(config, vision_encoder, prompt_encoder_mask_decoder) {
super(config, vision_encoder);
this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder;
}

/**
* Compute image embeddings and positional image embeddings, given the pixel values of an image.
* @param {Object} model_inputs Object containing the model inputs.
* @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`.
* @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings.
*/
async get_image_embeddings({ pixel_values }) {
// in:
// - pixel_values: tensor.float32[batch_size,3,1024,1024]
//
// out:
// - image_embeddings: tensor.float32[batch_size,256,64,64]
// - image_positional_embeddings: tensor.float32[batch_size,256,64,64]
return await encoderForward(this, { pixel_values })
}

/**
* @typedef {Object} SamModelInputs Object containing the model inputs.
* @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* These can be obtained using a `SamProcessor`.
* @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`.
* This is used by the prompt encoder to encode the prompt.
* @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`.
* This is used by the prompt encoder to encode the prompt. There are 4 types of labels:
* - `1`: the point is a point that contains the object of interest
* - `0`: the point is a point that does not contain the object of interest
* - `-1`: the point corresponds to the background
* - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
* @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder.
* @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder.
*/

/**
* @param {SamModelInputs} model_inputs Object containing the model inputs.
* @returns {Promise<Object>} The output of the model.
*/
async forward(model_inputs) {
if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) {
// Compute the image embeddings if they are missing
model_inputs = {
...model_inputs,
...(await this.get_image_embeddings(model_inputs))
}
}

if (!model_inputs.input_labels) {
// Set default input labels if they are missing
const shape = model_inputs.input_points.dims.slice(0, -1);
const numElements = shape.reduce((a, b) => a * b, 1);
model_inputs.input_labels = new Tensor(
'int64',
new BigInt64Array(numElements).fill(1n),
shape
);
}

// Returns:
// - iou_scores: tensor.float32[batch_size,point_batch_size,3]
// - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256]
return await sessionRun(this.prompt_encoder_mask_decoder, {
input_points: model_inputs.input_points,
input_labels: model_inputs.input_labels,
image_embeddings: model_inputs.image_embeddings,
image_positional_embeddings: model_inputs.image_positional_embeddings,
});
}

/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Model inputs
* @returns {Promise<SamImageSegmentationOutput>} Object containing segmentation outputs
*/
async _call(model_inputs) {
return new SamImageSegmentationOutput(await super._call(model_inputs));
Expand Down Expand Up @@ -5049,7 +5175,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([

['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],

['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly
]);

const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
Expand Down Expand Up @@ -5290,7 +5415,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
Expand Down Expand Up @@ -5329,7 +5454,9 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
* let model = await AutoModel.from_pretrained('bert-base-uncased');
*/
export class AutoModel extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY];
/** @type {Map<string, Object>[]} */
// @ts-ignore
static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map(x => x[0]);
static BASE_IF_FAIL = true;
}

Expand Down Expand Up @@ -5493,7 +5620,7 @@ export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {


/**
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
* Helper class which is used to instantiate pretrained mask generation models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
*
* @example
Expand Down
Loading

0 comments on commit cdcbfc1

Please sign in to comment.