Skip to content

Commit

Permalink
Add depth-estimation w/ DPT and GLPN (#389)
Browse files Browse the repository at this point in the history
* Add `size` getter to `RawImage`

* Add `DPTFeatureExtractor`

* Add depth-estimation w/ DPT models

* Add GLPN models for depth estimation

* Add missing import in example

* Add `DPTFeatureExtractor` processor test

* Add unit test for GLPN processor

* Add support for `GLPNFeatureExtractor`

Uses `size_divisor` to determine resize width and height

* Add `GLPNForDepthEstimation` example code

* Add DPT to list of supported models

* Add GLPN to list of supported models

* Add `DepthEstimationPipeline`

* Add listed support for depth estimation pipeline

* Add depth estimation pipeline unit tests

* Fix formatting

* Update `pipeline` JSDoc

* Fix typo from merge
  • Loading branch information
xenova authored Nov 20, 2023
1 parent 5ddc472 commit 83dfa47
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 5 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te

| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | |
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) |
Expand Down Expand Up @@ -277,8 +277,10 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
Expand Down
2 changes: 1 addition & 1 deletion docs/snippets/5_supported-tasks.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | |
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) |
Expand Down
2 changes: 2 additions & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
Expand Down
10 changes: 10 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,21 @@
# Document Question Answering
'naver-clova-ix/donut-base-finetuned-docvqa',
],
'dpt': [
# Depth estimation
'Intel/dpt-hybrid-midas',
'Intel/dpt-large',
],
'falcon': [
# Text generation
'Rocketknight1/tiny-random-falcon-7b',
'fxmarty/really-tiny-falcon-testing',
],
'glpn': [
# Depth estimation
'vinvino02/glpn-kitti',
'vinvino02/glpn-nyu',
],
'gpt_neo': [
# Text generation
'EleutherAI/gpt-neo-125M',
Expand Down
106 changes: 106 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3371,6 +3371,100 @@ export class Swin2SRModel extends Swin2SRPreTrainedModel { }
export class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class DPTPreTrainedModel extends PreTrainedModel { }

/**
* The bare DPT Model transformer outputting raw hidden-states without any specific head on top.
*/
export class DPTModel extends DPTPreTrainedModel { }

/**
* DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
*
* **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`.
* ```javascript
* import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers';
*
* // Load model and processor
* const model_id = 'Xenova/dpt-hybrid-midas';
* const model = await DPTForDepthEstimation.from_pretrained(model_id);
* const processor = await AutoProcessor.from_pretrained(model_id);
*
* // Load image from URL
* const url = 'http://images.cocodataset.org/val2017/000000039769.jpg';
* const image = await RawImage.fromURL(url);
*
* // Prepare image for the model
* const inputs = await processor(image);
*
* // Run model
* const { predicted_depth } = await model(inputs);
*
* // Interpolate to original size
* const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false);
*
* // Visualize the prediction
* const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8');
* const depth = RawImage.fromTensor(formatted);
* // RawImage {
* // data: Uint8Array(307200) [ 85, 85, 84, ... ],
* // width: 640,
* // height: 480,
* // channels: 1
* // }
* ```
*/
export class DPTForDepthEstimation extends DPTPreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class GLPNPreTrainedModel extends PreTrainedModel { }

/**
* The bare GLPN encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.
*/
export class GLPNModel extends GLPNPreTrainedModel { }

/**
* GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.
*
* **Example:** Depth estimation w/ `Xenova/glpn-kitti`.
* ```javascript
* import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers';
*
* // Load model and processor
* const model_id = 'Xenova/glpn-kitti';
* const model = await GLPNForDepthEstimation.from_pretrained(model_id);
* const processor = await AutoProcessor.from_pretrained(model_id);
*
* // Load image from URL
* const url = 'http://images.cocodataset.org/val2017/000000039769.jpg';
* const image = await RawImage.fromURL(url);
*
* // Prepare image for the model
* const inputs = await processor(image);
*
* // Run model
* const { predicted_depth } = await model(inputs);
*
* // Interpolate to original size
* const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false);
*
* // Visualize the prediction
* const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8');
* const depth = RawImage.fromTensor(formatted);
* // RawImage {
* // data: Uint8Array(307200) [ 207, 169, 154, ... ],
* // width: 640,
* // height: 480,
* // channels: 1
* // }
* ```
*/
export class GLPNForDepthEstimation extends GLPNPreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class DonutSwinPreTrainedModel extends PreTrainedModel { }

Expand Down Expand Up @@ -4025,6 +4119,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['swin2sr', ['Swin2SRModel', Swin2SRModel]],
['donut-swin', ['DonutSwinModel', DonutSwinModel]],
['yolos', ['YolosModel', YolosModel]],
['dpt', ['DPTModel', DPTModel]],
['glpn', ['GLPNModel', GLPNModel]],

['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],

Expand Down Expand Up @@ -4205,6 +4301,11 @@ const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
])

const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
])


const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
Expand All @@ -4221,6 +4322,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
Expand Down Expand Up @@ -4425,6 +4527,10 @@ export class AutoModelForImageToImage extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES];
}

export class AutoModelForDepthEstimation extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES];
}

//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit 83dfa47

Please sign in to comment.