diff --git a/.github/workflows/delete-doc-comment-trigger.yml b/.github/workflows/delete-doc-comment-trigger.yml deleted file mode 100644 index 5e39e2539..000000000 --- a/.github/workflows/delete-doc-comment-trigger.yml +++ /dev/null @@ -1,12 +0,0 @@ -name: Delete doc comment trigger - -on: - pull_request: - types: [ closed ] - - -jobs: - delete: - uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main - with: - pr_number: ${{ github.event.number }} \ No newline at end of file diff --git a/.github/workflows/delete-doc-comment.yml b/.github/workflows/delete-doc-comment.yml deleted file mode 100644 index 8604019d7..000000000 --- a/.github/workflows/delete-doc-comment.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Delete doc comment - -on: - workflow_run: - workflows: ["Delete doc comment trigger"] - types: - - completed - - -jobs: - delete: - uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main - secrets: - comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index f15b73826..2ad41220e 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,13 @@ NPM - Downloads + NPM Downloads + + + jsDelivr Hits - License + License Documentation @@ -98,7 +101,7 @@ npm i @xenova/transformers Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with: ```html ``` @@ -130,7 +133,7 @@ Want to jump straight in? Get started with one of our sample applications/templa -By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.8.0/dist/), which should work out-of-the-box. You can customize this as follows: +By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.10.0/dist/), which should work out-of-the-box. You can customize this as follows: ### Settings @@ -211,7 +214,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| -| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ❌ | +| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) | | [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) | | [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) | @@ -247,7 +250,9 @@ You can refine your search by selecting the task you're interested in (e.g., [te | [Image-to-Text](https://huggingface.co/tasks/image-to-text) | `image-to-text` | Output text from a given image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToTextPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-to-text&library=transformers.js) | | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ | | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ | +| [Zero-Shot Audio Classification](https://huggingface.co/learn/audio-course/chapter4/classification_models#zero-shot-audio-classification) | `zero-shot-audio-classification` | Classifying audios into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotAudioClassificationPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-audio-classification&library=transformers.js) | | [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) | +| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) | #### Reinforcement Learning @@ -261,6 +266,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te ### Models 1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. +1. **[Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer)** (from MIT) released with the paper [AST: Audio Spectrogram Transformer](https://arxiv.org/abs/2104.01778) by Yuan Gong, Yu-An Chung, James Glass. 1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei. 1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. @@ -268,17 +274,22 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/). 1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. +1. **[CLAP](https://huggingface.co/docs/transformers/model_doc/clap)** (from LAION-AI) released with the paper [Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation](https://arxiv.org/abs/2211.06687) by Yusong Wu, Ke Chen, Tianyu Zhang, Yuchen Hui, Taylor Berg-Kirkpatrick, Shlomo Dubnov. 1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. 1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong. 1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve. +1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. +1. **[ConvNeXTV2](https://huggingface.co/docs/transformers/model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT. 1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. +1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. 1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme. 1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei +1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. @@ -300,7 +311,9 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. +1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. diff --git a/docs/scripts/build_readme.py b/docs/scripts/build_readme.py index e3beaca86..82c51dad7 100644 --- a/docs/scripts/build_readme.py +++ b/docs/scripts/build_readme.py @@ -17,10 +17,13 @@ NPM
- Downloads + NPM Downloads + + + jsDelivr Hits - License + License Documentation diff --git a/docs/snippets/2_installation.snippet b/docs/snippets/2_installation.snippet index f06edaa68..8d5e433d2 100644 --- a/docs/snippets/2_installation.snippet +++ b/docs/snippets/2_installation.snippet @@ -7,6 +7,6 @@ npm i @xenova/transformers Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with: ```html ``` diff --git a/docs/snippets/4_custom-usage.snippet b/docs/snippets/4_custom-usage.snippet index d9522edc2..3367b2685 100644 --- a/docs/snippets/4_custom-usage.snippet +++ b/docs/snippets/4_custom-usage.snippet @@ -1,6 +1,6 @@ -By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.8.0/dist/), which should work out-of-the-box. You can customize this as follows: +By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.10.0/dist/), which should work out-of-the-box. You can customize this as follows: ### Settings diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index dee075808..838026092 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -22,7 +22,7 @@ | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| -| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ❌ | +| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) | | [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) | | [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) | @@ -58,7 +58,9 @@ | [Image-to-Text](https://huggingface.co/tasks/image-to-text) | `image-to-text` | Output text from a given image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToTextPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-to-text&library=transformers.js) | | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ | | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ | +| [Zero-Shot Audio Classification](https://huggingface.co/learn/audio-course/chapter4/classification_models#zero-shot-audio-classification) | `zero-shot-audio-classification` | Classifying audios into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotAudioClassificationPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-audio-classification&library=transformers.js) | | [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) | +| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) | #### Reinforcement Learning diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 42c12bd2a..4dfc00e2d 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -2,6 +2,7 @@ ### Models 1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. +1. **[Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer)** (from MIT) released with the paper [AST: Audio Spectrogram Transformer](https://arxiv.org/abs/2104.01778) by Yuan Gong, Yu-An Chung, James Glass. 1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei. 1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. @@ -9,17 +10,22 @@ 1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/). 1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. +1. **[CLAP](https://huggingface.co/docs/transformers/model_doc/clap)** (from LAION-AI) released with the paper [Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation](https://arxiv.org/abs/2211.06687) by Yusong Wu, Ke Chen, Tianyu Zhang, Yuchen Hui, Taylor Berg-Kirkpatrick, Shlomo Dubnov. 1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. 1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong. 1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve. +1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. +1. **[ConvNeXTV2](https://huggingface.co/docs/transformers/model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT. 1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. +1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. 1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme. 1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei +1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. @@ -41,7 +47,9 @@ 1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. +1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. diff --git a/examples/next-client/package-lock.json b/examples/next-client/package-lock.json index ea29ca18c..309387dca 100644 --- a/examples/next-client/package-lock.json +++ b/examples/next-client/package-lock.json @@ -4019,9 +4019,9 @@ } }, "node_modules/sharp": { - "version": "0.32.4", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.4.tgz", - "integrity": "sha512-exUnZewqVZC6UXqXuQ8fyJJv0M968feBi04jb9GcUHrWtkRoAKnbJt8IfwT4NJs7FskArbJ14JAFGVuooszoGg==", + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", diff --git a/examples/next-server/package-lock.json b/examples/next-server/package-lock.json index 25f40f30b..e7861f920 100644 --- a/examples/next-server/package-lock.json +++ b/examples/next-server/package-lock.json @@ -4019,9 +4019,9 @@ } }, "node_modules/sharp": { - "version": "0.32.4", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.4.tgz", - "integrity": "sha512-exUnZewqVZC6UXqXuQ8fyJJv0M968feBi04jb9GcUHrWtkRoAKnbJt8IfwT4NJs7FskArbJ14JAFGVuooszoGg==", + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", diff --git a/examples/semantic-image-search-client/package-lock.json b/examples/semantic-image-search-client/package-lock.json index c99b280f2..7c06d25f3 100644 --- a/examples/semantic-image-search-client/package-lock.json +++ b/examples/semantic-image-search-client/package-lock.json @@ -4073,9 +4073,9 @@ } }, "node_modules/sharp": { - "version": "0.32.4", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.4.tgz", - "integrity": "sha512-exUnZewqVZC6UXqXuQ8fyJJv0M968feBi04jb9GcUHrWtkRoAKnbJt8IfwT4NJs7FskArbJ14JAFGVuooszoGg==", + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", diff --git a/examples/semantic-image-search/package-lock.json b/examples/semantic-image-search/package-lock.json index f714d8b13..fbc41c984 100644 --- a/examples/semantic-image-search/package-lock.json +++ b/examples/semantic-image-search/package-lock.json @@ -4256,9 +4256,9 @@ } }, "node_modules/sharp": { - "version": "0.32.4", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.4.tgz", - "integrity": "sha512-exUnZewqVZC6UXqXuQ8fyJJv0M968feBi04jb9GcUHrWtkRoAKnbJt8IfwT4NJs7FskArbJ14JAFGVuooszoGg==", + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", diff --git a/package-lock.json b/package-lock.json index 3c86aafe7..c2bdd41bf 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@xenova/transformers", - "version": "2.8.0", + "version": "2.10.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@xenova/transformers", - "version": "2.8.0", + "version": "2.10.0", "license": "Apache-2.0", "dependencies": { "onnxruntime-web": "1.14.0", @@ -2017,6 +2017,11 @@ "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", "dev": true }, + "node_modules/b4a": { + "version": "1.6.4", + "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz", + "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw==" + }, "node_modules/babel-jest": { "version": "29.6.1", "resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.6.1.tgz", @@ -3014,9 +3019,9 @@ } }, "node_modules/detect-libc": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.1.tgz", - "integrity": "sha512-463v3ZeIrcWtdgIg6vI6XUncguvr2TnGl4SzDXinkt9mSLpBJKXT3mW6xT3VQdDN11+WVs29pgvivTc4Lp8v+w==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz", + "integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==", "engines": { "node": ">=8" } @@ -3415,6 +3420,11 @@ "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", "dev": true }, + "node_modules/fast-fifo": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", + "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==" + }, "node_modules/fast-glob": { "version": "3.2.12", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", @@ -5616,9 +5626,9 @@ } }, "node_modules/node-addon-api": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.0.0.tgz", - "integrity": "sha512-GyHvgPvUXBvAkXa0YvYnhilSB1A+FRYMpIVggKzPZqdaZfevZOuzfWzyvgzOwRLHBeo/MMswmJFsrNF4Nw1pmA==" + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz", + "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA==" }, "node_modules/node-forge": { "version": "1.3.1", @@ -6150,6 +6160,11 @@ } ] }, + "node_modules/queue-tick": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz", + "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag==" + }, "node_modules/randombytes": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz", @@ -6689,18 +6704,18 @@ } }, "node_modules/sharp": { - "version": "0.32.0", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.0.tgz", - "integrity": "sha512-yLAypVcqj1toSAqRSwbs86nEzfyZVDYqjuUX8grhFpeij0DDNagKJXELS/auegDBRDg1XBtELdOGfo2X1cCpeA==", + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", - "detect-libc": "^2.0.1", - "node-addon-api": "^6.0.0", + "detect-libc": "^2.0.2", + "node-addon-api": "^6.1.0", "prebuild-install": "^7.1.1", - "semver": "^7.3.8", + "semver": "^7.5.4", "simple-get": "^4.0.1", - "tar-fs": "^2.1.1", + "tar-fs": "^3.0.4", "tunnel-agent": "^0.6.0" }, "engines": { @@ -6710,6 +6725,26 @@ "url": "https://opencollective.com/libvips" } }, + "node_modules/sharp/node_modules/tar-fs": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz", + "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==", + "dependencies": { + "mkdirp-classic": "^0.5.2", + "pump": "^3.0.0", + "tar-stream": "^3.1.5" + } + }, + "node_modules/sharp/node_modules/tar-stream": { + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz", + "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==", + "dependencies": { + "b4a": "^1.6.4", + "fast-fifo": "^1.2.0", + "streamx": "^2.15.0" + } + }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", @@ -7026,6 +7061,15 @@ "node": ">=0.10.0" } }, + "node_modules/streamx": { + "version": "2.15.5", + "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.5.tgz", + "integrity": "sha512-9thPGMkKC2GctCzyCUjME3yR03x2xNo0GPKGkRw2UMYN+gqWa9uqpyNWhmsNCutU5zHmkUum0LsCRQTXUgUCAg==", + "dependencies": { + "fast-fifo": "^1.1.0", + "queue-tick": "^1.0.1" + } + }, "node_modules/string_decoder": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", diff --git a/package.json b/package.json index ce18ca9f3..131af0308 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@xenova/transformers", - "version": "2.8.0", + "version": "2.10.0", "description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!", "main": "./src/transformers.js", "types": "./types/transformers.d.ts", @@ -10,7 +10,7 @@ "dev": "webpack serve --no-client-overlay", "build": "webpack && npm run typegen", "generate-tests": "python -m tests.generate_tests", - "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose", + "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1", "readme": "python ./docs/scripts/build_readme.py", "docs-api": "node ./docs/scripts/generate.js", "docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module", diff --git a/scripts/convert.py b/scripts/convert.py index ffc999bfa..9055397e8 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -84,7 +84,7 @@ 'vision-encoder-decoder': { 'per_channel': False, 'reduce_range': False, - } + }, } MODELS_WITHOUT_TOKENIZERS = [ @@ -326,6 +326,11 @@ def main(): with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp, indent=4) + elif config.model_type == 'owlvit': + # Override default batch size to 1, needed because non-maximum suppression is performed for exporting. + # For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032 + export_kwargs['batch_size'] = 1 + else: pass # TODO @@ -348,6 +353,25 @@ def main(): device=conv_args.device, ) + # TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged + # elif config.model_type == 'clap' and conv_args.split_modalities: + # # Handle special case for exporting text and audio models separately + # from .extra.clap import ClapTextModelWithProjectionOnnxConfig, ClapAudioModelWithProjectionOnnxConfig + # from transformers.models.clap import ClapTextModelWithProjection, ClapAudioModelWithProjection + + # text_model = ClapTextModelWithProjection.from_pretrained(model_id) + # audio_model = ClapAudioModelWithProjection.from_pretrained(model_id) + + # export_models( + # models_and_onnx_configs={ + # "text_model": (text_model, ClapTextModelWithProjectionOnnxConfig(text_model.config)), + # "audio_model": (audio_model, ClapAudioModelWithProjectionOnnxConfig(audio_model.config)), + # }, + # output_dir=output_model_folder, + # opset=conv_args.opset, + # device=conv_args.device, + # ) + else: main_export(**export_kwargs) diff --git a/scripts/extra/clap.py b/scripts/extra/clap.py new file mode 100644 index 000000000..cd71dcad5 --- /dev/null +++ b/scripts/extra/clap.py @@ -0,0 +1,40 @@ +# TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged + +# # Support exporting vision and text models separately: +# # Adapted from https://github.com/huggingface/optimum/issues/1186#issuecomment-1637641760 + +# from optimum.exporters.onnx.model_configs import CLAPTextWithProjectionOnnxConfig, AudioOnnxConfig +# from optimum.utils.normalized_config import NormalizedAudioConfig +# from optimum.utils.input_generators import DummyAudioInputGenerator +# from typing import Dict + + +# class ClapAudioModelWithProjectionOnnxConfig(AudioOnnxConfig): +# NORMALIZED_CONFIG_CLASS = NormalizedAudioConfig +# DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator, ) + +# @property +# def inputs(self) -> Dict[str, Dict[int, str]]: +# return { +# "input_features": {0: "audio_batch_size", 1: "num_channels", 2: "height", 3: "width"}, # As described in modeling_clap.py +# } + +# @property +# def outputs(self) -> Dict[str, Dict[int, str]]: +# return { +# "audio_embeds": {0: "batch_size"}, +# } + +# class ClapTextModelWithProjectionOnnxConfig(CLAPTextWithProjectionOnnxConfig): +# @property +# def outputs(self) -> Dict[str, Dict[int, str]]: +# return { +# "text_embeds": {0: "batch_size"}, +# } + +# def generate_dummy_inputs(self, framework: str = "pt", **kwargs): +# dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) +# if framework == "pt": +# import torch +# dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int64) +# return dummy_inputs diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 7838686d0..26a46cbbd 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -16,6 +16,15 @@ 'sentence-transformers/paraphrase-albert-base-v2', ], }, + 'audio-spectrogram-transformer': { + # Audio classification + 'audio-classification': { + 'MIT/ast-finetuned-audioset-10-10-0.4593', + 'MIT/ast-finetuned-audioset-16-16-0.442', + 'MIT/ast-finetuned-speech-commands-v2', + 'mtg-upf/discogs-maest-30s-pw-73e-ts', + } + }, 'bart': { # Summarization 'summarization': [ @@ -175,6 +184,17 @@ 'openai/clip-vit-large-patch14-336', ] }, + 'clap': { + # Zero-shot audio classification and feature extraction + # (with and without `--split_modalities`) + 'zero-shot-audio-classification': { + 'laion/clap-htsat-unfused', + # TODO add 'laion/clap-htsat-fused', + 'laion/larger_clap_general', + 'laion/larger_clap_music_and_speech', + # 'Xenova/tiny-random-ClapModel', + } + }, 'codegen': { # Text generation 'text-generation': [ @@ -183,6 +203,49 @@ 'Salesforce/codegen-350M-nl', ], }, + 'convnext': { + # Image classification + 'image-classification': [ + 'facebook/convnext-tiny-224', + 'facebook/convnext-small-224', + 'facebook/convnext-base-224', + 'facebook/convnext-base-224-22k', + 'facebook/convnext-base-224-22k-1k', + 'facebook/convnext-base-384', + 'facebook/convnext-base-384-22k-1k', + 'facebook/convnext-large-224', + 'facebook/convnext-large-224-22k', + 'facebook/convnext-large-224-22k-1k', + 'facebook/convnext-large-384', + 'facebook/convnext-large-384-22k-1k', + 'facebook/convnext-xlarge-224-22k', + 'facebook/convnext-xlarge-224-22k-1k', + 'facebook/convnext-xlarge-384-22k-1k', + ], + }, + 'convnextv2': { + # Image classification + 'image-classification': [ + 'facebook/convnextv2-atto-1k-224', + 'facebook/convnextv2-femto-1k-224', + 'facebook/convnextv2-pico-1k-224', + 'facebook/convnextv2-tiny-1k-224', + 'facebook/convnextv2-tiny-22k-384', + 'facebook/convnextv2-tiny-22k-224', + 'facebook/convnextv2-nano-1k-224', + 'facebook/convnextv2-nano-22k-384', + 'facebook/convnextv2-base-22k-224', + 'facebook/convnextv2-base-1k-224', + 'facebook/convnextv2-base-22k-384', + 'facebook/convnextv2-large-22k-224', + 'facebook/convnextv2-large-1k-224', + 'facebook/convnextv2-large-22k-384', + # 'facebook/convnextv2-huge-22k-512', + # 'facebook/convnextv2-huge-1k-224', + # 'facebook/convnextv2-huge-22k-384', + # 'facebook/convnextv2-nano-22k-224', + ], + }, 'deberta': { # Zero-shot classification 'zero-shot-classification': [ @@ -265,7 +328,7 @@ 'distilbert-base-cased', ], }, - 'donut': { # NOTE: also a `vision-encoder-decoder` + 'donut': { # NOTE: also a `vision-encoder-decoder` # Image-to-text 'image-to-text': [ 'naver-clova-ix/donut-base-finetuned-cord-v2', @@ -277,6 +340,13 @@ 'naver-clova-ix/donut-base-finetuned-docvqa', ], }, + 'dpt': { + # Depth estimation + 'depth-estimation': [ + 'Intel/dpt-hybrid-midas', + 'Intel/dpt-large', + ], + }, 'falcon': { # Text generation 'text-generation': [ @@ -284,6 +354,13 @@ 'fxmarty/really-tiny-falcon-testing', ] }, + 'glpn': { + # Depth estimation + 'depth-estimation': [ + 'vinvino02/glpn-kitti', + 'vinvino02/glpn-nyu', + ], + }, 'gpt_neo': { # Text generation 'text-generation': [ @@ -454,6 +531,13 @@ 'google/mt5-base', ], }, + 'nougat': { + # Image-to-text + 'image-to-text': [ + 'facebook/nougat-small', + 'facebook/nougat-base', + ], + }, 'opt': { # Text generation 'text-generation': [ @@ -464,6 +548,15 @@ 'PygmalionAI/pygmalion-350m', ] }, + 'owlvit': { + # Object detection (Zero-shot object detection) + # NOTE: Exported with --batch_size 1 + 'zero-shot-object-detection': [ + 'google/owlvit-base-patch32', + 'google/owlvit-base-patch16', + 'google/owlvit-large-patch14', + ], + }, 'resnet': { # Image classification 'image-classification': [ @@ -503,7 +596,7 @@ # 'facebook/sam-vit-large', # 'facebook/sam-vit-huge', # ], - + 'speecht5': { # Text-to-audio/Text-to-speech 'text-to-audio': [ diff --git a/src/backends/onnx.js b/src/backends/onnx.js index a06beb0e3..0bee3dce7 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -21,7 +21,7 @@ import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web'; -/** @type {module} The ONNX runtime module. */ +/** @type {import('onnxruntime-web')} The ONNX runtime module. */ export let ONNX; export const executionProviders = [ diff --git a/src/env.js b/src/env.js index a12688933..16b19aaa5 100644 --- a/src/env.js +++ b/src/env.js @@ -29,7 +29,7 @@ import url from 'url'; import { ONNX } from './backends/onnx.js'; const { env: onnx_env } = ONNX; -const VERSION = '2.8.0'; +const VERSION = '2.10.0'; // Check if various APIs are available (depends on environment) const WEB_CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self; diff --git a/src/models.js b/src/models.js index b0a82cee0..1c55d8877 100644 --- a/src/models.js +++ b/src/models.js @@ -42,6 +42,10 @@ import { AutoConfig, } from './configs.js'; +import { + add_token_types, +} from './tokenizers.js'; + import { Callable, isIntegralNumber, @@ -64,6 +68,7 @@ import { WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, + NoBadWordsLogitsProcessor, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, @@ -82,7 +87,9 @@ import { import { executionProviders, ONNX } from './backends/onnx.js'; import { medianFilter } from './transformers.js'; -const { InferenceSession, Tensor: ONNXTensor } = ONNX; +const { InferenceSession, Tensor: ONNXTensor, env } = ONNX; + +/** @typedef {import('onnxruntime-web').InferenceSession} InferenceSession */ ////////////////////////////////////////////////// // Model types: used internally @@ -142,21 +149,31 @@ async function constructSession(pretrained_model_name_or_path, fileName, options /** * Validate model inputs * @param {InferenceSession} session The InferenceSession object that will be run. - * @param {Object} inputs The inputs to check. - * @returns {Promise} A Promise that resolves to the checked inputs. + * @param {Record} inputs The inputs to check. + * @returns {Record} The checked inputs. * @throws {Error} If any inputs are missing. * @private */ -async function validateInputs(session, inputs) { - // NOTE: Only create a shallow copy - const checkedInputs = {}; +function validateInputs(session, inputs) { + /** + * NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy` + * @type {Record} + */ + const checkedInputs = Object.create(null); const missingInputs = []; - for (let inputName of session.inputNames) { - if (inputs[inputName] === undefined) { + for (const inputName of session.inputNames) { + const tensor = inputs[inputName]; + // Rare case where one of the model's input names corresponds to a built-in + // object name (e.g., toString), which would cause a simple (!tensor) check to fail, + // because it's not undefined but a function. + if (!(tensor instanceof Tensor)) { missingInputs.push(inputName); - } else { - checkedInputs[inputName] = inputs[inputName]; + continue; } + // NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker + // boundary, transferring ownership to the worker and invalidating the tensor. + // So, in this case, we simply sacrifice a clone for it. + checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor; } if (missingInputs.length > 0) { throw new Error( @@ -187,7 +204,7 @@ async function validateInputs(session, inputs) { * @private */ async function sessionRun(session, inputs) { - const checkedInputs = await validateInputs(session, inputs); + const checkedInputs = validateInputs(session, inputs); try { let output = await session.run(checkedInputs); output = replaceTensors(output); @@ -488,10 +505,15 @@ function seq2seqUpdatebeam(beam, newTokenId) { * @private */ async function encoderForward(self, model_inputs) { - let encoderFeeds = {}; - for (let key of self.session.inputNames) { + const encoderFeeds = Object.create(null); + for (const key of self.session.inputNames) { encoderFeeds[key] = model_inputs[key]; } + if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { + // Assign default `token_type_ids` to the `encoderFeeds` if the model expects it, + // but they weren't created by the tokenizer. + add_token_types(encoderFeeds); + } return await sessionRun(self.session, encoderFeeds); } @@ -836,9 +858,9 @@ export class PreTrainedModel extends Callable { // } // } - // if (generation_config.bad_words_ids !== null) { - // processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); - // } + if (generation_config.bad_words_ids !== null) { + processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); + } if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); @@ -2442,6 +2464,22 @@ export class XLMRobertaForQuestionAnswering extends XLMRobertaPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// Audio Spectrogram Transformer (AST) models +export class ASTPreTrainedModel extends PreTrainedModel { }; + +/** + * The bare AST Model transformer outputting raw hidden-states without any specific head on top. + */ +export class ASTModel extends ASTPreTrainedModel { } + +/** + * Audio Spectrogram Transformer model with an audio classification head on top + * (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2. + */ +export class ASTForAudioClassification extends ASTPreTrainedModel {} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // Whisper models export class WhisperPreTrainedModel extends PreTrainedModel { }; @@ -3035,9 +3073,9 @@ export class LlamaPreTrainedModel extends PreTrainedModel { // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id - this.num_heads = this.config.num_attention_heads + this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads this.num_layers = this.config.num_hidden_layers - this.dim_kv = this.config.hidden_size / this.num_heads; + this.dim_kv = this.config.hidden_size / this.config.num_attention_heads } } /** @@ -3179,6 +3217,12 @@ export class MobileViTForImageClassification extends MobileViTPreTrainedModel { ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class OwlViTPreTrainedModel extends PreTrainedModel { } +export class OwlViTModel extends OwlViTPreTrainedModel { } +export class OwlViTForObjectDetection extends OwlViTPreTrainedModel { } +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // Beit Models export class BeitPreTrainedModel extends PreTrainedModel { } @@ -3343,6 +3387,100 @@ export class Swin2SRModel extends Swin2SRPreTrainedModel { } export class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class DPTPreTrainedModel extends PreTrainedModel { } + +/** + * The bare DPT Model transformer outputting raw hidden-states without any specific head on top. + */ +export class DPTModel extends DPTPreTrainedModel { } + +/** + * DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + * + * **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`. + * ```javascript + * import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers'; + * + * // Load model and processor + * const model_id = 'Xenova/dpt-hybrid-midas'; + * const model = await DPTForDepthEstimation.from_pretrained(model_id); + * const processor = await AutoProcessor.from_pretrained(model_id); + * + * // Load image from URL + * const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; + * const image = await RawImage.fromURL(url); + * + * // Prepare image for the model + * const inputs = await processor(image); + * + * // Run model + * const { predicted_depth } = await model(inputs); + * + * // Interpolate to original size + * const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false); + * + * // Visualize the prediction + * const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8'); + * const depth = RawImage.fromTensor(formatted); + * // RawImage { + * // data: Uint8Array(307200) [ 85, 85, 84, ... ], + * // width: 640, + * // height: 480, + * // channels: 1 + * // } + * ``` + */ +export class DPTForDepthEstimation extends DPTPreTrainedModel { } +////////////////////////////////////////////////// + +////////////////////////////////////////////////// +export class GLPNPreTrainedModel extends PreTrainedModel { } + +/** + * The bare GLPN encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top. + */ +export class GLPNModel extends GLPNPreTrainedModel { } + +/** + * GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2. + * + * **Example:** Depth estimation w/ `Xenova/glpn-kitti`. + * ```javascript + * import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers'; + * + * // Load model and processor + * const model_id = 'Xenova/glpn-kitti'; + * const model = await GLPNForDepthEstimation.from_pretrained(model_id); + * const processor = await AutoProcessor.from_pretrained(model_id); + * + * // Load image from URL + * const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; + * const image = await RawImage.fromURL(url); + * + * // Prepare image for the model + * const inputs = await processor(image); + * + * // Run model + * const { predicted_depth } = await model(inputs); + * + * // Interpolate to original size + * const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false); + * + * // Visualize the prediction + * const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8'); + * const depth = RawImage.fromTensor(formatted); + * // RawImage { + * // data: Uint8Array(307200) [ 207, 169, 154, ... ], + * // width: 640, + * // height: 480, + * // channels: 1 + * // } + * ``` + */ +export class GLPNForDepthEstimation extends GLPNPreTrainedModel { } +////////////////////////////////////////////////// + ////////////////////////////////////////////////// export class DonutSwinPreTrainedModel extends PreTrainedModel { } @@ -3423,6 +3561,50 @@ export class DonutSwinPreTrainedModel extends PreTrainedModel { } export class DonutSwinModel extends DonutSwinPreTrainedModel { } ////////////////////////////////////////////////// + +////////////////////////////////////////////////// +export class ConvNextPreTrainedModel extends PreTrainedModel { } + +/** + * The bare ConvNext model outputting raw features without any specific head on top. + */ +export class ConvNextModel extends ConvNextPreTrainedModel { } + +/** + * ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet. + */ +export class ConvNextForImageClassification extends ConvNextPreTrainedModel { + /** + * @param {any} model_inputs + */ + async _call(model_inputs) { + return new SequenceClassifierOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + + +////////////////////////////////////////////////// +export class ConvNextV2PreTrainedModel extends PreTrainedModel { } + +/** + * The bare ConvNextV2 model outputting raw features without any specific head on top. + */ +export class ConvNextV2Model extends ConvNextV2PreTrainedModel { } + +/** + * ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet. + */ +export class ConvNextV2ForImageClassification extends ConvNextV2PreTrainedModel { + /** + * @param {any} model_inputs + */ + async _call(model_inputs) { + return new SequenceClassifierOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// export class YolosPreTrainedModel extends PreTrainedModel { } export class YolosModel extends YolosPreTrainedModel { } @@ -3900,6 +4082,85 @@ export class FalconForCausalLM extends FalconPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// CLAP models +export class ClapPreTrainedModel extends PreTrainedModel { } + +export class ClapModel extends ClapPreTrainedModel { } + +/** + * CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output). + * + * **Example:** Compute text embeddings with `ClapTextModelWithProjection`. + * + * ```javascript + * import { AutoTokenizer, ClapTextModelWithProjection } from '@xenova/transformers'; + * + * // Load tokenizer and text model + * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clap-htsat-unfused'); + * const text_model = await ClapTextModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused'); + * + * // Run tokenization + * const texts = ['a sound of a cat', 'a sound of a dog']; + * const text_inputs = tokenizer(texts, { padding: true, truncation: true }); + * + * // Compute embeddings + * const { text_embeds } = await text_model(text_inputs); + * // Tensor { + * // dims: [ 2, 512 ], + * // type: 'float32', + * // data: Float32Array(1024) [ ... ], + * // size: 1024 + * // } + * ``` + */ +export class ClapTextModelWithProjection extends ClapPreTrainedModel { + + /** @type {PreTrainedModel.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options = {}) { + // Update default model file name if not provided + options.model_file_name ??= 'text_model'; + return super.from_pretrained(pretrained_model_name_or_path, options); + } +} + +/** + * CLAP Audio Model with a projection layer on top (a linear layer on top of the pooled output). + * + * **Example:** Compute audio embeddings with `ClapAudioModelWithProjection`. + * + * ```javascript + * import { AutoProcessor, ClapAudioModelWithProjection, read_audio } from '@xenova/transformers'; + * + * // Load processor and audio model + * const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused'); + * const audio_model = await ClapAudioModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused'); + * + * // Read audio and run processor + * const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav'); + * const audio_inputs = await processor(audio); + * + * // Compute embeddings + * const { audio_embeds } = await audio_model(audio_inputs); + * // Tensor { + * // dims: [ 1, 512 ], + * // type: 'float32', + * // data: Float32Array(512) [ ... ], + * // size: 512 + * // } + * ``` + */ +export class ClapAudioModelWithProjection extends ClapPreTrainedModel { + /** @type {PreTrainedModel.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options = {}) { + // Update default model file name if not provided + options.model_file_name ??= 'audio_model'; + return super.from_pretrained(pretrained_model_name_or_path, options); + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels // (uses config to instantiate correct class) @@ -3980,22 +4241,29 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['roberta', ['RobertaModel', RobertaModel]], ['xlm', ['XLMModel', XLMModel]], ['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]], + ['clap', ['ClapModel', ClapModel]], ['clip', ['CLIPModel', CLIPModel]], ['mobilebert', ['MobileBertModel', MobileBertModel]], ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]], ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]], ['wavlm', ['WavLMModel', WavLMModel]], + ['audio-spectrogram-transformer', ['ASTModel', ASTModel]], ['detr', ['DetrModel', DetrModel]], ['vit', ['ViTModel', ViTModel]], ['mobilevit', ['MobileViTModel', MobileViTModel]], + ['owlvit', ['OwlViTModel', OwlViTModel]], ['beit', ['BeitModel', BeitModel]], ['deit', ['DeiTModel', DeiTModel]], + ['convnext', ['ConvNextModel', ConvNextModel]], + ['convnextv2', ['ConvNextV2Model', ConvNextV2Model]], ['resnet', ['ResNetModel', ResNetModel]], ['swin', ['SwinModel', SwinModel]], ['swin2sr', ['Swin2SRModel', Swin2SRModel]], ['donut-swin', ['DonutSwinModel', DonutSwinModel]], ['yolos', ['YolosModel', YolosModel]], + ['dpt', ['DPTModel', DPTModel]], + ['glpn', ['GLPNModel', GLPNModel]], ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]], @@ -4141,6 +4409,8 @@ const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ ['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]], ['beit', ['BeitForImageClassification', BeitForImageClassification]], ['deit', ['DeiTForImageClassification', DeiTForImageClassification]], + ['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]], + ['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]], ['resnet', ['ResNetForImageClassification', ResNetForImageClassification]], ['swin', ['SwinForImageClassification', SwinForImageClassification]], ]); @@ -4150,6 +4420,10 @@ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ ['yolos', ['YolosForObjectDetection', YolosForObjectDetection]], ]); +const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([ + ['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]], +]); + const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ ['detr', ['DetrForSegmentation', DetrForSegmentation]], ]); @@ -4166,12 +4440,20 @@ const MODEL_FOR_CTC_MAPPING_NAMES = new Map([ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]], ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], -]); + ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], +]); + + const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([ ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]], ]) +const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([ + ['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]], + ['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]], +]) + const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly], @@ -4188,7 +4470,9 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -4207,6 +4491,9 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { const CUSTOM_MAPPING = [ ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly], + + ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], + ['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly], ] for (const [name, model, type] of CUSTOM_MAPPING) { MODEL_TYPE_MAPPING.set(name, type); @@ -4359,6 +4646,11 @@ export class AutoModelForObjectDetection extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]; } +export class AutoModelForZeroShotObjectDetection extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]; +} + + /** * Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. @@ -4386,6 +4678,10 @@ export class AutoModelForImageToImage extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]; } +export class AutoModelForDepthEstimation extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]; +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/src/pipelines.js b/src/pipelines.js index d61ba25dc..7705cc190 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -33,8 +33,10 @@ import { AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForObjectDetection, + AutoModelForZeroShotObjectDetection, AutoModelForDocumentQuestionAnswering, AutoModelForImageToImage, + AutoModelForDepthEstimation, // AutoModelForTextToWaveform, PreTrainedModel, } from './models.js'; @@ -50,6 +52,7 @@ import { dispatchCallback, pop, product, + get_bounding_box, } from './utils/core.js'; import { softmax, @@ -63,6 +66,7 @@ import { import { Tensor, mean_pooling, + interpolate, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; @@ -950,7 +954,7 @@ export class FeatureExtractionPipeline extends Pipeline { * Audio classification pipeline using any `AutoModelForAudioClassification`. * This pipeline predicts the class of a raw waveform or an audio file. * - * **Example:** Perform audio classification. + * **Example:** Perform audio classification with `Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech`. * ```javascript * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; * let classifier = await pipeline('audio-classification', 'Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech'); @@ -960,6 +964,19 @@ export class FeatureExtractionPipeline extends Pipeline { * // { label: 'female', score: 0.001845747814513743 } * // ] * ``` + * + * **Example:** Perform audio classification with `Xenova/ast-finetuned-audioset-10-10-0.4593` and return top 4 results. + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav'; + * let classifier = await pipeline('audio-classification', 'Xenova/ast-finetuned-audioset-10-10-0.4593'); + * let output = await classifier(url, { topk: 4 }); + * // [ + * // { label: 'Meow', score: 0.5617874264717102 }, + * // { label: 'Cat', score: 0.22365376353263855 }, + * // { label: 'Domestic animals, pets', score: 0.1141069084405899 }, + * // { label: 'Animal', score: 0.08985692262649536 }, + * // ] + * ``` */ export class AudioClassificationPipeline extends Pipeline { @@ -1035,6 +1052,105 @@ export class AudioClassificationPipeline extends Pipeline { } } +/** + * Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you + * provide an audio and a set of `candidate_labels`. + * + * **Example**: Perform zero-shot audio classification with `Xenova/clap-htsat-unfused`. + * ```javascript + * let classifier = await pipeline('zero-shot-audio-classification', 'Xenova/clap-htsat-unfused'); + * let audio = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/dog_barking.wav'; + * let candidate_labels = ['dog', 'vaccum cleaner']; + * let scores = await classifier(audio, candidate_labels); + * // [ + * // { score: 0.9993992447853088, label: 'dog' }, + * // { score: 0.0006007603369653225, label: 'vaccum cleaner' } + * // ] + * ``` + */ +export class ZeroShotAudioClassificationPipeline extends Pipeline { + + /** + * Create a new ZeroShotAudioClassificationPipeline. + * @param {Object} options An object containing the following properties: + * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks. + * @param {PreTrainedModel} [options.model] The model to use. + * @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use. + * @param {Processor} [options.processor] The processor to use. + */ + constructor(options) { + super(options); + } + + /** + * Preprocesses the input audio for the ZeroShotAudioClassificationPipeline. + * @param {any} audio The audio to be preprocessed. + * @param {number} sampling_rate The sampling rate of the audio. + * @returns {Promise} A promise that resolves to the preprocessed audio data. + * @private + */ + async _preprocess(audio, sampling_rate) { + if (isString(audio)) { + audio = await read_audio(audio, sampling_rate); + } + + return audio; + } + + /** + * Assign labels to the audio(s) passed as inputs. + * @param {Array} audios The input audios. + * @param {string[]} candidate_labels The candidate labels for this audio + * @param {Object} options The options for the classification. + * @param {string} [options.hypothesis_template] The sentence used in cunjunction with *candidate_labels* to attempt + * the audio classification by replacing the placeholder with the candidate_labels. + * Then likelihood is estimated by using logits_per_audio. + * @returns {Promise} + */ + async _call(audios, candidate_labels, { + hypothesis_template = "This is a sound of {}." + } = {}) { + const single = !Array.isArray(audios); + if (single) { + // @ts-ignore + audios = [audios]; + } + + // Insert label into hypothesis template + const texts = candidate_labels.map( + x => hypothesis_template.replace('{}', x) + ); + + // Run tokenization + const text_inputs = this.tokenizer(texts, { + padding: true, + truncation: true, + }); + + const sampling_rate = this.processor.feature_extractor.config.sampling_rate; + + const toReturn = []; + for (let audio of audios) { + audio = await this._preprocess(audio, sampling_rate) + + const audio_inputs = await this.processor(audio); + + // Run model with both text and audio inputs + const output = await this.model({ ...text_inputs, ...audio_inputs }); + + // Compute softmax per audio + const probs = softmax(output.logits_per_audio.data); + + toReturn.push([...probs].map((x, i) => { + return { + score: x, + label: candidate_labels[i] + } + })); + } + return !single ? toReturn : toReturn[0]; + } +} /** * Pipeline that aims at extracting spoken text contained within some audio. @@ -1751,28 +1867,148 @@ export class ObjectDetectionPipeline extends Pipeline { return { score: batch.scores[i], label: id2label[batch.classes[i]], - box: this._get_bounding_box(box, !percentage), + box: get_bounding_box(box, !percentage), } }) }) return isBatched ? result : result[0]; } +} + +/** + * Zero-shot object detection pipeline. This pipeline predicts bounding boxes of + * objects when you provide an image and a set of `candidate_labels`. + * + * **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32`. + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png'; + * let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag']; + * let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); + * let output = await detector(url, candidate_labels); + * // [ + * // { + * // score: 0.24392342567443848, + * // label: 'human face', + * // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } + * // }, + * // { + * // score: 0.15129457414150238, + * // label: 'american flag', + * // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 } + * // }, + * // { + * // score: 0.13649864494800568, + * // label: 'helmet', + * // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 } + * // }, + * // { + * // score: 0.10262022167444229, + * // label: 'rocket', + * // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 } + * // } + * // ] + * ``` + * + * **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32` (returning top 4 matches and setting a threshold). + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png'; + * let candidate_labels = ['hat', 'book', 'sunglasses', 'camera']; + * let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); + * let output = await detector(url, candidate_labels, { topk: 4, threshold: 0.05 }); + * // [ + * // { + * // score: 0.1606510728597641, + * // label: 'sunglasses', + * // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } + * // }, + * // { + * // score: 0.08935828506946564, + * // label: 'hat', + * // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 } + * // }, + * // { + * // score: 0.08530698716640472, + * // label: 'camera', + * // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 } + * // }, + * // { + * // score: 0.08349756896495819, + * // label: 'book', + * // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 } + * // } + * // ] + * ``` + */ +export class ZeroShotObjectDetectionPipeline extends Pipeline { /** - * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... } - * @param {number[]} box The bounding box as a list. - * @param {boolean} asInteger Whether to cast to integers. - * @returns {Object} The bounding box as an object. - * @private + * Create a new ZeroShotObjectDetectionPipeline. + * @param {Object} options An object containing the following properties: + * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks. + * @param {PreTrainedModel} [options.model] The model to use. + * @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use. + * @param {Processor} [options.processor] The processor to use. */ - _get_bounding_box(box, asInteger) { - if (asInteger) { - box = box.map(x => x | 0); + constructor(options) { + super(options); + } + + /** + * Detect objects (bounding boxes & classes) in the image(s) passed as inputs. + * @param {Array} images The input images. + * @param {string[]} candidate_labels What the model should recognize in the image. + * @param {Object} options The options for the classification. + * @param {number} [options.threshold] The probability necessary to make a prediction. + * @param {number} [options.topk] The number of top predictions that will be returned by the pipeline. + * If the provided number is `null` or higher than the number of predictions available, it will default + * to the number of predictions. + * @param {boolean} [options.percentage=false] Whether to return the boxes coordinates in percentage (true) or in pixels (false). + * @returns {Promise} An array of classifications for each input image or a single classification object if only one input image is provided. + */ + async _call(images, candidate_labels, { + threshold = 0.1, + topk = null, + percentage = false, + } = {}) { + const isBatched = Array.isArray(images); + images = await prepareImages(images); + + // Run tokenization + const text_inputs = this.tokenizer(candidate_labels, { + padding: true, + truncation: true + }); + + // Run processor + const model_inputs = await this.processor(images); + + // Since non-maximum suppression is performed for exporting, we need to + // process each image separately. For more information, see: + // https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032 + const toReturn = []; + for (let i = 0; i < images.length; ++i) { + const image = images[i]; + const imageSize = [[image.height, image.width]]; + const pixel_values = model_inputs.pixel_values[i].unsqueeze_(0); + + // Run model with both text and pixel inputs + const output = await this.model({ ...text_inputs, pixel_values }); + + // @ts-ignore + const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSize, true)[0]; + let result = processed.boxes.map((box, i) => ({ + score: processed.scores[i], + label: candidate_labels[processed.classes[i]], + box: get_bounding_box(box, !percentage), + })).sort((a, b) => b.score - a.score); + if (topk !== null) { + result = result.slice(0, topk); + } + toReturn.push(result) } - const [xmin, ymin, xmax, ymax] = box; - return { xmin, ymin, xmax, ymax }; + return isBatched ? toReturn : toReturn[0]; } } @@ -1984,6 +2220,56 @@ export class ImageToImagePipeline extends Pipeline { } } +/** + * Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image. + * + * **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas` + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg'; + * let depth_estimator = await pipeline('depth-estimation', 'Xenova/dpt-hybrid-midas'); + * let out = await depth_estimator(url); + * // { + * // predicted_depth: Tensor { + * // dims: [ 384, 384 ], + * // type: 'float32', + * // data: Float32Array(147456) [ 542.859130859375, 545.2833862304688, 546.1649169921875, ... ], + * // size: 147456 + * // }, + * // depth: RawImage { + * // data: Uint8Array(307200) [ 86, 86, 86, ... ], + * // width: 640, + * // height: 480, + * // channels: 1 + * // } + * // } + * ``` + */ +export class DepthEstimationPipeline extends Pipeline { + /** + * Predicts the depth for the image(s) passed as inputs. + * @param {any} images The images to compute depth for. + * @returns {Promise} An image or a list of images containing result(s). + */ + async _call(images) { + images = await prepareImages(images); + + const inputs = await this.processor(images); + const { predicted_depth } = await this.model(inputs); + + const toReturn = []; + for (let i = 0; i < images.length; ++i) { + const prediction = interpolate(predicted_depth[i], images[i].size.reverse(), 'bilinear', false); + const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8'); + toReturn.push({ + predicted_depth: predicted_depth[i], + depth: RawImage.fromTensor(formatted), + }); + } + + return toReturn.length > 1 ? toReturn : toReturn[0]; + } +} + const SUPPORTED_TASKS = { "text-classification": { "tokenizer": AutoTokenizer, @@ -2096,6 +2382,18 @@ const SUPPORTED_TASKS = { }, "type": "audio", }, + "zero-shot-audio-classification": { + "tokenizer": AutoTokenizer, + "pipeline": ZeroShotAudioClassificationPipeline, + "model": AutoModel, + "processor": AutoProcessor, + "default": { + // TODO: replace with original + // "model": "laion/clap-htsat-fused", + "model": "Xenova/clap-htsat-unfused", + }, + "type": "multimodal", + }, "automatic-speech-recognition": { "tokenizer": AutoTokenizer, "pipeline": AutomaticSpeechRecognitionPipeline, @@ -2185,6 +2483,18 @@ const SUPPORTED_TASKS = { }, "type": "multimodal", }, + "zero-shot-object-detection": { + "tokenizer": AutoTokenizer, + "pipeline": ZeroShotObjectDetectionPipeline, + "model": AutoModelForZeroShotObjectDetection, + "processor": AutoProcessor, + "default": { + // TODO: replace with original + // "model": "google/owlvit-base-patch32", + "model": "Xenova/owlvit-base-patch32", + }, + "type": "multimodal", + }, "document-question-answering": { "tokenizer": AutoTokenizer, "pipeline": DocumentQuestionAnsweringPipeline, @@ -2209,6 +2519,18 @@ const SUPPORTED_TASKS = { }, "type": "image", }, + "depth-estimation": { + // no tokenizer + "pipeline": DepthEstimationPipeline, + "model": AutoModelForDepthEstimation, + "processor": AutoProcessor, + "default": { + // TODO: replace with original + // "model": "Intel/dpt-large", + "model": "Xenova/dpt-large", + }, + "type": "image", + }, // This task serves as a useful interface for dealing with sentence-transformers (https://huggingface.co/sentence-transformers). "feature-extraction": { @@ -2242,6 +2564,7 @@ const TASK_ALIASES = { * @param {string} task The task defining which pipeline will be returned. Currently accepted tasks are: * - `"audio-classification"`: will return a `AudioClassificationPipeline`. * - `"automatic-speech-recognition"`: will return a `AutomaticSpeechRecognitionPipeline`. + * - `"depth-estimation"`: will return a `DepthEstimationPipeline`. * - `"document-question-answering"`: will return a `DocumentQuestionAnsweringPipeline`. * - `"feature-extraction"`: will return a `FeatureExtractionPipeline`. * - `"fill-mask"`: will return a `FillMaskPipeline`. @@ -2258,7 +2581,9 @@ const TASK_ALIASES = { * - `"translation"`: will return a `TranslationPipeline`. * - `"translation_xx_to_yy"`: will return a `TranslationPipeline`. * - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`. + * - `"zero-shot-audio-classification"`: will return a `ZeroShotAudioClassificationPipeline`. * - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`. + * - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`. * @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used. * @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline. * @returns {Promise} A Pipeline object for the specified task. diff --git a/src/processors.js b/src/processors.js index 7691cb413..a6f1351d1 100644 --- a/src/processors.js +++ b/src/processors.js @@ -30,16 +30,20 @@ import { } from './utils/hub.js'; import { + min, max, softmax, - FFT, } from './utils/maths.js'; import { Tensor, transpose, cat, interpolate } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; -import { getMelFilters } from './utils/audio.js'; +import { + window_function, + spectrogram, + mel_filter_bank, +} from './utils/audio.js'; // Helper functions @@ -64,9 +68,13 @@ function center_to_corners_format([centerX, centerY, width, height]) { * @param {Object} outputs The outputs of the model that must be post-processed * @param {Tensor} outputs.logits The logits * @param {Tensor} outputs.pred_boxes The predicted boxes. + * @param {number} [threshold=0.5] The threshold to use for the scores. + * @param {number[][]} [target_sizes=null] The sizes of the original images. + * @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed. * @return {Object[]} An array of objects containing the post-processed outputs. + * @private */ -function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null) { +function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) { const out_logits = outputs.logits; const out_bbox = outputs.pred_boxes; const [batch_size, num_boxes, num_classes] = out_logits.dims; @@ -88,19 +96,33 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = for (let j = 0; j < num_boxes; ++j) { let logit = logits[j]; - // Get most probable class - let maxIndex = max(logit.data)[1]; + let indices = []; + let probs; + if (is_zero_shot) { + // Get indices of classes with high enough probability + probs = logit.sigmoid().data; + for (let k = 0; k < probs.length; ++k) { + if (probs[k] > threshold) { + indices.push(k); + } + } - if (maxIndex === num_classes - 1) { - // This is the background class, skip it - continue; + } else { + // Get most probable class + let maxIndex = max(logit.data)[1]; + + if (maxIndex === num_classes - 1) { + // This is the background class, skip it + continue; + } + indices.push(maxIndex); + + // Compute softmax over classes + probs = softmax(logit.data); } - // Compute softmax over classes - let probs = softmax(logit.data); + for (const index of indices) { - let score = probs[maxIndex]; - if (score > threshold) { // Some class has a high enough probability /** @type {number[]} */ let box = bbox[j].data; @@ -112,8 +134,8 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = } info.boxes.push(box); - info.classes.push(maxIndex); - info.scores.push(score); + info.classes.push(index); + info.scores.push(probs[index]); } } toReturn.push(info); @@ -127,6 +149,21 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = * @typedef {[height: number, width: number]} HeightWidth */ +/** + * Helper function to validate audio inputs. + * @param {any} audio The audio data. + * @param {string} feature_extractor The name of the feature extractor. + * @private + */ +function validate_audio_inputs(audio, feature_extractor) { + if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { + throw new Error( + `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` + + `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` + ) + } +} + /** * Base class for feature extractors. * @@ -185,10 +222,12 @@ export class ImageFeatureExtractor extends FeatureExtractor { this.do_resize = this.config.do_resize; this.do_thumbnail = this.config.do_thumbnail; this.size = this.config.size; + this.size_divisor = this.config.size_divisor; this.do_center_crop = this.config.do_center_crop; this.crop_size = this.config.crop_size; this.do_convert_rgb = this.config.do_convert_rgb ?? true; + this.do_crop_margin = this.config.do_crop_margin; this.pad_size = this.config.pad_size; this.do_pad = this.config.do_pad; @@ -231,6 +270,44 @@ export class ImageFeatureExtractor extends FeatureExtractor { } + /** + * Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold). + * @param {RawImage} image The image to be cropped. + * @param {number} gray_threshold Value below which pixels are considered to be gray. + * @returns {Promise} The cropped image. + */ + async crop_margin(image, gray_threshold = 200) { + + const gray_image = image.clone().grayscale(); + + const minValue = min(gray_image.data)[0]; + const maxValue = max(gray_image.data)[0]; + const diff = maxValue - minValue; + + if (diff === 0) { + return image; + } + + const threshold = gray_threshold / 255; + + let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0; + for (let j = 0; j < gray_image.height; ++j) { + const row = j * gray_image.width; + for (let i = 0; i < gray_image.width; ++i) { + if ((gray_image.data[row + i] - minValue) / diff < threshold) { + // We have a non-zero pixel, so we update the min/max values accordingly + x_min = Math.min(x_min, i); + y_min = Math.min(y_min, j); + x_max = Math.max(x_max, i); + y_max = Math.max(y_max, j); + } + } + } + + image = await image.crop([x_min, y_min, x_max, y_max]); + return image; + } + /** * Pad the image by a certain amount. * @param {Float32Array} pixelData The pixel data to pad. @@ -261,7 +338,12 @@ export class ImageFeatureExtractor extends FeatureExtractor { // Only add padding if there is a difference in size if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) { const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels); - if (constant_values !== 0) { + if (Array.isArray(constant_values)) { + // Fill with constant values, cycling through the array + for (let i = 0; i < paddedPixelData.length; ++i) { + paddedPixelData[i] = constant_values[i % imageChannels]; + } + } else if (constant_values !== 0) { paddedPixelData.fill(constant_values); } @@ -329,15 +411,21 @@ export class ImageFeatureExtractor extends FeatureExtractor { */ async preprocess(image) { - // First, convert image to RGB if specified in config. - if (this.do_convert_rgb) { - image = image.rgb(); + if (this.do_crop_margin) { + // NOTE: Specific to nougat processors. This is done before resizing, + // and can be interpreted as a pre-preprocessing step. + image = await this.crop_margin(image); } const srcWidth = image.width; // original width const srcHeight = image.height; // original height - // Next, resize all images + // Convert image to RGB if specified in config. + if (this.do_convert_rgb) { + image = image.rgb(); + } + + // Resize all images if (this.do_resize) { // TODO: // For efficiency reasons, it might be best to merge the resize and center crop operations into one. @@ -358,7 +446,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { shortest_edge = this.size; longest_edge = this.config.max_size ?? shortest_edge; - } else { + } else if (this.size !== undefined) { // Extract known properties from `this.size` shortest_edge = this.size.shortest_edge; longest_edge = this.size.longest_edge; @@ -391,11 +479,20 @@ export class ImageFeatureExtractor extends FeatureExtractor { resample: this.resample, }); - } else if (this.size.width !== undefined && this.size.height !== undefined) { + } else if (this.size !== undefined && this.size.width !== undefined && this.size.height !== undefined) { // If `width` and `height` are set, resize to those dimensions image = await image.resize(this.size.width, this.size.height, { resample: this.resample, }); + + } else if (this.size_divisor !== undefined) { + // Rounds the height and width down to the closest multiple of size_divisor + const newWidth = Math.floor(srcWidth / this.size_divisor) * this.size_divisor; + const newHeight = Math.floor(srcHeight / this.size_divisor) * this.size_divisor; + image = await image.resize(newWidth, newHeight, { + resample: this.resample, + }); + } else { throw new Error(`Could not resize image due to unsupported \`this.size\` option in config: ${JSON.stringify(this.size)}`); } @@ -509,24 +606,48 @@ export class ImageFeatureExtractor extends FeatureExtractor { } +export class DPTFeatureExtractor extends ImageFeatureExtractor { } +export class GLPNFeatureExtractor extends ImageFeatureExtractor { } +export class CLIPFeatureExtractor extends ImageFeatureExtractor { } export class ConvNextFeatureExtractor extends ImageFeatureExtractor { } +export class ConvNextImageProcessor extends ConvNextFeatureExtractor { } // NOTE extends ConvNextFeatureExtractor export class ViTFeatureExtractor extends ImageFeatureExtractor { } export class MobileViTFeatureExtractor extends ImageFeatureExtractor { } +export class OwlViTFeatureExtractor extends ImageFeatureExtractor { + /** @type {post_process_object_detection} */ + post_process_object_detection(...args) { + return post_process_object_detection(...args); + } +} export class DeiTFeatureExtractor extends ImageFeatureExtractor { } export class BeitFeatureExtractor extends ImageFeatureExtractor { } export class DonutFeatureExtractor extends ImageFeatureExtractor { pad_image(pixelData, imgDims, padSize, options = {}) { + const [imageWidth, imageHeight, imageChannels] = imgDims; + + let image_mean = this.image_mean; + if (!Array.isArray(this.image_mean)) { + image_mean = new Array(imageChannels).fill(image_mean); + } + + let image_std = this.image_std; + if (!Array.isArray(this.image_std)) { + image_std = new Array(imageChannels).fill(image_mean); + } + + const constant_values = image_mean.map((x, i) => - x / this.image_std[i]); + return super.pad_image(pixelData, imgDims, padSize, { center: true, - // Since normalization is done after padding, we need to pad with -1. - // NOTE: This only works if `image_mean = 0.5` and `image_std = 0.5`. + // Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed. // For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451 - constant_values: -1, + constant_values: constant_values, ...options, }); } } +export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor /** * @typedef {object} DetrFeatureExtractorResultProps @@ -1024,232 +1145,24 @@ export class Swin2SRImageProcessor extends ImageFeatureExtractor { } } + export class WhisperFeatureExtractor extends FeatureExtractor { constructor(config) { super(config); // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. - this.config.mel_filters ??= getMelFilters(this.config.sampling_rate, this.config.n_fft, this.config.feature_size); - } - - - /** - * Pads an array with a reflected version of itself on both ends. - * @param {Float32Array} array The array to pad. - * @param {number} left The amount of padding to add to the left. - * @param {number} right The amount of padding to add to the right. - * @returns {Float32Array} The padded array. - */ - padReflect(array, left, right) { - const padded = new Float32Array(array.length + left + right); - const w = array.length - 1; - - for (let i = 0; i < array.length; ++i) { - padded[left + i] = array[i]; - } - - for (let i = 1; i <= left; ++i) { - padded[left - i] = array[calculateReflectOffset(i, w)]; - } - - for (let i = 1; i <= right; ++i) { - padded[w + left + i] = array[calculateReflectOffset(w - i, w)]; - } - - return padded; - } - - /** - * Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. - * - * @param {number[][]} frames A 2D array representing the signal frames. - * @param {number[]} window A 1D array representing the window to be applied to the frames. - * @returns {Object} An object with the following properties: - * - data: A 1D array representing the complex STFT of the signal. - * - dims: An array representing the dimensions of the STFT data, i.e. [num_frames, num_fft_bins]. - */ - stft(frames, window) { - // Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. - // - // NOTE: Since the window width is not a power of 2, we must - // perform Fast Fourier Transform with chirp-z transform: - // https://math.stackexchange.com/questions/77118/non-power-of-2-ffts/77156#77156 - - // Helper variables - const fft_size = this.config.n_fft; - const a = 2 * (fft_size - 1); - const b = 2 * (2 * fft_size - 1); - const nextP2 = 2 ** (Math.ceil(Math.log2(b))) - const num_fft_bins = fft_size + 2; - - // Preallocate array to store output - // double since we store complex numbers - const data = new Float32Array(num_fft_bins * frames.length); - - // Define buffers - // Compute chirp for transform - const chirp = new Float32Array(b); - const ichirp = new Float32Array(nextP2); - const buffer1 = new Float32Array(nextP2); - const buffer2 = new Float32Array(nextP2); - const outBuffer = new Float32Array(nextP2); - const outBuffer2 = new Float32Array(nextP2); - const outBuffer3 = new Float32Array(nextP2); - - // Compute complex exponentiation - const theta = -2 * Math.PI / fft_size; - const baseR = Math.cos(theta); - const baseI = Math.sin(theta); - - // Precompute helper for chirp-z transform - for (let i = 0; i < b >> 1; ++i) { - // Compute complex power: - const e = (i + 1 - fft_size) ** 2 / 2.0; - - // Compute the modulus and argument of the result - const result_mod = Math.sqrt(baseR ** 2 + baseI ** 2) ** e; - const result_arg = e * Math.atan2(baseI, baseR); - - // Convert the result back to rectangular form - // and assign to chirp and ichirp - let i2 = 2 * i; - chirp[i2] = result_mod * Math.cos(result_arg); - chirp[i2 + 1] = result_mod * Math.sin(result_arg); - - // conjugate - ichirp[i2] = chirp[i2]; - ichirp[i2 + 1] = - chirp[i2 + 1]; - } - const slicedChirp = chirp.subarray(a, b); - - // create object to perform Fast Fourier Transforms - // with `nextP2` complex numbers - const f = new FFT(nextP2 >> 1); - // TODO: decide between Float32Array and Float64Array - f.transform(outBuffer, ichirp); - - for (let i = 0; i < frames.length; ++i) { - const frame = frames[i]; - - for (let j = 0; j < slicedChirp.length; j += 2) { - const j2 = j + 1 - const j3 = j >> 1; - - const a_real = frame[j3] * window[j3]; - buffer1[j] = a_real * slicedChirp[j]; - buffer1[j2] = a_real * slicedChirp[j2]; - } - // TODO: decide between Float32Array and Float64Array - f.transform(outBuffer2, buffer1); - - for (let j = 0; j < outBuffer.length; j += 2) { - const j2 = j + 1; - - buffer2[j] = outBuffer2[j] * outBuffer[j] - outBuffer2[j2] * outBuffer[j2] - buffer2[j2] = outBuffer2[j] * outBuffer[j2] + outBuffer2[j2] * outBuffer[j] - } - // TODO: decide between Float32Array and Float64Array - f.inverseTransform(outBuffer3, buffer2) - - const offset = i * num_fft_bins; - for (let j = 0; j < num_fft_bins; j += 2) { - const a_real = outBuffer3[j + a]; - const a_imag = outBuffer3[j + a + 1]; - const b_real = slicedChirp[j]; - const b_imag = slicedChirp[j + 1]; - - // TODO write as transpose - const o1 = offset + j; - data[o1] = a_real * b_real - a_imag * b_imag - data[o1 + 1] = a_real * b_imag + a_imag * b_real - } - } - - return { - data: data, - dims: [frames.length, num_fft_bins] // [3001, 402] - }; - } - - /** - * Creates an array of frames from a given waveform. - * - * @param {Float32Array} waveform The waveform to create frames from. - * @param {boolean} [center=true] Whether to center the frames on their corresponding positions in the waveform. Defaults to true. - * @returns {Array} An array of frames. - */ - fram_wave(waveform, center = true) { - const frames = []; - const half_window = Math.floor((this.config.n_fft - 1) / 2) + 1; - const waveformLength = waveform.length; - - for (let i = 0; i < waveformLength + 1; i += this.config.hop_length) { - - let frame; - if (center) { - - let frameStart = i > half_window ? i - half_window : 0; - let frameEnd = - i < waveformLength - half_window - ? i + half_window - : waveformLength; - - frame = waveform.subarray(frameStart, frameEnd) - - if (frameStart === 0) { - frame = this.padReflect( - frame, - -i + half_window, - 0 - ) - - } else if (frameEnd === waveformLength) { - frame = this.padReflect( - frame, - 0, - i - waveformLength + half_window - ) - } - - } else { - frame = new Float32Array(this.config.n_fft); - const frameArray = waveform.subarray(i, i + this.config.n_fft); - - if (frameArray.length < this.config.n_fft) { - frame.set(frameArray); - frame.fill(0, frameArray.length, this.config.n_fft) - } else { - frame = frameArray; - } - - } - frames.push(frame); - } + this.config.mel_filters ??= mel_filter_bank( + Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins + this.config.feature_size, // num_mel_filters + 0.0, // min_frequency + 8000.0, // max_frequency + this.config.sampling_rate, // sampling_rate + "slaney", // norm + "slaney", // mel_scale + ); - return frames; - } - - /** - * Generates a Hanning window of length M. - * - * @param {number} M The length of the Hanning window to generate. - * @returns {*} The generated Hanning window. - */ - hanning(M) { - if (M < 1) { - return []; - } - if (M === 1) { - return [1]; - } - const denom = M - 1; - const cos_vals = new Float32Array(denom); - for (let i = 0; i < denom; ++i) { - const n = 2 * i - M + 1; - cos_vals[i] = 0.5 + 0.5 * Math.cos(Math.PI * n / denom); - } - return cos_vals; + this.window = window_function(this.config.n_fft, 'hann'); } /** @@ -1258,80 +1171,28 @@ export class WhisperFeatureExtractor extends FeatureExtractor { * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. */ _extract_fbank_features(waveform) { - // Compute the log-Mel spectrogram of the provided audio - - const buffer = new Float32Array(this.config.n_samples); - buffer.set(waveform) - - const window = this.hanning(this.config.n_fft + 1) - const frames = this.fram_wave(buffer) - - const stft = this.stft(frames, window) - - const stftData = stft.data; - const d1 = stft.dims[0] - 1; // Ignore last row - const d2 = stft.dims[1] >> 1; // Only need to store real numbers now - - // compute magnitudes - // NOTE: Unlike the original implementation, we do not - // transpose since we perform matrix multiplication later - const magnitudes = new Float32Array(d1 * d2); - for (let i = 0; i < d1; ++i) { - for (let j = 0; j < d2; ++j) { - // let outOffset = (j * d1 + i); // transpose - let outOffset = i * d2 + j; - let inOffset = outOffset << 1; // * 2 since complex - let magnitude = stftData[inOffset] ** 2 + stftData[inOffset + 1] ** 2 - magnitudes[outOffset] = magnitude; + const { data, dims } = spectrogram( + waveform, + this.window, // window + this.config.n_fft, // frame_length + this.config.hop_length, // hop_length + { + power: 2.0, + mel_filters: this.config.mel_filters, + log_mel: 'log10', + + // Custom + max_num_frames: this.config.nb_max_frames, // 3000 } - } - - const mel_filters = this.config.mel_filters; - const num_mel_filters = mel_filters.length; - - const mel_spec = new Float32Array(num_mel_filters * d1); - let mIndex = 0; - - // Perform matrix muliplication: - // mel_spec = filters @ magnitudes - // - filters.shape=(80, 201) - // - magnitudes.shape=(201, 3000) - // - mel_spec.shape=(80, 3000) - for (let i = 0; i < num_mel_filters; ++i) { - const mel_filter = mel_filters[i]; - - for (let j = 0; j < d1; ++j) { - let sum = 0; - - // perform dot product - for (let k = 0; k < d2; ++k) { - sum += mel_filter[k] * magnitudes[j * d2 + k]; - } - - mel_spec[mIndex++] = sum; - } - } - - const a_min = 1e-10; - const log_spec = new Float32Array(mel_spec.length); + ) - let maxLogSpec = 0; - for (let i = 0; i < mel_spec.length; ++i) { - const clipped = Math.max(a_min, mel_spec[i]); - const log10 = Math.log10(clipped); - log_spec[i] = log10; - maxLogSpec = Math.max(log10, maxLogSpec) - } + const maxValue = max(data)[0]; - for (let i = 0; i < log_spec.length; ++i) { - log_spec[i] = Math.max(log_spec[i], maxLogSpec - 8); - log_spec[i] = (log_spec[i] + 4) / 4; + for (let i = 0; i < data.length; ++i) { + data[i] = (Math.max(data[i], maxValue - 8.0) + 4.0) / 4.0; } - return { - data: log_spec, - dims: [num_mel_filters, d1] - }; + return { data, dims }; } /** @@ -1340,29 +1201,28 @@ export class WhisperFeatureExtractor extends FeatureExtractor { * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. */ async _call(audio) { - if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { - throw new Error( - // @ts-ignore - `WhisperFeatureExtractor expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` + - `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` - ) - } + validate_audio_inputs(audio, 'WhisperFeatureExtractor'); + let waveform; if (audio.length > this.config.n_samples) { console.warn( "Attempting to extract features for audio longer than 30 seconds. " + "If using a pipeline to extract transcript from a long audio clip, " + "remember to specify `chunk_length_s` and/or `stride_length_s`." ); + waveform = audio.slice(0, this.config.n_samples); + } else { + // pad with zeros + waveform = new Float32Array(this.config.n_samples); + waveform.set(audio); } - let waveform = audio.slice(0, this.config.n_samples); - let features = this._extract_fbank_features(waveform); + const { data, dims } = this._extract_fbank_features(waveform); return { input_features: new Tensor('float32', - features.data, - [1, ...features.dims] + data, + [1, ...dims] ) }; } @@ -1388,14 +1248,8 @@ export class Wav2Vec2FeatureExtractor extends FeatureExtractor { * @returns {Promise<{ input_values: Tensor; attention_mask: Tensor }>} A Promise resolving to an object containing the extracted input features and attention mask as Tensors. */ async _call(audio) { - // TODO: remove duplication - if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { - throw new Error( - // @ts-ignore - `Wav2Vec2FeatureExtractor expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` + - `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` - ) - } + validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor'); + if (audio instanceof Float64Array) { audio = new Float32Array(audio); } @@ -1416,6 +1270,260 @@ export class Wav2Vec2FeatureExtractor extends FeatureExtractor { } } +export class ASTFeatureExtractor extends FeatureExtractor { + + + constructor(config) { + super(config); + + const sampling_rate = this.config.sampling_rate; + const mel_filters = mel_filter_bank( + 256, // num_frequency_bins + this.config.num_mel_bins, // num_mel_filters + 20, // min_frequency + Math.floor(sampling_rate / 2), // max_frequency + sampling_rate, // sampling_rate + null, // norm + "kaldi", // mel_scale + true, // triangularize_in_mel_space + ); + + // Do padding: + for (let i = 0; i < mel_filters.length; ++i) { + mel_filters[i].push(0); + } + this.mel_filters = mel_filters; + + this.window = window_function(400, 'hann', { + periodic: false, + }) + + this.mean = this.config.mean; + this.std = this.config.std; + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @param {number} max_length The maximum number of frames to return. + * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + _extract_fbank_features(waveform, max_length) { + // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` + return spectrogram( + waveform, + this.window, // window + 400, // frame_length + 160, // hop_length + { + fft_length: 512, + power: 2.0, + center: false, + preemphasis: 0.97, + mel_filters: this.mel_filters, + log_mel: 'log', + mel_floor: 1.192092955078125e-07, + remove_dc_offset: true, + + // Custom + max_num_frames: max_length, + transpose: true, + } + ) + } + + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_values: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. + */ + async _call(audio) { + validate_audio_inputs(audio, 'ASTFeatureExtractor'); + + const features = this._extract_fbank_features(audio, this.config.max_length); + if (this.config.do_normalize) { + // Normalize the input audio spectrogram to have mean=0, std=0.5 + const denom = this.std * 2; + for (let i = 0; i < features.data.length; ++i) { + features.data[i] = (features.data[i] - this.mean) / denom; + } + } + + return { + input_values: new Tensor('float32', + features.data, + [1, ...features.dims] + ) + }; + } +} + +export class ClapFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + this.mel_filters = mel_filter_bank( + this.config.nb_frequency_bins, // num_frequency_bins + this.config.feature_size, // num_mel_filters + this.config.frequency_min, // min_frequency + this.config.frequency_max, // max_frequency + this.config.sampling_rate, // sampling_rate + null, // norm + "htk", // mel_scale + ); + + this.mel_filters_slaney = mel_filter_bank( + this.config.nb_frequency_bins, // num_frequency_bins + this.config.feature_size, // num_mel_filters + this.config.frequency_min, // min_frequency + this.config.frequency_max, // max_frequency + this.config.sampling_rate, // sampling_rate + "slaney", // norm + "slaney", // mel_scale + ); + + this.window = window_function(this.config.fft_window_size, 'hann') + + } + + + /** + * Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. + * + * Four different path are possible: + * - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram + * will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram + * are then stacked together. They will later be used for `feature_fusion`. + * - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is + * padded based on `padding`. + * - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded + * based on `padding`, and is repeated `4` times. + * - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel + * spectrogram will be computed on a random crop of the waveform. + * + * @param {Float32Array|Float64Array} waveform The input waveform. + * @param {number} max_length The maximum length of the waveform. + * @param {string} truncation The truncation strategy to use. + * @param {string} padding The padding strategy to use. + * @returns {{ data: Float32Array; dims: number[]; longer: boolean; }} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length. + */ + _get_input_mel(waveform, max_length, truncation, padding) { + + /** @type {{ data: Float32Array; dims: number[]}} */ + let input_mel; + let longer = false; + const diff = waveform.length - max_length; + if (diff > 0) { + if (truncation === 'rand_trunc') { + longer = true; + const idx = Math.floor(Math.random() * (diff + 1)); + waveform = waveform.subarray(idx, idx + max_length); + + input_mel = this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); + input_mel.dims = [1, ...input_mel.dims]; // "unsqueeze" + } else { + // TODO implement fusion strategy + throw new Error(`Truncation strategy "${truncation}" not implemented`) + } + } else { + if (diff < 0) { + let padded = new Float64Array(max_length); // already padded with zeros + padded.set(waveform); + + if (padding === 'repeat') { + for (let i = waveform.length; i < max_length; i += waveform.length) { + padded.set(waveform.subarray(0, Math.min(waveform.length, max_length - i)), i); + } + } else if (padding === 'repeatpad') { + for (let i = waveform.length; i < -diff; i += waveform.length) { + padded.set(waveform, i); + } + } + waveform = padded; + } + + if (truncation === 'fusion') { + throw new Error(`Truncation strategy "${truncation}" not implemented`) + } + + input_mel = this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); + input_mel.dims = [1, ...input_mel.dims]; // "unsqueeze" + } + + return { + ...input_mel, + longer, + } + } + + /** + * Compute the log-mel spectrogram of the provided `waveform` using the Hann window. + * In CLAP, two different filter banks are used depending on the truncation pattern: + * - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from + * calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` + * is set to `"fusion"`. + * - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used + * `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original + * implementation when the truncation mode is not `"fusion"`. + * + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @param {number[][]} mel_filters The mel filters to use. + * @param {number} [max_length=null] The maximum number of frames to return. + * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + _extract_fbank_features(waveform, mel_filters, max_length = null) { + // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` + return spectrogram( + waveform, + this.window, // window + this.config.fft_window_size, // frame_length + this.config.hop_length, // hop_length + { + power: 2.0, + mel_filters, + log_mel: 'dB', + + // Custom + max_num_frames: max_length, + do_pad: false, + transpose: true, + } + ) + } + + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. + */ + async _call(audio, { + max_length = null, + } = {}) { + validate_audio_inputs(audio, 'ClapFeatureExtractor'); + + // convert to mel spectrogram, truncate and pad if needed. + const padded_inputs = this._get_input_mel( + audio, + max_length ?? this.config.nb_max_samples, + this.config.truncation, + this.config.padding, + ); + + + return { + input_features: new Tensor('float32', + padded_inputs.data, + [1, ...padded_inputs.dims] + ) + }; + } +} + + + export class SpeechT5FeatureExtractor extends FeatureExtractor { } /** @@ -1501,6 +1609,8 @@ export class SpeechT5Processor extends Processor { } } +export class OwlViTProcessor extends Processor { } + ////////////////////////////////////////////////// /** @@ -1538,17 +1648,25 @@ export class AutoProcessor { WhisperFeatureExtractor, ViTFeatureExtractor, MobileViTFeatureExtractor, + OwlViTFeatureExtractor, + CLIPFeatureExtractor, ConvNextFeatureExtractor, + ConvNextImageProcessor, + DPTFeatureExtractor, + GLPNFeatureExtractor, BeitFeatureExtractor, DeiTFeatureExtractor, DetrFeatureExtractor, YolosFeatureExtractor, DonutFeatureExtractor, + NougatImageProcessor, SamImageProcessor, Swin2SRImageProcessor, Wav2Vec2FeatureExtractor, SpeechT5FeatureExtractor, + ASTFeatureExtractor, + ClapFeatureExtractor, } static PROCESSOR_CLASS_MAPPING = { @@ -1556,6 +1674,7 @@ export class AutoProcessor { Wav2Vec2ProcessorWithLM, SamProcessor, SpeechT5Processor, + OwlViTProcessor, } /** diff --git a/src/tokenizers.js b/src/tokenizers.js index 1f7bd9fcd..c0cf7afb8 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -56,20 +56,56 @@ async function loadTokenizer(pretrained_model_name_or_path, options) { return info; } + +/** + * Helper function to split a string on a regex, but keep the delimiters. + * This is required, because the JavaScript `.split()` method does not keep the delimiters, + * and wrapping in a capturing group causes issues with existing capturing groups (due to nesting). + * @param {string} text The text to split. + * @param {RegExp} regex The regex to split on. + * @returns {string[]} The split string. + */ +function regexSplit(text, regex) { + const result = []; + let prev = 0; + for (const match of text.matchAll(regex)) { + const fullMatch = match[0]; + if (prev < match.index) { + result.push(text.slice(prev, match.index)); + } + if (fullMatch.length > 0) { + result.push(fullMatch); + } + prev = match.index + fullMatch.length; + } + if (prev < text.length) { + result.push(text.slice(prev)); + } + return result; +} + + /** * Helper method to construct a pattern from a config object. * @param {Object} pattern The pattern object. - * @param {boolean} invert Whether to invert the pattern (only applicable for Regex patterns). - * @returns {RegExp|string|null} The compiled pattern. + * @param {boolean} invert Whether to invert the pattern. + * @returns {RegExp|null} The compiled pattern. */ function createPattern(pattern, invert = true) { if (pattern.Regex !== undefined) { - // NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split() - return new RegExp(invert ? pattern.Regex : `(${pattern.Regex})`, 'gu'); + // In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~). + // i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed). + // This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used. + // For this reason, it is necessary to remove these backslashes before creating the regex. + // See https://stackoverflow.com/a/63007777/13989043 for more information + const regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary + return new RegExp(regex, 'gu'); } else if (pattern.String !== undefined) { - return pattern.String; + const escaped = escapeRegExp(pattern.String); + // NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split() + return new RegExp(invert ? escaped : `(${escaped})`, 'gu'); } else { console.warn('Unknown pattern type:', pattern) @@ -86,6 +122,26 @@ function objectToMap(obj) { return new Map(Object.entries(obj)); } +/** + * Helper function to convert a tensor to a list before decoding. + * @param {Tensor} tensor The tensor to convert. + * @returns {number[]} The tensor as a list. + */ +function prepareTensorForDecode(tensor) { + const dims = tensor.dims; + switch (dims.length) { + case 1: + return tensor.tolist(); + case 2: + if (dims[0] !== 1) { + throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.'); + } + return tensor.tolist()[0]; + default: + throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`) + } +} + /** * Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms * @param {string} text The text to clean up. @@ -270,6 +326,7 @@ class WordPieceTokenizer extends TokenizerModel { * @param {Object} config.vocab A mapping of tokens to ids. * @param {string} config.unk_token The unknown token string. * @param {string} config.continuing_subword_prefix The prefix to use for continuing subwords. + * @param {number} [config.max_input_chars_per_word=100] The maximum number of characters per word. */ constructor(config) { super(config); @@ -291,6 +348,12 @@ class WordPieceTokenizer extends TokenizerModel { */ this.unk_token = config.unk_token; + /** + * The maximum number of characters allowed per word. + * @type {number} + */ + this.max_input_chars_per_word = config.max_input_chars_per_word ?? 100; + /** * An array of tokens. * @type {string[]} @@ -310,10 +373,10 @@ class WordPieceTokenizer extends TokenizerModel { let outputTokens = []; for (let token of tokens) { let chars = [...token]; - // TODO add - // if len(chars) > self.max_input_chars_per_word: - // output_tokens.append(self.unk_token) - // continue + if (chars.length > this.max_input_chars_per_word) { + outputTokens.push(this.unk_token); + continue; + } let isUnknown = false; let start = 0; @@ -806,6 +869,8 @@ class Normalizer extends Callable { return new Replace(config); case 'NFC': return new NFC(config); + case 'NFKC': + return new NFKC(config); case 'NFKD': return new NFKD(config); case 'Strip': @@ -881,6 +946,21 @@ class NFC extends Normalizer { } } +/** + * NFKC Normalizer. + * @extends Normalizer + */ +class NFKC extends Normalizer { + /** + * Normalize text using NFKC normalization. + * @param {string} text The text to be normalized. + * @returns {string} The normalized text. + */ + normalize(text) { + text = text.normalize('NFKC') + return text; + } +} /** * NFKD Normalizer. * @extends Normalizer @@ -1292,7 +1372,7 @@ class SplitPreTokenizer extends PreTokenizer { if (this.config.invert) { return text.match(this.pattern) || []; } else { - return text.split(this.pattern).filter(x => x); + return regexSplit(text, this.pattern); } } } @@ -2183,6 +2263,9 @@ export class PreTrainedTokenizer extends Callable { this.sep_token = this.getToken(tokenizerConfig, 'sep_token'); this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token); + this.unk_token = this.getToken(tokenizerConfig, 'unk_token'); + this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token); + this.model_max_length = tokenizerConfig.model_max_length; /** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */ @@ -2493,18 +2576,21 @@ export class PreTrainedTokenizer extends Callable { /** * Decode a batch of tokenized sequences. - * @param {number[][]} batch List of tokenized input sequences. + * @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences. * @param {Object} decode_args (Optional) Object with decoding arguments. * @returns {string[]} List of decoded sequences. */ batch_decode(batch, decode_args = {}) { + if (batch instanceof Tensor) { + batch = batch.tolist(); + } return batch.map(x => this.decode(x, decode_args)); } /** * Decodes a sequence of token IDs back to a string. * - * @param {number[]} token_ids List of token IDs to decode. + * @param {number[]|Tensor} token_ids List/Tensor of token IDs to decode. * @param {Object} [decode_args={}] * @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string. * @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed. @@ -2516,6 +2602,10 @@ export class PreTrainedTokenizer extends Callable { token_ids, decode_args = {}, ) { + if (token_ids instanceof Tensor) { + token_ids = prepareTensorForDecode(token_ids); + } + if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) { throw Error("token_ids must be a non-empty array of integers."); } @@ -2571,7 +2661,7 @@ export class PreTrainedTokenizer extends Callable { * @param {Object} inputs An object containing the input ids and attention mask. * @returns {Object} The prepared inputs object. */ -function add_token_types(inputs) { +export function add_token_types(inputs) { // TODO ensure correctness when token pair is present if (inputs.input_ids instanceof Tensor) { inputs.token_type_ids = new Tensor( @@ -3395,6 +3485,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { let text; // @ts-ignore if (decode_args && decode_args.decode_with_timestamps) { + if (token_ids instanceof Tensor) { + token_ids = prepareTensorForDecode(token_ids); + } text = this.decodeWithTimestamps(token_ids, decode_args); } else { text = super.decode(token_ids, decode_args); @@ -3749,6 +3842,8 @@ export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } export class SpeechT5Tokenizer extends PreTrainedTokenizer { } +export class NougatTokenizer extends PreTrainedTokenizer { } + /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. * The chosen tokenizer class is determined by the type specified in the tokenizer config. @@ -3791,6 +3886,7 @@ export class AutoTokenizer { BlenderbotTokenizer, BlenderbotSmallTokenizer, SpeechT5Tokenizer, + NougatTokenizer, // Base case: PreTrainedTokenizer, diff --git a/src/utils/audio.js b/src/utils/audio.js index f9f6216c0..082870de8 100644 --- a/src/utils/audio.js +++ b/src/utils/audio.js @@ -10,7 +10,11 @@ import { getFile, } from './hub.js'; -import { rfftfreq } from './maths.js'; +import { FFT, max } from './maths.js'; +import { + calculateReflectOffset, +} from './core.js'; + /** * Helper function to read audio from a path/URL. @@ -57,8 +61,8 @@ export async function read_audio(url, sampling_rate) { // audio at all, this scaling factor may not be needed. const SCALING_FACTOR = Math.sqrt(2); - let left = decoded.getChannelData(0); - let right = decoded.getChannelData(1); + const left = decoded.getChannelData(0); + const right = decoded.getChannelData(1); audio = new Float32Array(left.length); for (let i = 0; i < decoded.length; ++i) { @@ -74,69 +78,587 @@ export async function read_audio(url, sampling_rate) { } /** - * Creates a frequency bin conversion matrix used to obtain a mel spectrogram. - * @param {number} sr Sample rate of the audio waveform. - * @param {number} n_fft Number of frequencies used to compute the spectrogram (should be the same as in `stft`). - * @param {number} n_mels Number of mel filters to generate. - * @returns {number[][]} Projection matrix to go from a spectrogram to a mel spectrogram. + * Generates a Hanning window of length M. + * + * @param {number} M The length of the Hanning window to generate. + * @returns {Float64Array} The generated Hanning window. + */ +export function hanning(M) { + if (M < 1) { + return new Float64Array(); + } + if (M === 1) { + return new Float64Array([1]); + } + const denom = M - 1; + const factor = Math.PI / denom; + const cos_vals = new Float64Array(M); + for (let i = 0; i < M; ++i) { + const n = 2 * i - denom; + cos_vals[i] = 0.5 + 0.5 * Math.cos(factor * n); + } + return cos_vals; +} + +const HERTZ_TO_MEL_MAPPING = { + "htk": (/** @type {number} */ freq) => 2595.0 * Math.log10(1.0 + (freq / 700.0)), + "kaldi": (/** @type {number} */ freq) => 1127.0 * Math.log(1.0 + (freq / 700.0)), + "slaney": (/** @type {number} */ freq, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = 27.0 / Math.log(6.4)) => + freq >= min_log_hertz + ? min_log_mel + Math.log(freq / min_log_hertz) * logstep + : 3.0 * freq / 200.0, +} + +/** + * @template {Float32Array|Float64Array|number} T + * @param {T} freq + * @param {string} [mel_scale] + * @returns {T} + */ +function hertz_to_mel(freq, mel_scale = "htk") { + const fn = HERTZ_TO_MEL_MAPPING[mel_scale]; + if (!fn) { + throw new Error('mel_scale should be one of "htk", "slaney" or "kaldi".'); + } + + return typeof freq === 'number' ? fn(freq) : freq.map(x => fn(x)); +} + +const MEL_TO_HERTZ_MAPPING = { + "htk": (/** @type {number} */ mels) => 700.0 * (10.0 ** (mels / 2595.0) - 1.0), + "kaldi": (/** @type {number} */ mels) => 700.0 * (Math.exp(mels / 1127.0) - 1.0), + "slaney": (/** @type {number} */ mels, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = Math.log(6.4) / 27.0) => mels >= min_log_mel + ? min_log_hertz * Math.exp(logstep * (mels - min_log_mel)) + : 200.0 * mels / 3.0, +} + +/** + * @template {Float32Array|Float64Array|number} T + * @param {T} mels + * @param {string} [mel_scale] + * @returns {T} */ -export function getMelFilters(sr, n_fft, n_mels = 128) { - n_mels = Math.floor(n_mels); +function mel_to_hertz(mels, mel_scale = "htk") { + const fn = MEL_TO_HERTZ_MAPPING[mel_scale]; + if (!fn) { + throw new Error('mel_scale should be one of "htk", "slaney" or "kaldi".'); + } - // Initialize the weights - const mel_size = Math.floor(1 + n_fft / 2); - const weights = new Array(n_mels); + return typeof mels === 'number' ? fn(mels) : mels.map(x => fn(x)); +} - // Center freqs of each FFT bin - const fftfreqs = rfftfreq(n_fft, 1 / sr); +/** +* Creates a triangular filter bank. +* +* Adapted from torchaudio and librosa. +* +* @param {Float64Array} fft_freqs Discrete frequencies of the FFT bins in Hz, of shape `(num_frequency_bins,)`. +* @param {Float64Array} filter_freqs Center frequencies of the triangular filters to create, in Hz, of shape `(num_mel_filters,)`. +* @returns {number[][]} of shape `(num_frequency_bins, num_mel_filters)`. +*/ +function _create_triangular_filter_bank(fft_freqs, filter_freqs) { + const filter_diff = Float64Array.from( + { length: filter_freqs.length - 1 }, + (_, i) => filter_freqs[i + 1] - filter_freqs[i] + ); - // 'Center freqs' of mel bands - uniformly spaced between limits - const min_mel = 0.0; - const max_mel = 45.245640471924965; - const mel_range = max_mel - min_mel; - const mel_scale = mel_range / (n_mels + 1); + const slopes = Array.from({ + length: fft_freqs.length + }, () => new Array(filter_freqs.length)); - // Fill in the linear scale - const f_min = 0.0; - const f_sp = 200.0 / 3; - const freqs = new Array(n_mels + 2); + for (let j = 0; j < fft_freqs.length; ++j) { + const slope = slopes[j]; + for (let i = 0; i < filter_freqs.length; ++i) { + slope[i] = filter_freqs[i] - fft_freqs[j]; + } + } - // And now the nonlinear scale - const min_log_hz = 1000.0; // beginning of log region (Hz) - const min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels) - const logstep = Math.log(6.4) / 27.0; // step size for log region + const numFreqs = filter_freqs.length - 2; + const ret = Array.from({ length: numFreqs }, () => new Array(fft_freqs.length)); - const ramps = new Array(freqs.length); - for (let i = 0; i < freqs.length; ++i) { - const mel = i * mel_scale + min_mel; - if (mel >= min_log_mel) { - freqs[i] = min_log_hz * Math.exp(logstep * (mel - min_log_mel)); - } else { - freqs[i] = f_min + f_sp * mel; + for (let j = 0; j < fft_freqs.length; ++j) { // 201 + const slope = slopes[j]; + for (let i = 0; i < numFreqs; ++i) { // 80 + const down = -slope[i] / filter_diff[i]; + const up = slope[i + 2] / filter_diff[i + 1]; + ret[i][j] = Math.max(0, Math.min(down, up)); } - ramps[i] = fftfreqs.map(k => freqs[i] - k); } + return ret; +} - const fdiffinv = freqs.slice(1).map((v, i) => 1 / (v - freqs[i])); +/** + * Return evenly spaced numbers over a specified interval. + * @param {number} start The starting value of the sequence. + * @param {number} end The end value of the sequence. + * @param {number} num Number of samples to generate. + * @returns `num` evenly spaced samples, calculated over the interval `[start, stop]`. + */ +function linspace(start, end, num) { + const step = (end - start) / (num - 1); + return Float64Array.from({ length: num }, (_, i) => start + step * i); +} + +/** + * Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and + * various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters + * are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these + * features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency. + * @param {number} num_frequency_bins Number of frequencies used to compute the spectrogram (should be the same as in `stft`). + * @param {number} num_mel_filters Number of mel filters to generate. + * @param {number} min_frequency Lowest frequency of interest in Hz. + * @param {number} max_frequency Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`. + * @param {number} sampling_rate Sample rate of the audio waveform. + * @param {string} [norm] If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization). + * @param {string} [mel_scale] The mel frequency scale to use, `"htk"` or `"slaney"`. + * @param {boolean} [triangularize_in_mel_space] If this option is enabled, the triangular filter is applied in mel space rather than frequency space. + * This should be set to `true` in order to get the same results as `torchaudio` when computing mel filters. + * @returns {number[][]} Triangular filter bank matrix, which is a 2D array of shape (`num_frequency_bins`, `num_mel_filters`). + * This is a projection matrix to go from a spectrogram to a mel spectrogram. + */ +export function mel_filter_bank( + num_frequency_bins, + num_mel_filters, + min_frequency, + max_frequency, + sampling_rate, + norm = null, + mel_scale = "htk", + triangularize_in_mel_space = false, +) { + if (norm !== null && norm !== "slaney") { + throw new Error('norm must be one of null or "slaney"'); + } - for (let i = 0; i < weights.length; ++i) { - weights[i] = new Array(mel_size); + const mel_min = hertz_to_mel(min_frequency, mel_scale); + const mel_max = hertz_to_mel(max_frequency, mel_scale); + const mel_freqs = linspace(mel_min, mel_max, num_mel_filters + 2); - const a = fdiffinv[i]; - const b = fdiffinv[i + 1]; - const c = ramps[i]; - const d = ramps[i + 2]; + let filter_freqs = mel_to_hertz(mel_freqs, mel_scale); + let fft_freqs; // frequencies of FFT bins in Hz + if (triangularize_in_mel_space) { + const fft_bin_width = sampling_rate / (num_frequency_bins * 2); + fft_freqs = hertz_to_mel(Float64Array.from({ length: num_frequency_bins }, (_, i) => i * fft_bin_width), mel_scale); + filter_freqs = mel_freqs; + } else { + fft_freqs = linspace(0, Math.floor(sampling_rate / 2), num_frequency_bins); + } + + const mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs); + + if (norm !== null && norm === "slaney") { // Slaney-style mel is scaled to be approx constant energy per channel - const enorm = 2.0 / (freqs[i + 2] - freqs[i]); + for (let i = 0; i < num_mel_filters; ++i) { + const filter = mel_filters[i]; + const enorm = 2.0 / (filter_freqs[i + 2] - filter_freqs[i]); + for (let j = 0; j < num_frequency_bins; ++j) { + // Apply this enorm to all frequency bins + filter[j] *= enorm; + } + } + } + + // TODO warn if there is a zero row + + return mel_filters; + +} + +/** + * @template {Float32Array|Float64Array} T + * Pads an array with a reflected version of itself on both ends. + * @param {T} array The array to pad. + * @param {number} left The amount of padding to add to the left. + * @param {number} right The amount of padding to add to the right. + * @returns {T} The padded array. + */ +function padReflect(array, left, right) { + // @ts-ignore + const padded = new array.constructor(array.length + left + right); + const w = array.length - 1; + + for (let i = 0; i < array.length; ++i) { + padded[left + i] = array[i]; + } + + for (let i = 1; i <= left; ++i) { + padded[left - i] = array[calculateReflectOffset(i, w)]; + } + + for (let i = 1; i <= right; ++i) { + padded[w + left + i] = array[calculateReflectOffset(w - i, w)]; + } + + return padded; +} + +/** + * Helper function to compute `amplitude_to_db` and `power_to_db`. + * @template {Float32Array|Float64Array} T + * @param {T} spectrogram + * @param {number} factor + * @param {number} reference + * @param {number} min_value + * @param {number} db_range + * @returns {T} + */ +function _db_conversion_helper(spectrogram, factor, reference, min_value, db_range) { + if (reference <= 0) { + throw new Error('reference must be greater than zero'); + } + + if (min_value <= 0) { + throw new Error('min_value must be greater than zero'); + } + + reference = Math.max(min_value, reference); + + const logReference = Math.log10(reference); + for (let i = 0; i < spectrogram.length; ++i) { + spectrogram[i] = factor * Math.log10(Math.max(min_value, spectrogram[i]) - logReference) + } + + if (db_range !== null) { + if (db_range <= 0) { + throw new Error('db_range must be greater than zero'); + } + const maxValue = max(spectrogram)[0] - db_range; + for (let i = 0; i < spectrogram.length; ++i) { + spectrogram[i] = Math.max(spectrogram[i], maxValue); + } + } + + return spectrogram; +} + +/** + * Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, + * using basic logarithm properties for numerical stability. NOTE: Operates in-place. + * + * The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a + * linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it. + * This means that large variations in energy may not sound all that different if the sound is loud to begin with. + * This compression operation makes the (mel) spectrogram features match more closely what humans actually hear. + * + * @template {Float32Array|Float64Array} T + * @param {T} spectrogram The input amplitude (mel) spectrogram. + * @param {number} [reference=1.0] Sets the input spectrogram value that corresponds to 0 dB. + * For example, use `np.max(spectrogram)` to set the loudest part to 0 dB. Must be greater than zero. + * @param {number} [min_value=1e-5] The spectrogram will be clipped to this minimum value before conversion to decibels, + * to avoid taking `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero. + * @param {number} [db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the + * difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + * @returns {T} The modified spectrogram in decibels. + */ +function amplitude_to_db(spectrogram, reference = 1.0, min_value = 1e-5, db_range = null) { + return _db_conversion_helper(spectrogram, 20.0, reference, min_value, db_range); +} + +/** + * Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, + * using basic logarithm properties for numerical stability. NOTE: Operates in-place. + * + * The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a + * linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it. + * This means that large variations in energy may not sound all that different if the sound is loud to begin with. + * This compression operation makes the (mel) spectrogram features match more closely what humans actually hear. + * + * Based on the implementation of `librosa.power_to_db`. + * + * @template {Float32Array|Float64Array} T + * @param {T} spectrogram The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared! + * @param {number} [reference=1.0] Sets the input spectrogram value that corresponds to 0 dB. + * For example, use `np.max(spectrogram)` to set the loudest part to 0 dB. Must be greater than zero. + * @param {number} [min_value=1e-10] The spectrogram will be clipped to this minimum value before conversion to decibels, + * to avoid taking `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero. + * @param {number} [db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the + * difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + * @returns {T} The modified spectrogram in decibels. + */ +function power_to_db(spectrogram, reference = 1.0, min_value = 1e-10, db_range = null) { + return _db_conversion_helper(spectrogram, 10.0, reference, min_value, db_range); +} + +/** + * Calculates a spectrogram over one waveform using the Short-Time Fourier Transform. + * + * This function can create the following kinds of spectrograms: + * - amplitude spectrogram (`power = 1.0`) + * - power spectrogram (`power = 2.0`) + * - complex-valued spectrogram (`power = None`) + * - log spectrogram (use `log_mel` argument) + * - mel spectrogram (provide `mel_filters`) + * - log-mel spectrogram (provide `mel_filters` and `log_mel`) + * + * In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. + * A padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, + * typically the next power of two. + * + * @param {Float32Array|Float64Array} waveform The input waveform of shape `(length,)`. This must be a single real-valued, mono waveform. + * @param {Float32Array|Float64Array} window The windowing function to apply of shape `(frame_length,)`, including zero-padding if necessary. The actual window length may be + * shorter than `frame_length`, but we're assuming the array has already been zero-padded. + * @param {number} frame_length The length of the analysis frames in samples (a.k.a., `fft_length`). + * @param {number} hop_length The stride between successive analysis frames in samples. + * @param {Object} options + * @param {number} [options.fft_length=null] The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have. + * For optimal speed, this should be a power of two. If `null`, uses `frame_length`. + * @param {number} [options.power=1.0] If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `null`, returns complex numbers. + * @param {boolean} [options.center=true] Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `false`, frame + * `t` will start at time `t * hop_length`. + * @param {string} [options.pad_mode="reflect"] Padding mode used when `center` is `true`. Possible values are: `"constant"` (pad with zeros), + * `"edge"` (pad with edge values), `"reflect"` (pads with mirrored values). + * @param {boolean} [options.onesided=true] If `true`, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1` + * frequency bins. If `false`, also computes the negative frequencies and returns `fft_length` frequency bins. + * @param {number} [options.preemphasis=null] Coefficient for a low-pass filter that applies pre-emphasis before the DFT. + * @param {number[][]} [options.mel_filters=null] The mel filter bank of shape `(num_freq_bins, num_mel_filters)`. + * If supplied, applies this filter bank to create a mel spectrogram. + * @param {number} [options.mel_floor=1e-10] Minimum value of mel frequency banks. + * @param {string} [options.log_mel=null] How to convert the spectrogram to log scale. Possible options are: + * `null` (don't convert), `"log"` (take the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). + * Can only be used when `power` is not `null`. + * @param {number} [options.reference=1.0] Sets the input spectrogram value that corresponds to 0 dB. For example, use `max(spectrogram)[0]` to set + * the loudest part to 0 dB. Must be greater than zero. + * @param {number} [options.min_value=1e-10] The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking `log(0)`. + * For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an amplitude spectrogram, the value `1e-5` corresponds to -100 dB. + * Must be greater than zero. + * @param {number} [options.db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + * peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + * @param {boolean} [options.remove_dc_offset=null] Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in + * order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters. + * @param {number} [options.max_num_frames=null] If provided, limits the number of frames to compute to this value. + * @param {boolean} [options.do_pad=true] If `true`, pads the output spectrogram to have `max_num_frames` frames. + * @param {boolean} [options.transpose=false] If `true`, the returned spectrogram will have shape `(num_frames, num_frequency_bins/num_mel_filters)`. If `false`, the returned spectrogram will have shape `(num_frequency_bins/num_mel_filters, num_frames)`. + * @returns {{data: Float32Array, dims: number[]}} Spectrogram of shape `(num_frequency_bins, length)` (regular spectrogram) or shape `(num_mel_filters, length)` (mel spectrogram). + */ +export function spectrogram( + waveform, + window, + frame_length, + hop_length, + { + fft_length = null, + power = 1.0, + center = true, + pad_mode = "reflect", + onesided = true, + preemphasis = null, + mel_filters = null, + mel_floor = 1e-10, + log_mel = null, + reference = 1.0, + min_value = 1e-10, + db_range = null, + remove_dc_offset = null, + + // Custom parameters for efficiency reasons + max_num_frames = null, + do_pad = true, + transpose = false, + } = {} +) { + const window_length = window.length; + if (fft_length === null) { + fft_length = frame_length; + } + if (frame_length > fft_length) { + throw Error(`frame_length (${frame_length}) may not be larger than fft_length (${fft_length})`) + } + + if (window_length !== frame_length) { + throw new Error(`Length of the window (${window_length}) must equal frame_length (${frame_length})`); + } + + if (hop_length <= 0) { + throw new Error("hop_length must be greater than zero"); + } - for (let j = 0; j < weights[i].length; ++j) { - // lower and upper slopes for all bins - const lower = -c[j] * a; - const upper = d[j] * b; - weights[i][j] = Math.max(0, Math.min(lower, upper)) * enorm; + if (center) { + if (pad_mode !== 'reflect') { + throw new Error(`pad_mode="${pad_mode}" not implemented yet.`) } + const half_window = Math.floor((fft_length - 1) / 2) + 1; + waveform = padReflect(waveform, half_window, half_window); + } + + // split waveform into frames of frame_length size + const num_frames = Math.floor(1 + Math.floor((waveform.length - frame_length) / hop_length)) + + const num_frequency_bins = onesided ? Math.floor(fft_length / 2) + 1 : fft_length + + let d1 = num_frames; + let d1Max = num_frames; + + // If maximum number of frames is provided, we must either pad or truncate + if (max_num_frames !== null) { + if (max_num_frames > num_frames) { // input is too short, so we pad + if (do_pad) { + d1Max = max_num_frames; + } + } else { // input is too long, so we truncate + d1Max = d1 = max_num_frames; + } + } + + // Preallocate arrays to store output. + const fft = new FFT(fft_length); + const inputBuffer = new Float64Array(fft_length); + const outputBuffer = new Float64Array(fft.outputBufferSize); + const magnitudes = new Array(d1); + + for (let i = 0; i < d1; ++i) { + // Populate buffer with waveform data + const offset = i * hop_length; + for (let j = 0; j < frame_length; ++j) { + inputBuffer[j] = waveform[offset + j]; + } + + if (remove_dc_offset) { + let sum = 0; + for (let j = 0; j < frame_length; ++j) { + sum += inputBuffer[j]; + } + const mean = sum / frame_length; + for (let j = 0; j < frame_length; ++j) { + inputBuffer[j] -= mean; + } + } + + if (preemphasis !== null) { + // Done in reverse to avoid copies and distructive modification + for (let j = frame_length - 1; j >= 1; --j) { + inputBuffer[j] -= preemphasis * inputBuffer[j - 1]; + } + inputBuffer[0] *= 1 - preemphasis; + } + + for (let j = 0; j < window.length; ++j) { + inputBuffer[j] *= window[j]; + } + + fft.realTransform(outputBuffer, inputBuffer); + + // compute magnitudes + const row = new Array(num_frequency_bins); + for (let j = 0; j < row.length; ++j) { + const j2 = j << 1; + row[j] = outputBuffer[j2] ** 2 + outputBuffer[j2 + 1] ** 2; + } + magnitudes[i] = row; + } + + // TODO what should happen if power is None? + // https://github.com/huggingface/transformers/issues/27772 + if (power !== null && power !== 2) { + // slight optimization to not sqrt + const pow = 2 / power; // we use 2 since we already squared + for (let i = 0; i < magnitudes.length; ++i) { + const magnitude = magnitudes[i]; + for (let j = 0; j < magnitude.length; ++j) { + magnitude[j] **= pow; + } + } + } + + // TODO: What if `mel_filters` is null? + const num_mel_filters = mel_filters.length; + + // Only here do we create Float32Array + const mel_spec = new Float32Array(num_mel_filters * d1Max); + + // Perform matrix muliplication: + // mel_spec = mel_filters @ magnitudes.T + // - mel_filters.shape=(80, 201) + // - magnitudes.shape=(3000, 201) => - magnitudes.T.shape=(201, 3000) + // - mel_spec.shape=(80, 3000) + const dims = transpose ? [d1Max, num_mel_filters] : [num_mel_filters, d1Max]; + for (let i = 0; i < num_mel_filters; ++i) { // num melfilters (e.g., 80) + const filter = mel_filters[i]; + for (let j = 0; j < d1; ++j) { // num frames (e.g., 3000) + const magnitude = magnitudes[j]; + + let sum = 0; + for (let k = 0; k < num_frequency_bins; ++k) { // num frequency bins (e.g., 201) + sum += filter[k] * magnitude[k]; + } + + mel_spec[ + transpose + ? j * num_mel_filters + i + : i * d1 + j + ] = Math.max(mel_floor, sum); + } + } + + if (power !== null && log_mel !== null) { + const o = Math.min(mel_spec.length, d1 * num_mel_filters); + switch (log_mel) { + case 'log': + for (let i = 0; i < o; ++i) { + mel_spec[i] = Math.log(mel_spec[i]); + } + break; + case 'log10': + for (let i = 0; i < o; ++i) { + mel_spec[i] = Math.log10(mel_spec[i]); + } + break; + case 'dB': + if (power === 1.0) { + // NOTE: operates in-place + amplitude_to_db(mel_spec, reference, min_value, db_range); + } else if (power === 2.0) { + power_to_db(mel_spec, reference, min_value, db_range); + } else { + throw new Error(`Cannot use log_mel option '${log_mel}' with power ${power}`) + } + break; + default: + throw new Error(`log_mel must be one of null, 'log', 'log10' or 'dB'. Got '${log_mel}'`); + } + } + + return { data: mel_spec, dims }; +} + +/** + * Returns an array containing the specified window. + * @param {number} window_length The length of the window in samples. + * @param {string} name The name of the window function. + * @param {Object} options Additional options. + * @param {boolean} [options.periodic=true] Whether the window is periodic or symmetric. + * @param {number} [options.frame_length=null] The length of the analysis frames in samples. + * Provide a value for `frame_length` if the window is smaller than the frame length, so that it will be zero-padded. + * @param {boolean} [options.center=true] Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided. + * @returns {Float64Array} The window of shape `(window_length,)` or `(frame_length,)`. + */ +export function window_function(window_length, name, { + periodic = true, + frame_length = null, + center = true, +} = {}) { + const length = periodic ? window_length + 1 : window_length; + let window; + switch (name) { + case 'boxcar': + window = new Float64Array(length).fill(1.0); + break; + case 'hann': + case 'hann_window': + window = hanning(length); + break; + default: + throw new Error(`Unknown window type ${name}.`); + } + if (periodic) { + window = window.subarray(0, window_length); + } + if (frame_length === null) { + return window; + } + if (window_length > frame_length) { + throw new Error(`Length of the window (${window_length}) may not be larger than frame_length (${frame_length})`); } - return weights; + return window; } diff --git a/src/utils/core.js b/src/utils/core.js index 7de13625d..9ab11144c 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -184,3 +184,18 @@ export function product(...a) { export function calculateReflectOffset(i, w) { return Math.abs((i + w) % (2 * w) - w); } + +/** + * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... } + * @param {number[]} box The bounding box as a list. + * @param {boolean} asInteger Whether to cast to integers. + * @returns {Object} The bounding box as an object. + */ +export function get_bounding_box(box, asInteger) { + if (asInteger) { + box = box.map(x => x | 0); + } + const [xmin, ymin, xmax, ymax] = box; + + return { xmin, ymin, xmax, ymax }; +} diff --git a/src/utils/generation.js b/src/utils/generation.js index 588b1a562..c6df20cf5 100644 --- a/src/utils/generation.js +++ b/src/utils/generation.js @@ -491,6 +491,49 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { } } +export class NoBadWordsLogitsProcessor extends LogitsProcessor { + /** + * Create a `NoBadWordsLogitsProcessor`. + * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated. + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(bad_words_ids, eos_token_id) { + super(); + this.bad_words_ids = bad_words_ids; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {Array} input_ids The input IDs. + * @param {Object} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + + for (const bad_word_ids of this.bad_words_ids) { + // Whether to modify the logits of the last token in the bad word id sequence + let mark = true; + + // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), + // then we set the logits of the last bad word id to -Infinity. + for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) { + + if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) { + // We have found a mismatch + mark = false; + break; + } + } + if (mark) { + logits.data[bad_word_ids.at(-1)] = -Infinity; + } + } + + return logits + } +} + /** * Class that holds a configuration for a generation task. */ diff --git a/src/utils/image.js b/src/utils/image.js index cb9303572..96c1d8227 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -16,6 +16,7 @@ import { env } from '../env.js'; import sharp from 'sharp'; const BROWSER_ENV = typeof self !== 'undefined'; +const WEBWORKER_ENV = BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope'; let createCanvasFunction; let ImageDataClass; @@ -90,6 +91,10 @@ export class RawImage { this.channels = channels; } + get size() { + return [this.width, this.height]; + } + /** * Helper method for reading an image from a variety of input types. * @param {RawImage|string|URL} input @@ -387,6 +392,58 @@ export class RawImage { } } + async crop([x_min, y_min, x_max, y_max]) { + // Ensure crop bounds are within the image + x_min = Math.max(x_min, 0); + y_min = Math.max(y_min, 0); + x_max = Math.min(x_max, this.width - 1); + y_max = Math.min(y_max, this.height - 1); + + // Do nothing if the crop is the entire image + if (x_min === 0 && y_min === 0 && x_max === this.width - 1 && y_max === this.height - 1) { + return this; + } + + const crop_width = x_max - x_min + 1; + const crop_height = y_max - y_min + 1; + + if (BROWSER_ENV) { + // Store number of channels before resizing + const numChannels = this.channels; + + // Create canvas object for this image + const canvas = this.toCanvas(); + + // Create a new canvas of the desired size. This is needed since if the + // image is too small, we need to pad it with black pixels. + const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); + + // Draw image to context, cropping in the process + ctx.drawImage(canvas, + x_min, y_min, crop_width, crop_height, + 0, 0, crop_width, crop_height + ); + + // Create image from the resized data + const resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4); + + // Convert back so that image has the same number of channels as before + return resizedImage.convert(numChannels); + + } else { + // Create sharp image from raw data + const img = this.toSharp().extract({ + left: x_min, + top: y_min, + width: crop_width, + height: crop_height, + }); + + return await loadImageFunction(img); + } + + } + async center_crop(crop_width, crop_height) { // If the image is already the desired size, return it if (this.width === crop_width && this.height === crop_height) { @@ -502,6 +559,15 @@ export class RawImage { } } + async toBlob(type = 'image/png', quality = 1) { + if (!BROWSER_ENV) { + throw new Error('toBlob() is only supported in browser environments.') + } + + const canvas = this.toCanvas(); + return await canvas.convertToBlob({ type, quality }); + } + toCanvas() { if (!BROWSER_ENV) { throw new Error('toCanvas() is only supported in browser environments.') @@ -575,17 +641,21 @@ export class RawImage { * Save the image to the given path. * @param {string} path The path to save the image to. */ - save(path) { + async save(path) { if (BROWSER_ENV) { + if (WEBWORKER_ENV) { + throw new Error('Unable to save an image from a Web Worker.') + } + const extension = path.split('.').pop().toLowerCase(); const mime = CONTENT_TYPE_MAP.get(extension) ?? 'image/png'; - // Convert image to canvas - const canvas = this.toCanvas(); + // Convert image to Blob + const blob = await this.toBlob(mime); // Convert the canvas content to a data URL - const dataURL = canvas.toDataURL(mime); + const dataURL = URL.createObjectURL(blob); // Create an anchor element with the data URL as the href attribute const downloadLink = document.createElement('a'); @@ -605,7 +675,7 @@ export class RawImage { } else { const img = this.toSharp(); - img.toFile(path); + return await img.toFile(path); } } diff --git a/src/utils/maths.js b/src/utils/maths.js index 8bf5c7e95..eb4ec3482 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -89,8 +89,8 @@ export function interpolate_data(input, [in_channels, in_height, in_width], [out /** * Helper method to transpose a `AnyTypedArray` directly - * @param {T} array * @template {AnyTypedArray} T + * @param {T} array * @param {number[]} dims * @param {number[]} axes * @returns {[T, number[]]} The transposed array and the new shape. @@ -232,7 +232,7 @@ export function magnitude(arr) { /** * Returns the value and index of the minimum element in an array. - * @param {number[]} arr array of numbers. + * @param {number[]|TypedArray} arr array of numbers. * @returns {number[]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin] * @throws {Error} If array is empty. */ @@ -252,7 +252,7 @@ export function min(arr) { /** * Returns the value and index of the maximum element in an array. - * @param {number[]} arr array of numbers. + * @param {number[]|TypedArray} arr array of numbers. * @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] * @throws {Error} If array is empty. */ @@ -269,48 +269,31 @@ export function max(arr) { return [max, indexOfMax]; } -/** - * Return the Discrete Fourier Transform sample frequencies. - * - * Code adapted from https://github.com/numpy/numpy/blob/25908cacd19915bf3ddd659c28be28a41bd97a54/numpy/fft/helper.py#L173-L221 - * Original Python doc: https://numpy.org/doc/stable/reference/generated/numpy.fft.rfftfreq.html - * @example - * rfftfreq(400, 1 / 16000) // (201) [0, 40, 80, 120, 160, 200, ..., 8000] - * @param {number} n Window length - * @param {number} [d = 1.0] Sample spacing (inverse of the sampling rate). Defaults to 1. - * @throws {TypeError} If n is not an integer. - * @returns {number[]} Array of length `Math.floor(n / 2) + 1;` containing the sample frequencies. - */ -export function rfftfreq(n, d = 1.0) { - if (!Number.isInteger(n)) { - throw new TypeError(`n should be an integer, but ${n} given.`); - } - const val = 1.0 / (n * d); - const len = Math.floor(n / 2) + 1; - const results = new Array(len); - for (let i = 0; i < len; ++i) { - results[i] = i * val; - } - return results; +function isPowerOfTwo(number) { + // Check if the number is greater than 0 and has only one bit set to 1 + return (number > 0) && ((number & (number - 1)) === 0); } /** - * FFT class provides functionality for performing Fast Fourier Transform on arrays + * Implementation of Radix-4 FFT. + * + * P2FFT class provides functionality for performing Fast Fourier Transform on arrays + * which are a power of two in length. * Code adapted from https://www.npmjs.com/package/fft.js */ -export class FFT { +class P2FFT { /** - * @param {number} size The size of the input array. Must be a power of two and bigger than 1. - * @throws {Error} FFT size must be a power of two and bigger than 1. + * @param {number} size The size of the input array. Must be a power of two larger than 1. + * @throws {Error} FFT size must be a power of two larger than 1. */ constructor(size) { this.size = size | 0; // convert to a 32-bit signed integer - if (this.size <= 1 || (this.size & (this.size - 1)) !== 0) - throw new Error('FFT size must be a power of two and bigger than 1'); + if (this.size <= 1 || !isPowerOfTwo(this.size)) + throw new Error('FFT size must be a power of two larger than 1'); this._csize = size << 1; - this.table = new Float32Array(this.size * 2); + this.table = new Float64Array(this.size * 2); for (let i = 0; i < this.table.length; i += 2) { const angle = Math.PI * i / this.size; this.table[i] = Math.cos(angle); @@ -341,16 +324,16 @@ export class FFT { /** * Create a complex number array with size `2 * size` * - * @returns {Float32Array} A complex number array with size `2 * size` + * @returns {Float64Array} A complex number array with size `2 * size` */ createComplexArray() { - return new Float32Array(this._csize); + return new Float64Array(this._csize); } /** - * Converts a complex number representation stored in a Float32Array to an array of real numbers. + * Converts a complex number representation stored in a Float64Array to an array of real numbers. * - * @param {Float32Array} complex The complex number representation to be converted. + * @param {Float64Array} complex The complex number representation to be converted. * @param {number[]} [storage] An optional array to store the result in. * @returns {number[]} An array of real numbers representing the input complex number representation. */ @@ -363,9 +346,9 @@ export class FFT { /** * Convert a real-valued input array to a complex-valued output array. - * @param {Float32Array} input The real-valued input array. - * @param {Float32Array} [storage] Optional buffer to store the output array. - * @returns {Float32Array} The complex-valued output array. + * @param {Float64Array} input The real-valued input array. + * @param {Float64Array} [storage] Optional buffer to store the output array. + * @returns {Float64Array} The complex-valued output array. */ toComplexArray(input, storage) { const res = storage || this.createComplexArray(); @@ -378,7 +361,7 @@ export class FFT { /** * Completes the spectrum by adding its mirrored negative frequency components. - * @param {Float32Array} spectrum The input spectrum. + * @param {Float64Array} spectrum The input spectrum. * @returns {void} */ completeSpectrum(spectrum) { @@ -393,8 +376,8 @@ export class FFT { /** * Performs a Fast Fourier Transform (FFT) on the given input data and stores the result in the output buffer. * - * @param {Float32Array} out The output buffer to store the result. - * @param {Float32Array} data The input data to transform. + * @param {Float64Array} out The output buffer to store the result. + * @param {Float64Array} data The input data to transform. * * @throws {Error} Input and output buffers must be different. * @@ -412,8 +395,8 @@ export class FFT { * The input buffer must contain real values only, while the output buffer will contain complex values. The input and * output buffers must be different. * - * @param {Float32Array} out The output buffer. - * @param {Float32Array} data The input buffer containing real values. + * @param {Float64Array} out The output buffer. + * @param {Float64Array} data The input buffer containing real values. * * @throws {Error} If the input and output buffers are the same. */ @@ -429,8 +412,8 @@ export class FFT { * The `out` array must be a different buffer than the `data` array. The `out` array will contain the * result of the transformation. The `data` array will not be modified. * - * @param {Float32Array} out The output buffer for the transformed data. - * @param {Float32Array} data The input data to transform. + * @param {Float64Array} out The output buffer for the transformed data. + * @param {Float64Array} data The input data to transform. * @throws {Error} If `out` and `data` refer to the same buffer. * @returns {void} */ @@ -446,8 +429,8 @@ export class FFT { /** * Performs a radix-4 implementation of a discrete Fourier transform on a given set of data. * - * @param {Float32Array} out The output buffer for the transformed data. - * @param {Float32Array} data The input buffer of data to be transformed. + * @param {Float64Array} out The output buffer for the transformed data. + * @param {Float64Array} data The input buffer of data to be transformed. * @param {number} inv A scaling factor to apply to the transform. * @returns {void} */ @@ -463,7 +446,7 @@ export class FFT { let outOff; let t; - let bitrev = this._bitrev; + const bitrev = this._bitrev; if (len === 4) { for (outOff = 0, t = 0; outOff < size; outOff += len, ++t) { const off = bitrev[t]; @@ -480,12 +463,12 @@ export class FFT { // Loop through steps in decreasing order for (step >>= 2; step >= 2; step >>= 2) { len = (size / step) << 1; - let quarterLen = len >>> 2; + const quarterLen = len >>> 2; // Loop through offsets in the data for (outOff = 0; outOff < size; outOff += len) { // Full case - let limit = outOff + quarterLen; + const limit = outOff + quarterLen - 1; for (let i = outOff, k = 0; i < limit; i += 2, k += step) { const A = i; const B = A + quarterLen; @@ -544,8 +527,8 @@ export class FFT { /** * Performs a radix-2 implementation of a discrete Fourier transform on a given set of data. * - * @param {Float32Array} data The input buffer of data to be transformed. - * @param {Float32Array} out The output buffer for the transformed data. + * @param {Float64Array} data The input buffer of data to be transformed. + * @param {Float64Array} out The output buffer for the transformed data. * @param {number} outOff The offset at which to write the output data. * @param {number} off The offset at which to begin reading the input data. * @param {number} step The step size for indexing the input data. @@ -569,8 +552,8 @@ export class FFT { /** * Performs radix-4 transformation on input data of length 8 * - * @param {Float32Array} data Input data array of length 8 - * @param {Float32Array} out Output data array of length 8 + * @param {Float64Array} data Input data array of length 8 + * @param {Float64Array} out Output data array of length 8 * @param {number} outOff Index of output array to start writing from * @param {number} off Index of input array to start reading from * @param {number} step Step size between elements in input array @@ -617,8 +600,8 @@ export class FFT { /** * Real input radix-4 implementation - * @param {Float32Array} out Output array for the transformed data - * @param {Float32Array} data Input array of real data to be transformed + * @param {Float64Array} out Output array for the transformed data + * @param {Float64Array} data Input array of real data to be transformed * @param {number} inv The scale factor used to normalize the inverse transform */ _realTransform4(out, data, inv) { @@ -630,9 +613,9 @@ export class FFT { let step = 1 << width; let len = (size / step) << 1; - var outOff; - var t; - var bitrev = this._bitrev; + let outOff; + let t; + const bitrev = this._bitrev; if (len === 4) { for (outOff = 0, t = 0; outOff < size; outOff += len, ++t) { const off = bitrev[t]; @@ -646,17 +629,18 @@ export class FFT { } } + // TODO: Optimize once https://github.com/indutny/fft.js/issues/25 is fixed // Loop through steps in decreasing order for (step >>= 2; step >= 2; step >>= 2) { len = (size / step) << 1; - const halfLen = len >>> 1; - const quarterLen = halfLen >>> 1; - const hquarterLen = quarterLen >>> 1; + const quarterLen = len >>> 2; // Loop through offsets in the data for (outOff = 0; outOff < size; outOff += len) { - for (let i = 0, k = 0; i <= hquarterLen; i += 2, k += step) { - const A = outOff + i; + // Full case + const limit = outOff + quarterLen - 1; + for (let i = outOff, k = 0; i < limit; i += 2, k += step) { + const A = i; const B = A + quarterLen; const C = B + quarterLen; const D = C + quarterLen; @@ -701,25 +685,10 @@ export class FFT { out[A + 1] = T0i + T2i; out[B] = T1r + T3i; out[B + 1] = T1i - T3r; - - // Output final middle point - if (i === 0) { - out[C] = T0r - T2r; - out[C + 1] = T0i - T2i; - continue; - } - - // Do not overwrite ourselves - if (i === hquarterLen) - continue; - - const SA = outOff + quarterLen - i; - const SB = outOff + halfLen - i; - - out[SA] = T1r + -inv * T3i; - out[SA + 1] = -T1i - inv * T3r; - out[SB] = T0r + -inv * T2r; - out[SB + 1] = -T0i + inv * T2i; + out[C] = T0r - T2r; + out[C + 1] = T0i - T2i; + out[D] = T1r - T3i; + out[D + 1] = T1i + T3r; } } } @@ -728,8 +697,8 @@ export class FFT { /** * Performs a single real input radix-2 transformation on the provided data * - * @param {Float32Array} data The input data array - * @param {Float32Array} out The output data array + * @param {Float64Array} data The input data array + * @param {Float64Array} out The output data array * @param {number} outOff The output offset * @param {number} off The input offset * @param {number} step The step @@ -753,8 +722,8 @@ export class FFT { * Computes a single real-valued transform using radix-4 algorithm. * This method is only called for len=8. * - * @param {Float32Array} data The input data array. - * @param {Float32Array} out The output data array. + * @param {Float64Array} data The input data array. + * @param {Float64Array} out The output data array. * @param {number} outOff The offset into the output array. * @param {number} off The offset into the input array. * @param {number} step The step size for the input array. @@ -790,6 +759,148 @@ export class FFT { } } +/** + * NP2FFT class provides functionality for performing Fast Fourier Transform on arrays + * which are not a power of two in length. In such cases, the chirp-z transform is used. + * + * For more information, see: https://math.stackexchange.com/questions/77118/non-power-of-2-ffts/77156#77156 + */ +class NP2FFT { + + /** + * Constructs a new NP2FFT object. + * @param {number} fft_length The length of the FFT + */ + constructor(fft_length) { + // Helper variables + const a = 2 * (fft_length - 1); + const b = 2 * (2 * fft_length - 1); + const nextP2 = 2 ** (Math.ceil(Math.log2(b))) + this.bufferSize = nextP2; + this._a = a; + + // Define buffers + // Compute chirp for transform + const chirp = new Float64Array(b); + const ichirp = new Float64Array(nextP2); + this._chirpBuffer = new Float64Array(nextP2); + this._buffer1 = new Float64Array(nextP2); + this._buffer2 = new Float64Array(nextP2); + this._outBuffer1 = new Float64Array(nextP2); + this._outBuffer2 = new Float64Array(nextP2); + + // Compute complex exponentiation + const theta = -2 * Math.PI / fft_length; + const baseR = Math.cos(theta); + const baseI = Math.sin(theta); + + // Precompute helper for chirp-z transform + for (let i = 0; i < b >> 1; ++i) { + // Compute complex power: + const e = (i + 1 - fft_length) ** 2 / 2.0; + + // Compute the modulus and argument of the result + const result_mod = Math.sqrt(baseR ** 2 + baseI ** 2) ** e; + const result_arg = e * Math.atan2(baseI, baseR); + + // Convert the result back to rectangular form + // and assign to chirp and ichirp + const i2 = 2 * i; + chirp[i2] = result_mod * Math.cos(result_arg); + chirp[i2 + 1] = result_mod * Math.sin(result_arg); + + // conjugate + ichirp[i2] = chirp[i2]; + ichirp[i2 + 1] = - chirp[i2 + 1]; + } + this._slicedChirpBuffer = chirp.subarray(a, b); + + // create object to perform Fast Fourier Transforms + // with `nextP2` complex numbers + this._f = new P2FFT(nextP2 >> 1); + this._f.transform(this._chirpBuffer, ichirp); + } + + _transform(output, input, real) { + const ib1 = this._buffer1; + const ib2 = this._buffer2; + const ob2 = this._outBuffer1; + const ob3 = this._outBuffer2; + const cb = this._chirpBuffer; + const sb = this._slicedChirpBuffer; + const a = this._a; + + if (real) { + // Real multiplication + for (let j = 0; j < sb.length; j += 2) { + const j2 = j + 1 + const j3 = j >> 1; + + const a_real = input[j3]; + ib1[j] = a_real * sb[j]; + ib1[j2] = a_real * sb[j2]; + } + } else { + // Complex multiplication + for (let j = 0; j < sb.length; j += 2) { + const j2 = j + 1 + ib1[j] = input[j] * sb[j] - input[j2] * sb[j2]; + ib1[j2] = input[j] * sb[j2] + input[j2] * sb[j]; + } + } + this._f.transform(ob2, ib1); + + for (let j = 0; j < cb.length; j += 2) { + const j2 = j + 1; + + ib2[j] = ob2[j] * cb[j] - ob2[j2] * cb[j2]; + ib2[j2] = ob2[j] * cb[j2] + ob2[j2] * cb[j]; + } + this._f.inverseTransform(ob3, ib2); + + for (let j = 0; j < ob3.length; j += 2) { + const a_real = ob3[j + a]; + const a_imag = ob3[j + a + 1]; + const b_real = sb[j]; + const b_imag = sb[j + 1]; + + output[j] = a_real * b_real - a_imag * b_imag; + output[j + 1] = a_real * b_imag + a_imag * b_real; + } + } + + transform(output, input) { + this._transform(output, input, false); + } + + realTransform(output, input) { + this._transform(output, input, true); + } +} + +export class FFT { + constructor(fft_length) { + this.fft_length = fft_length; + this.isPowerOfTwo = isPowerOfTwo(fft_length); + if (this.isPowerOfTwo) { + this.fft = new P2FFT(fft_length); + this.outputBufferSize = 2 * fft_length; + } else { + this.fft = new NP2FFT(fft_length); + this.outputBufferSize = this.fft.bufferSize; + } + } + + realTransform(out, input) { + this.fft.realTransform(out, input); + } + + transform(out, input) { + this.fft.transform(out, input); + } +} + + /** * Performs median filter on the provided data. Padding is done by mirroring the data. * @param {AnyTypedArray} data The input array diff --git a/src/utils/tensor.js b/src/utils/tensor.js index b575b5136..3cf165936 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -34,7 +34,6 @@ const DataTypeMap = new Map([ * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray */ -/** @type {Object} */ const ONNXTensor = ONNX.Tensor; export class Tensor extends ONNXTensor { diff --git a/tests/generate_tests.py b/tests/generate_tests.py index 40c41b00e..8c44258a8 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -5,6 +5,7 @@ import os from transformers import AutoTokenizer, AutoConfig +import numpy as np from scripts.supported_models import SUPPORTED_MODELS @@ -205,6 +206,37 @@ def generate_config_tests(): return results +ARRAY_SIZES = sorted(set([2 ** i for i in range(1, 10)]) \ + | set([3 ** i for i in range(1, 8)]) \ + | set([5 ** i for i in range(1, 6)]) \ + | set([7 ** i for i in range(1, 4)])) + + +def serialize_complex_array(arr): + return [float(x) for y in arr for x in [y.real, y.imag]] + + +def serialize_real_array(arr): + return arr.tolist() + + +def generate_fft_tests(): + np.random.seed(0) + tests = {} + for complex in [False, True]: + serialize_fn = serialize_complex_array if complex else serialize_real_array + for size in ARRAY_SIZES: + arr = np.random.randn(size).astype(np.complex64 if complex else np.float64) + if complex: + arr += np.random.randn(size) * 1j + tests[f"fft_{size}_{'complex' if complex else 'real'}"] = { + "complex": complex, + "input": serialize_fn(arr), + "output": serialize_complex_array(np.fft.fft(arr)), + } + return tests + + def main(): # TODO add option to cache generated data + force build tests @@ -220,6 +252,9 @@ def main(): with open(os.path.join(data_dir, "config_tests.json"), "w", encoding="utf-8") as fp: json.dump(config_tests, fp) - + fft_tests = generate_fft_tests() + with open(os.path.join(data_dir, "fft_tests.json"), "w", encoding="utf-8") as fp: + json.dump(fft_tests, fp) + if __name__ == "__main__": main() diff --git a/tests/generation.test.js b/tests/generation.test.js index 2effaf94d..eb6b87f49 100644 --- a/tests/generation.test.js +++ b/tests/generation.test.js @@ -9,8 +9,8 @@ describe('Generation parameters', () => { // List all models which will be tested const models = [ - 'Xenova/LaMini-Flan-T5-77M', // encoder-decoder - 'Xenova/LaMini-GPT-124M', // decoder-only + 'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder + 'MBZUAI/LaMini-GPT-124M', // decoder-only ]; // encoder-decoder model diff --git a/tests/init.js b/tests/init.js index 6eb3c9a12..b01fe1000 100644 --- a/tests/init.js +++ b/tests/init.js @@ -9,6 +9,12 @@ import { onnxruntimeBackend } from "onnxruntime-node/dist/backend"; import ONNX_COMMON from "onnxruntime-common"; export function init() { + // In rare cases (specifically when running unit tests with GitHub actions), possibly due to + // a large number of concurrent executions, onnxruntime might fallback to use the WASM backend. + // In this case, we set the number of threads to 1 to avoid errors like: + // - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."` + ONNX_COMMON.env.wasm.numThreads = 1; + // A workaround to define a new backend for onnxruntime, which // will not throw an error when running tests with jest. // For more information, see: https://github.com/jestjs/jest/issues/11864#issuecomment-1261468011 diff --git a/tests/maths.test.js b/tests/maths.test.js index 0dc389d6d..9a7d3dc3c 100644 --- a/tests/maths.test.js +++ b/tests/maths.test.js @@ -1,7 +1,29 @@ import { compare } from './test_utils.js'; -import { medianFilter } from '../src/utils/maths.js'; +import { getFile } from '../src/utils/hub.js'; +import { FFT, medianFilter } from '../src/utils/maths.js'; + + +const fft = (arr, complex = false) => { + let output; + let fft; + if (complex) { + fft = new FFT(arr.length / 2); + output = new Float64Array(fft.outputBufferSize); + fft.transform(output, arr); + } else { + fft = new FFT(arr.length); + output = new Float64Array(fft.outputBufferSize); + fft.realTransform(output, arr); + } + if (!fft.isPowerOfTwo) { + output = output.slice(0, complex ? arr.length : 2 * arr.length); + } + return output; +} + +const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json() describe('Mathematical operations', () => { @@ -11,8 +33,8 @@ describe('Mathematical operations', () => { it('should compute median filter', async () => { const t1 = new Float32Array([5, 12, 2, 6, 3, 10, 9, 1, 4, 8, 11, 7]); const window = 3; - - const target = new Float32Array([12, 5, 6, 3, 6, 9, 9, 4, 4, 8, 8, 11]); + + const target = new Float32Array([12, 5, 6, 3, 6, 9, 9, 4, 4, 8, 8, 11]); const output = medianFilter(t1, window); compare(output, target, 1e-3); @@ -22,4 +44,83 @@ describe('Mathematical operations', () => { // TODO add tests for errors }); + describe('FFT', () => { + // Should match output of numpy fft + it('should compute real FFT for power of two', () => { + { // size = 4 + // np.fft.fft([1,2,3,4]) == array([10.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + const input = new Float32Array([1, 2, 3, 4]); + const target = new Float32Array([10, 0, -2, 2, -2, 0, -2, -2]); + + const output = fft(input); + compare(output, target, 1e-3); + } + + { // size = 16 + // np.fft.fft([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) + // == array([136. +0.j , -8.+40.21871594j, -8.+19.3137085j , + // -8.+11.9728461j , -8. +8.j , -8. +5.3454291j , + // -8. +3.3137085j , -8. +1.59129894j, -8. +0.j , + // -8. -1.59129894j, -8. -3.3137085j , -8. -5.3454291j , + // -8. -8.j , -8.-11.9728461j , -8.-19.3137085j , + // -8.-40.21871594j]) + const input = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const target = new Float32Array([136.0, 0.0, -8.0, 40.218715937006785, -8.0, 19.31370849898476, -8.0, 11.972846101323912, -8.0, 8.0, -8.0, 5.345429103354389, -8.0, 3.313708498984761, -8.0, 1.5912989390372658, -8.0, 0.0, -8.0, -1.5912989390372658, -8.0, -3.313708498984761, -8.0, -5.345429103354389, -8.0, -8.0, -8.0, -11.972846101323912, -8.0, -19.31370849898476, -8.0, -40.218715937006785]); + + const output = fft(input); + compare(output, target, 1e-3); + } + }); + + it('should compute real FFT for non-power of two', () => { + { // size = 3 + // np.fft.fft([1,2,3]) == array([ 6. +0.j, -1.5+0.8660254j, -1.5-0.8660254j]) + const input = new Float32Array([1, 2, 3]); + const target = new Float32Array([6, 0, -1.5, 0.8660254, -1.5, -0.8660254]); + + const output = fft(input); + compare(output, target, 1e-3); + } + }); + + it('should compute complex FFT for non-power of two', () => { + { // size = 3 + // np.fft.fft([1+3j,2-2j,3+1j]) == array([ 6. +2.j, -4.09807621+4.3660254j, 1.09807621+2.6339746j]) + const input = new Float32Array([1, 3, 2, -2, 3, 1]); + const target = new Float32Array([6, 2, -4.09807621, 4.3660254, 1.09807621, 2.6339746]); + + const output = fft(input, true); + compare(output, target, 1e-3); + } + }); + + it('should compute complex FFT for power of two', () => { + { // size = 4 + // np.fft.fft([1+4j, 2-3j,3+2j, 4-1j]) == array([10. +2.j, -4. +4.j, -2.+10.j, 0. +0.j]) + const input = new Float32Array([1, 4, 2, -3, 3, 2, 4, -1]); + const target = new Float32Array([10, 2, -4, 4, -2, 10, 0, 0]); + + const output = fft(input, true); + compare(output, target, 1e-3); + } + }); + }) + + describe('FFT (dynamic)', () => { + // Should match output of numpy fft + for (const [name, test] of Object.entries(fftTestsData)) { + // if (test.input.length > 5) continue; + it(name, () => { + const output = fft(test.input, test.complex); + + if (output.map((v, i) => Math.abs(v - test.output[i])).some(v => v > 1e-4)) { + console.log('input', test.input) + console.log('output', output) + console.log('target', test.output) + } + compare(output, test.output, 1e-4); + + }); + } + }); }); diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index d8f5e56c0..50d3aecd7 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1326,6 +1326,105 @@ describe('Pipelines', () => { }, MAX_TEST_EXECUTION_TIME); }); + describe('Zero-shot object detection', () => { + + // List all models which will be tested + const models = [ + 'google/owlvit-base-patch32', + ]; + + it(models[0], async () => { + let detector = await pipeline('zero-shot-object-detection', m(models[0])); + + + // single (default) + { + let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png'; + let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag']; + + let output = await detector(url, candidate_labels); + + // let expected = [ + // { + // score: 0.24392342567443848, + // label: 'human face', + // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } + // }, + // { + // score: 0.15129457414150238, + // label: 'american flag', + // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 } + // }, + // { + // score: 0.13649864494800568, + // label: 'helmet', + // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 } + // }, + // { + // score: 0.10262022167444229, + // label: 'rocket', + // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe('number'); + expect(typeof cls.label).toBe('string'); + for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { + expect(typeof cls.box[key]).toBe('number'); + } + } + } + + // topk + threshold + percentage + { + let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png'; + let candidate_labels = ['hat', 'book', 'sunglasses', 'camera']; + + let output = await detector(url, candidate_labels, { + topk: 4, + threshold: 0.05, + percentage: true, + }); + + // let expected = [ + // { + // score: 0.1606510728597641, + // label: 'sunglasses', + // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } + // }, + // { + // score: 0.08935828506946564, + // label: 'hat', + // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 } + // }, + // { + // score: 0.08530698716640472, + // label: 'camera', + // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 } + // }, + // { + // score: 0.08349756896495819, + // label: 'book', + // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe('number'); + expect(typeof cls.label).toBe('string'); + for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { + expect(typeof cls.box[key]).toBe('number'); + } + } + } + + await detector.dispose(); + }, MAX_TEST_EXECUTION_TIME); + }); + describe('Image-to-image', () => { // List all models which will be tested @@ -1364,6 +1463,47 @@ describe('Pipelines', () => { }, MAX_TEST_EXECUTION_TIME); }); + + describe('Depth estimation', () => { + + // List all models which will be tested + const models = [ + 'Intel/dpt-hybrid-midas', + ]; + + it(models[0], async () => { + let depth_estimator = await pipeline('depth-estimation', m(models[0])); + + let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg'; + + // single + { + let { predicted_depth, depth } = await depth_estimator(url); + compare(predicted_depth.dims, [384, 384]); + expect(depth.width).toEqual(640); + expect(depth.height).toEqual(480); + expect(depth.channels).toEqual(1); + expect(depth.data).toHaveLength(307200); + } + + // batched + { + let outputs = await depth_estimator([url, url]); + expect(outputs).toHaveLength(2); + for (let output of outputs) { + let { predicted_depth, depth } = output; + compare(predicted_depth.dims, [384, 384]); + expect(depth.width).toEqual(640); + expect(depth.height).toEqual(480); + expect(depth.channels).toEqual(1); + expect(depth.data).toHaveLength(307200); + } + } + + await depth_estimator.dispose(); + }, MAX_TEST_EXECUTION_TIME); + }); + describe('Document question answering', () => { // List all models which will be tested diff --git a/tests/processors.test.js b/tests/processors.test.js index fe594613e..c6703f14e 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -38,6 +38,11 @@ describe('Processors', () => { beit: 'microsoft/beit-base-patch16-224-pt22k-ft22k', detr: 'facebook/detr-resnet-50', yolos: 'hustvl/yolos-small-300', + dpt: 'Intel/dpt-hybrid-midas', + glpn: 'vinvino02/glpn-kitti', + nougat: 'facebook/nougat-small', + owlvit: 'google/owlvit-base-patch32', + clip: 'openai/clip-vit-base-patch16', } const TEST_IMAGES = { @@ -45,6 +50,8 @@ describe('Processors', () => { checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', + paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png', + cats: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg', // grayscale image skateboard: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png', @@ -87,6 +94,7 @@ describe('Processors', () => { // DonutProcessor/DonutFeatureExtractor // - tests thumbnail resizing (do_thumbnail=true, size=[960, 1280]) + // - tests padding after normalization (image_mean=image_std=0.5) it(MODELS['donut-swin'], async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS['donut-swin'])) @@ -171,7 +179,7 @@ describe('Processors', () => { it(MODELS.deit, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.deit)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); @@ -187,7 +195,7 @@ describe('Processors', () => { it(MODELS.beit, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.beit)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); @@ -204,7 +212,7 @@ describe('Processors', () => { it(MODELS.detr, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.detr)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image); @@ -225,7 +233,7 @@ describe('Processors', () => { it(MODELS.yolos, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.yolos)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); @@ -236,5 +244,196 @@ describe('Processors', () => { compare(reshaped_input_sizes, [[888, 1333]]); } }, MAX_TEST_EXECUTION_TIME); + + // DPTFeatureExtractor + it(MODELS.dpt, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt)) + + { // Tests grayscale image + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 384, 384]); + compare(avg(pixel_values.data), 0.0372855559389454); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[384, 384]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // GLPNForDepthEstimation + // - tests `size_divisor` and no size (size_divisor=32) + it(MODELS.glpn, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.glpn)) + + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + compare(pixel_values.dims, [1, 3, 480, 640]); + compare(avg(pixel_values.data), 0.5186172404123327); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[480, 640]]); + } + + { // Tests input which is not a multiple of 32 ([408, 612] -> [384, 608]) + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 384, 608]); + compare(avg(pixel_values.data), 0.38628831535989555); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[384, 608]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // NougatImageProcessor + // - tests padding after normalization (image_mean != 0.5, image_std != 0.5) + it(MODELS.nougat, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.nougat)) + + { + const image = await load_image(TEST_IMAGES.paper); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 896, 672]); + compare(avg(pixel_values.data), 1.8447155005897355); + + compare(original_sizes, [[850, 685]]); + compare(reshaped_input_sizes, [[833, 672]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // OwlViTFeatureExtractor + it(MODELS.owlvit, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.owlvit)) + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 768, 768]); + compare(avg(pixel_values.data), 0.250620447910435); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[768, 768]]); + } + }); + + // CLIPFeatureExtractor + // - tests center crop (do_center_crop=true, crop_size=224) + it(MODELS.clip, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.clip)) + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.06678297738282096); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, MAX_TEST_EXECUTION_TIME); + }); + + describe('Audio processors', () => { + const audioPromise = new Promise(async (resolve) => { + const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.npy'; + const buffer = await (await fetch(url)).arrayBuffer(); + const audio = Float32Array.from(new Float64Array(buffer)); + resolve(audio); + }); + + it('WhisperFeatureExtractor', async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained('Xenova/whisper-tiny.en'); + const { input_features } = await processor(audio); + compare(input_features.dims, [1, 80, 3000]); + expect(avg(input_features.data)).toBeCloseTo(-0.2813588131551941); + expect(input_features.data[0]).toBeCloseTo(0.33168578147888184); + expect(input_features.data[1]).toBeCloseTo(0.30986475944519043); + expect(input_features.data[81]).toBeCloseTo(0.10727232694625854); + expect(input_features.data[3001]).toBeCloseTo(0.2555035352706909); + }, MAX_TEST_EXECUTION_TIME); + + it('ASTFeatureExtractor', async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained('Xenova/ast-finetuned-audioset-10-10-0.4593'); + { // truncation + const { input_values } = await processor(audio); + compare(input_values.dims, [1, 1024, 128]); + + expect(avg(input_values.data)).toBeCloseTo(-0.04054912979309085); + expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914); + expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157); + expect(input_values.data[129]).toBeCloseTo(-1.084834098815918); + expect(input_values.data[1025]).toBeCloseTo(-1.1204065084457397); + } + { // padding + const { input_values } = await processor(audio.slice(0, 1000)); + compare(input_values.dims, [1, 1024, 128]); // [1, 4, 128] -> (padded to) -> [1, 1024, 128] + + expect(avg(input_values.data)).toBeCloseTo(0.4647964835166931); + expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914); + expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157); + expect(input_values.data[129]).toBeCloseTo(-1.084834098815918); + + // padded values + expect(input_values.data[1025]).toBeCloseTo(0.46703237295150757); + expect(input_values.data[2049]).toBeCloseTo(0.46703237295150757); + expect(input_values.data[10000]).toBeCloseTo(0.46703237295150757); + } + }, MAX_TEST_EXECUTION_TIME); + + it('ClapFeatureExtractor', async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused'); + { // truncation + // Since truncation uses a random strategy, we override + // Math.random to ensure that the test is deterministic + const originalRandom = Math.random; + Math.random = () => 0.5; + + let long_audio = new Float32Array(500000); + long_audio.set(audio); + long_audio.set(audio, long_audio.length - audio.length); + + const { input_features } = await processor(long_audio); + compare(input_features.dims, [1, 1, 1001, 64]); + + expect(avg(input_features.data)).toBeCloseTo(-37.94569396972656); + expect(input_features.data[0]).toBeCloseTo(-53.32647705078125); + expect(input_features.data[1]).toBeCloseTo(-47.76755142211914); + expect(input_features.data[65]).toBeCloseTo(-36.32261276245117); + expect(input_features.data[1002]).toBeCloseTo(-28.0314884185791); + expect(input_features.data[10000]).toBeCloseTo(-21.905902862548828); + expect(input_features.data[60000]).toBeCloseTo(-14.877863883972168); + expect(input_features.data[64062]).toBeCloseTo(-37.9784049987793); + expect(input_features.data[64063]).toBeCloseTo(-37.73963928222656); + + // Reset Math.random + Math.random = originalRandom; + } + { // padding + const { input_features } = await processor(audio); + compare(input_features.dims, [1, 1, 1001, 64]); + + expect(avg(input_features.data)).toBeCloseTo(-34.99049377441406); + expect(input_features.data[0]).toBeCloseTo(-21.32573890686035); + expect(input_features.data[1]).toBeCloseTo(-26.168411254882812); + expect(input_features.data[65]).toBeCloseTo(-29.716018676757812); + expect(input_features.data[1002]).toBeCloseTo(-32.16273498535156); + expect(input_features.data[10000]).toBeCloseTo(-19.9283390045166); + + // padded values + expect(input_features.data[60000]).toBeCloseTo(-100.0); + expect(input_features.data[64062]).toBeCloseTo(-100.0); + expect(input_features.data[64063]).toBeCloseTo(-100.0); + } + + + }, MAX_TEST_EXECUTION_TIME); }); }); diff --git a/tests/tensor.test.js b/tests/tensor.test.js index c2dc3374c..0d328a9ef 100644 --- a/tests/tensor.test.js +++ b/tests/tensor.test.js @@ -1,9 +1,6 @@ -import { AutoProcessor, Tensor } from '../src/transformers.js'; - -import { MAX_TEST_EXECUTION_TIME, m } from './init.js'; +import { Tensor } from '../src/transformers.js'; import { compare } from './test_utils.js'; - import { cat, mean, stack } from '../src/utils/tensor.js'; describe('Tensor operations', () => { diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 0f9e1a57c..2d4dfe683 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -3,6 +3,7 @@ import { AutoTokenizer } from '../src/transformers.js'; import { getFile } from '../src/utils/hub.js'; import { m, MAX_TEST_EXECUTION_TIME } from './init.js'; +import { compare } from './test_utils.js'; // Load test data generated by the python tests // TODO do this dynamically? @@ -41,10 +42,42 @@ describe('Tokenizers', () => { describe('Edge cases', () => { it('should not crash when encoding a very long string', async () => { - let tokenizer = await AutoTokenizer.from_pretrained('t5-small'); + let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); let text = String.prototype.repeat.call('Hello world! ', 50000); - let encoded = await tokenizer(text); + let encoded = tokenizer(text); expect(encoded.input_ids.data.length).toBeGreaterThan(100000); }, MAX_TEST_EXECUTION_TIME); + + it('should not take too long', async () => { + let tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2'); + + let text = String.prototype.repeat.call('a', 50000); + let token_ids = tokenizer.encode(text); + compare(token_ids, [101, 100, 102]) + }, 5000); // NOTE: 5 seconds +}); + +describe('Extra decoding tests', () => { + it('should be able to decode the output of encode', async () => { + let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + let text = 'hello world!'; + + // Ensure all the following outputs are the same: + // 1. Tensor of ids: allow decoding of 1D or 2D tensors. + let encodedTensor = tokenizer(text); + let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true }); + let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0]; + expect(decoded1).toEqual(text); + expect(decoded2).toEqual(text); + + // 2. List of ids + let encodedList = tokenizer(text, { return_tensor: false }); + let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true }); + let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0]; + expect(decoded3).toEqual(text); + expect(decoded4).toEqual(text); + + }, MAX_TEST_EXECUTION_TIME); }); diff --git a/tests/utils.test.js b/tests/utils.test.js index d1fdd6704..504d96817 100644 --- a/tests/utils.test.js +++ b/tests/utils.test.js @@ -1,8 +1,8 @@ import { AutoProcessor } from '../src/transformers.js'; -import { getMelFilters } from '../src/utils/audio.js'; +import { mel_filter_bank } from '../src/utils/audio.js'; -import { MAX_TEST_EXECUTION_TIME, m } from './init.js'; +import { MAX_TEST_EXECUTION_TIME } from './init.js'; describe('Utilities', () => { @@ -11,28 +11,32 @@ describe('Utilities', () => { it('should calculate MEL filters', async () => { // NOTE: Uses official HF implementation as reference: - let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); - - let config = processor.feature_extractor.config; - - let maxdiff = 0; + const processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); + const config = processor.feature_extractor.config; // True MEL filters - let original_mel_filters = config.mel_filters; + const original_mel_filters = config.mel_filters; // Calculated MEL filters - let calculated_mel_filters = getMelFilters(config.sampling_rate, config.n_fft, config.feature_size); - - for (let i = 0; i < original_mel_filters.length; ++i) { - for (let j = 0; j < original_mel_filters[i].length; ++j) { - const expected = original_mel_filters[i][j]; - const calculated = calculated_mel_filters[i][j]; - - const diff = Math.abs(expected - calculated); - maxdiff = Math.max(maxdiff, diff); - } - } - + const calculated_mel_filters = mel_filter_bank( + Math.floor(1 + config.n_fft / 2), // num_frequency_bins + config.feature_size, // num_mel_filters + 0.0, // min_frequency + 8000.0, // max_frequency + config.sampling_rate, // sampling_rate + "slaney", // norm + "slaney", // mel_scale + ); + + const original = original_mel_filters.flat(); + const calculated = calculated_mel_filters.flat(); + + // Compute max difference + const maxdiff = original.reduce((maxdiff, _, i) => { + const diff = Math.abs(original[i] - calculated[i]); + return Math.max(maxdiff, diff); + }, -Infinity); + expect(maxdiff).toBeGreaterThanOrEqual(0); expect(maxdiff).toBeLessThan(1e-6); }, MAX_TEST_EXECUTION_TIME);