From 73a99ba0afe5e1630881b240ac1d58cf08ed9890 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 9 Nov 2023 17:57:32 +0200 Subject: [PATCH] Add image-to-image task w/ Swin2SR (for super-resolution) (#381) * Add `Swin2SRImageProcessor` * Add `RawImage.fromTensor` helper function * Add clamp tensor function * Add support for `.to` data type conversion * Add `round` tensor function * Add support for `mul` tensor function * Fix image padding * Only perform padding if it will affect size * Create basic processors unit test suite * Add SamProcessor test case * Move `CONTENT_TYPE_MAP` outside `RawImage` class * Perform reflective padding for swin2sr models * Add swin2sr models for image super-resolution * Add listed support for Swin2SR models * Add image-to-image pipeline * Add listed support for image-to-image task * Add image-to-image unit tests * Add `add` tensor functions * Generalize `pad_image` helper function * Add more unit tests for image processors * Fix typo --- README.md | 3 +- docs/snippets/5_supported-tasks.snippet | 2 +- docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 11 ++ src/models.js | 53 +++++ src/pipelines.js | 51 +++++ src/processors.js | 160 ++++++++++++--- src/utils/core.js | 10 + src/utils/image.js | 37 ++-- src/utils/tensor.js | 112 +++++++++++ tests/pipelines.test.js | 38 ++++ tests/processors.test.js | 240 +++++++++++++++++++++++ 12 files changed, 679 insertions(+), 39 deletions(-) create mode 100644 tests/processors.test.js diff --git a/README.md b/README.md index 7ee4e5a09..56af52aec 100644 --- a/README.md +++ b/README.md @@ -210,7 +210,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te | [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ❌ | | [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) | -| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ❌ | +| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) | | [Mask Generation](https://huggingface.co/tasks/mask-generation) | `mask-generation` | Generate masks for the objects in an image. | ❌ | | [Object Detection](https://huggingface.co/tasks/object-detection) | `object-detection` | Identify objects of certain defined classes within an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ObjectDetectionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=object-detection&library=transformers.js) | | [Video Classification](https://huggingface.co/tasks/video-classification) | n/a | Assigning a label or class to an entire video. | ❌ | @@ -302,6 +302,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. +1. **[Swin2SR](https://huggingface.co/docs/transformers/model_doc/swin2sr)** (from University of Würzburg) released with the paper [Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://arxiv.org/abs/2209.11345) by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index 5a2fff4d1..dee075808 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -25,7 +25,7 @@ | [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ❌ | | [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) | -| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ❌ | +| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) | | [Mask Generation](https://huggingface.co/tasks/mask-generation) | `mask-generation` | Generate masks for the objects in an image. | ❌ | | [Object Detection](https://huggingface.co/tasks/object-detection) | `object-detection` | Identify objects of certain defined classes within an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ObjectDetectionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=object-detection&library=transformers.js) | | [Video Classification](https://huggingface.co/tasks/video-classification) | n/a | Assigning a label or class to an entire video. | ❌ | diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 3b7315410..42c12bd2a 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -47,6 +47,7 @@ 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. +1. **[Swin2SR](https://huggingface.co/docs/transformers/model_doc/swin2sr)** (from University of Würzburg) released with the paper [Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://arxiv.org/abs/2209.11345) by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index f59122ae9..79d0282c5 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -406,6 +406,17 @@ 'microsoft/swin-large-patch4-window7-224-in22k', 'microsoft/swin-large-patch4-window12-384', ], + 'swin2sr': [ + # Image-to-image (Super-resolution) + 'caidas/swin2SR-classical-sr-x2-64', + 'caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr', + 'caidas/swin2SR-classical-sr-x4-64', + 'caidas/swin2SR-compressed-sr-x4-48', + 'caidas/swin2SR-lightweight-x2-64', + + # Feature extraction + 'hf-tiny-model-private/tiny-random-Swin2SRModel', + ], 't5': [ # Text-to-text (Translation/Summarization) 't5-small', diff --git a/src/models.js b/src/models.js index b2e918103..b0a82cee0 100644 --- a/src/models.js +++ b/src/models.js @@ -3300,6 +3300,49 @@ export class SwinForImageClassification extends SwinPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class Swin2SRPreTrainedModel extends PreTrainedModel { } + +/** + * The bare Swin2SR Model transformer outputting raw hidden-states without any specific head on top. + */ +export class Swin2SRModel extends Swin2SRPreTrainedModel { } + +/** + * Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration. + * + * **Example:** Super-resolution w/ `Xenova/swin2SR-classical-sr-x2-64`. + * + * ```javascript + * import { AutoProcessor, Swin2SRForImageSuperResolution, RawImage } from '@xenova/transformers'; + * + * // Load processor and model + * const model_id = 'Xenova/swin2SR-classical-sr-x2-64'; + * const processor = await AutoProcessor.from_pretrained(model_id); + * const model = await Swin2SRForImageSuperResolution.from_pretrained(model_id); + * + * // Prepare model inputs + * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/butterfly.jpg'; + * const image = await RawImage.fromURL(url); + * const inputs = await processor(image); + * + * // Run model + * const outputs = await model(inputs); + * + * // Convert Tensor to RawImage + * const output = outputs.reconstruction.squeeze().clamp_(0, 1).mul_(255).round_().to('uint8'); + * const outputImage = RawImage.fromTensor(output); + * // RawImage { + * // data: Uint8Array(786432) [ 41, 31, 24, ... ], + * // width: 512, + * // height: 512, + * // channels: 3 + * // } + * ``` + */ +export class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel { } +////////////////////////////////////////////////// + ////////////////////////////////////////////////// export class DonutSwinPreTrainedModel extends PreTrainedModel { } @@ -3950,6 +3993,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['deit', ['DeiTModel', DeiTModel]], ['resnet', ['ResNetModel', ResNetModel]], ['swin', ['SwinModel', SwinModel]], + ['swin2sr', ['Swin2SRModel', Swin2SRModel]], ['donut-swin', ['DonutSwinModel', DonutSwinModel]], ['yolos', ['YolosModel', YolosModel]], @@ -4124,6 +4168,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], ]); +const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([ + ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]], +]) + const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly], @@ -4139,6 +4187,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq], [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -4333,6 +4382,10 @@ export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]; } +export class AutoModelForImageToImage extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]; +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/src/pipelines.js b/src/pipelines.js index 8842b6337..0c2e0b3a3 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -34,6 +34,7 @@ import { AutoModelForImageSegmentation, AutoModelForObjectDetection, AutoModelForDocumentQuestionAnswering, + AutoModelForImageToImage, // AutoModelForTextToWaveform, PreTrainedModel, } from './models.js'; @@ -1947,6 +1948,44 @@ export class TextToAudioPipeline extends Pipeline { } } +/** + * Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous image input. + * + * **Example:** Super-resolution w/ `Xenova/swin2SR-classical-sr-x2-64` + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/butterfly.jpg'; + * let upscaler = await pipeline('image-to-image', 'Xenova/swin2SR-classical-sr-x2-64'); + * let output = await upscaler(url); + * // RawImage { + * // data: Uint8Array(786432) [ 41, 31, 24, 43, ... ], + * // width: 512, + * // height: 512, + * // channels: 3 + * // } + * ``` + */ +export class ImageToImagePipeline extends Pipeline { + /** + * Transform the image(s) passed as inputs. + * @param {any} images The images to transform. + * @returns {Promise} An image or a list of images containing result(s). + */ + async _call(images) { + images = await prepareImages(images); + + let inputs = await this.processor(images); + let outputs = await this.model(inputs); + + let toReturn = []; + for (let batch of outputs.reconstruction) { + const output = batch.squeeze().clamp_(0, 1).mul_(255).round_().to('uint8'); + toReturn.push(RawImage.fromTensor(output)); + } + + return toReturn.length > 1 ? toReturn : toReturn[0]; + } +} + const SUPPORTED_TASKS = { "text-classification": { "tokenizer": AutoTokenizer, @@ -2160,6 +2199,18 @@ const SUPPORTED_TASKS = { }, "type": "multimodal", }, + "image-to-image": { + // no tokenizer + "pipeline": ImageToImagePipeline, + "model": AutoModelForImageToImage, + "processor": AutoProcessor, + "default": { + // TODO: replace with original + // "model": "caidas/swin2SR-classical-sr-x2-64", + "model": "Xenova/swin2SR-classical-sr-x2-64", + }, + "type": "image", + }, // This task serves as a useful interface for dealing with sentence-transformers (https://huggingface.co/sentence-transformers). "feature-extraction": { diff --git a/src/processors.js b/src/processors.js index 0770261c1..7691cb413 100644 --- a/src/processors.js +++ b/src/processors.js @@ -22,6 +22,7 @@ import { Callable, calculateDimensions, + calculateReflectOffset, } from './utils/core.js'; import { @@ -229,6 +230,90 @@ export class ImageFeatureExtractor extends FeatureExtractor { return await image.resize(width, height, { resample }); } + + /** + * Pad the image by a certain amount. + * @param {Float32Array} pixelData The pixel data to pad. + * @param {number[]} imgDims The dimensions of the image. + * @param {{width:number; height:number}|number} padSize The dimensions of the padded image. + * @param {Object} options The options for padding. + * @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add. + * @param {boolean} [options.center=false] Whether to center the image. + * @param {number} [options.constant_values=0] The constant value to use for padding. + * @returns {[Float32Array, number[]]} The padded pixel data and image dimensions. + */ + pad_image(pixelData, imgDims, padSize, { + mode = 'constant', + center = false, + constant_values = 0, + } = {}) { + const [imageWidth, imageHeight, imageChannels] = imgDims; + + let paddedImageWidth, paddedImageHeight; + if (typeof padSize === 'number') { + paddedImageWidth = padSize; + paddedImageHeight = padSize; + } else { + paddedImageWidth = padSize.width; + paddedImageHeight = padSize.height; + } + + // Only add padding if there is a difference in size + if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) { + const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels); + if (constant_values !== 0) { + paddedPixelData.fill(constant_values); + } + + const [left, top] = center + ? [Math.floor((paddedImageWidth - imageWidth) / 2), Math.floor((paddedImageHeight - imageHeight) / 2)] + : [0, 0]; + + // Copy the original image into the padded image + for (let i = 0; i < imageHeight; ++i) { + const a = (i + top) * paddedImageWidth; + const b = i * imageWidth; + for (let j = 0; j < imageWidth; ++j) { + const c = (a + j + left) * imageChannels; + const d = (b + j) * imageChannels; + for (let k = 0; k < imageChannels; ++k) { + paddedPixelData[c + k] = pixelData[d + k]; + } + } + } + + if (mode === 'symmetric') { + if (center) { + throw new Error('`center` padding is not supported when `mode` is set to `symmetric`.'); + // TODO: Implement this + } + const h1 = imageHeight - 1; + const w1 = imageWidth - 1; + for (let i = 0; i < paddedImageHeight; ++i) { + const a = i * paddedImageWidth; + const b = calculateReflectOffset(i, h1) * imageWidth; + + for (let j = 0; j < paddedImageWidth; ++j) { + if (i < imageHeight && j < imageWidth) continue; // Do not overwrite original image + const c = (a + j) * imageChannels; + const d = (b + calculateReflectOffset(j, w1)) * imageChannels; + + // Copy channel-wise + for (let k = 0; k < imageChannels; ++k) { + paddedPixelData[c + k] = pixelData[d + k]; + } + } + } + } + + + // Update pixel data and image dimensions + pixelData = paddedPixelData; + imgDims = [paddedImageHeight, paddedImageWidth, imageChannels] + } + return [pixelData, imgDims]; + } + /** * @typedef {object} PreprocessedImage * @property {HeightWidth} original_size The original size of the image. @@ -339,17 +424,8 @@ export class ImageFeatureExtractor extends FeatureExtractor { /** @type {HeightWidth} */ let reshaped_input_size = [image.height, image.width]; - // TODO is it okay to pad before rescaling/normalizing? - if (this.do_pad && this.pad_size) { - let left = 0; - let right = this.pad_size.width - image.width; - let top = 0; - let bottom = this.pad_size.height - image.height; - - image = await image.pad([left, right, top, bottom]); - } - - const pixelData = Float32Array.from(image.data); + let pixelData = Float32Array.from(image.data); + let imgDims = [image.height, image.width, image.channels]; if (this.do_rescale) { for (let i = 0; i < pixelData.length; ++i) { @@ -379,10 +455,17 @@ export class ImageFeatureExtractor extends FeatureExtractor { } } + // do padding after rescaling/normalizing + if (this.do_pad && this.pad_size) { + const padded = this.pad_image(pixelData, [image.width, image.height, image.channels], this.pad_size); + [pixelData, imgDims] = padded; // Update pixel data and image dimensions + } + + // Create HWC tensor + const img = new Tensor('float32', pixelData, imgDims); + // convert to channel dimension format: - let imgDims = [image.height, image.width, image.channels]; - let img = new Tensor('float32', pixelData, imgDims); - let transposed = transpose(img, [2, 0, 1]); // hwc -> chw + const transposed = transpose(img, [2, 0, 1]); // hwc -> chw return { original_size: [srcHeight, srcWidth], @@ -431,7 +514,19 @@ export class ViTFeatureExtractor extends ImageFeatureExtractor { } export class MobileViTFeatureExtractor extends ImageFeatureExtractor { } export class DeiTFeatureExtractor extends ImageFeatureExtractor { } export class BeitFeatureExtractor extends ImageFeatureExtractor { } -export class DonutFeatureExtractor extends ImageFeatureExtractor { } +export class DonutFeatureExtractor extends ImageFeatureExtractor { + pad_image(pixelData, imgDims, padSize, options = {}) { + return super.pad_image(pixelData, imgDims, padSize, { + center: true, + + // Since normalization is done after padding, we need to pad with -1. + // NOTE: This only works if `image_mean = 0.5` and `image_std = 0.5`. + // For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451 + constant_values: -1, + ...options, + }); + } +} /** * @typedef {object} DetrFeatureExtractorResultProps @@ -908,6 +1003,26 @@ export class SamImageProcessor extends ImageFeatureExtractor { } } +export class Swin2SRImageProcessor extends ImageFeatureExtractor { + pad_image(pixelData, imgDims, padSize, options = {}) { + // NOTE: In this case, `padSize` represents the size of the sliding window for the local attention. + // In other words, the image is padded so that its width and height are multiples of `padSize`. + const [imageWidth, imageHeight, imageChannels] = imgDims; + + return super.pad_image(pixelData, imgDims, { + // NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already + // a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19). + // For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`. + width: imageWidth + (padSize - imageWidth % padSize) % padSize, + height: imageHeight + (padSize - imageHeight % padSize) % padSize, + }, { + mode: 'symmetric', + center: false, + constant_values: -1, + ...options, + }) + } +} export class WhisperFeatureExtractor extends FeatureExtractor { @@ -917,15 +1032,7 @@ export class WhisperFeatureExtractor extends FeatureExtractor { // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. this.config.mel_filters ??= getMelFilters(this.config.sampling_rate, this.config.n_fft, this.config.feature_size); } - /** - * Calculates the index offset for a given index and window size. - * @param {number} i The index. - * @param {number} w The window size. - * @returns {number} The index offset. - */ - calcOffset(i, w) { - return Math.abs((i + w) % (2 * w) - w); - } + /** * Pads an array with a reflected version of itself on both ends. @@ -943,11 +1050,11 @@ export class WhisperFeatureExtractor extends FeatureExtractor { } for (let i = 1; i <= left; ++i) { - padded[left - i] = array[this.calcOffset(i, w)]; + padded[left - i] = array[calculateReflectOffset(i, w)]; } for (let i = 1; i <= right; ++i) { - padded[w + left + i] = array[this.calcOffset(w - i, w)]; + padded[w + left + i] = array[calculateReflectOffset(w - i, w)]; } return padded; @@ -1439,6 +1546,7 @@ export class AutoProcessor { DonutFeatureExtractor, SamImageProcessor, + Swin2SRImageProcessor, Wav2Vec2FeatureExtractor, SpeechT5FeatureExtractor, } diff --git a/src/utils/core.js b/src/utils/core.js index 718ca4f0e..7de13625d 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -174,3 +174,13 @@ export function product(...a) { // Adapted from https://stackoverflow.com/a/43053803 return a.reduce((a, b) => a.flatMap(d => b.map(e => [d, e]))); } + +/** + * Calculates the index offset for a given index and window size. + * @param {number} i The index. + * @param {number} w The window size. + * @returns {number} The index offset. + */ +export function calculateReflectOffset(i, w) { + return Math.abs((i + w) % (2 * w) - w); +} diff --git a/src/utils/image.js b/src/utils/image.js index 0016de8dc..cb9303572 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -64,17 +64,17 @@ const RESAMPLING_MAPPING = { 5: 'hamming', } -export class RawImage { +/** + * Mapping from file extensions to MIME types. + */ +const CONTENT_TYPE_MAP = new Map([ + ['png', 'image/png'], + ['jpg', 'image/jpeg'], + ['jpeg', 'image/jpeg'], + ['gif', 'image/gif'], +]); - /** - * Mapping from file extensions to MIME types. - */ - _CONTENT_TYPE_MAP = { - 'png': 'image/png', - 'jpg': 'image/jpeg', - 'jpeg': 'image/jpeg', - 'gif': 'image/gif', - } +export class RawImage { /** * Create a new `RawImage` object. @@ -156,6 +156,21 @@ export class RawImage { } } + /** + * Helper method to create a new Image from a tensor + * @param {import('./tensor.js').Tensor} tensor + */ + static fromTensor(tensor, channel_format = 'CHW') { + if (channel_format === 'CHW') { + tensor = tensor.transpose(1, 2, 0); + } else if (channel_format === 'HWC') { + // Do nothing + } else { + throw new Error(`Unsupported channel format: ${channel_format}`); + } + return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]); + } + /** * Convert the image to grayscale format. * @returns {RawImage} `this` to support chaining. @@ -564,7 +579,7 @@ export class RawImage { if (BROWSER_ENV) { const extension = path.split('.').pop().toLowerCase(); - const mime = this._CONTENT_TYPE_MAP[extension] ?? 'image/png'; + const mime = CONTENT_TYPE_MAP.get(extension) ?? 'image/png'; // Convert image to canvas const canvas = this.toCanvas(); diff --git a/src/utils/tensor.js b/src/utils/tensor.js index f5c6dff83..b575b5136 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -15,6 +15,21 @@ import { } from './maths.js'; +// @ts-ignore +const DataTypeMap = new Map([ + ['bool', Uint8Array], + ['float32', Float32Array], + ['float64', Float64Array], + ['string', Array], // string[] + ['int8', Int8Array], + ['uint8', Uint8Array], + ['int16', Int16Array], + ['uint16', Uint16Array], + ['int32', Int32Array], + ['uint32', Uint32Array], + ['int64', BigInt64Array], +]) + /** * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray */ @@ -160,6 +175,48 @@ export class Tensor extends ONNXTensor { return this; } + /** + * Return a new Tensor with every element multiplied by a constant. + * @param {number} val The value to multiply by. + * @returns {Tensor} The new tensor. + */ + mul(val) { + return this.clone().mul_(val); + } + + /** + * Multiply the tensor by a constant in place. + * @param {number} val The value to multiply by. + * @returns {Tensor} Returns `this`. + */ + mul_(val) { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] *= val; + } + return this; + } + + + /** + * Return a new Tensor with every element added by a constant. + * @param {number} val The value to add by. + * @returns {Tensor} The new tensor. + */ + add(val) { + return this.clone().add_(val); + } + + /** + * Add the tensor by a constant in place. + * @param {number} val The value to add by. + * @returns {Tensor} Returns `this`. + */ + add_(val) { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] += val; + } + return this; + } clone() { return new Tensor(this.type, this.data.slice(), this.dims.slice()); } @@ -482,6 +539,61 @@ export class Tensor extends ONNXTensor { neg() { return this.clone().neg_(); } + + /** + * In-place version of @see {@link Tensor.clamp} + */ + clamp_(min, max) { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = Math.min(Math.max(this.data[i], min), max); + } + return this; + } + + /** + * Clamps all elements in input into the range [ min, max ] + * @param {number} min lower-bound of the range to be clamped to + * @param {number} max upper-bound of the range to be clamped to + * @returns the output tensor. + */ + clamp(min, max) { + return this.clone().clamp_(min, max); + } + + /** + * In-place version of @see {@link Tensor.round} + */ + round_() { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = Math.round(this.data[i]); + } + return this; + } + + /** + * Rounds elements of input to the nearest integer. + * @returns the output tensor. + */ + round() { + return this.clone().round_(); + } + + /** + * Performs Tensor dtype conversion. + * @param {'bool'|'float32'|'float64'|'string'|'int8'|'uint8'|'int16'|'uint16'|'int32'|'uint32'|'int64'} type + * @returns {Tensor} The converted tensor. + */ + to(type) { + // If the self Tensor already has the correct dtype, then self is returned. + if (this.type === type) return this; + + // Otherwise, the returned tensor is a copy of self with the desired dtype. + const ArrayConstructor = DataTypeMap.get(type); + if (!ArrayConstructor) { + throw new Error(`Unsupported type: ${type}`); + } + return new Tensor(type, ArrayConstructor.from(this.data), this.dims); + } } /** diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index f42f1a861..d8f5e56c0 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1326,6 +1326,44 @@ describe('Pipelines', () => { }, MAX_TEST_EXECUTION_TIME); }); + describe('Image-to-image', () => { + + // List all models which will be tested + const models = [ + 'caidas/swin2SR-classical-sr-x2-64', + ]; + + it(models[0], async () => { + let upscaler = await pipeline('image-to-image', m(models[0])); + + // Input is 3x3 => padded to 8x8 => upscaled to 16x16 + let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png'; + + // single + { + let outputs = await upscaler(url); + expect(outputs.width).toEqual(16); + expect(outputs.height).toEqual(16); + expect(outputs.channels).toEqual(3); + expect(outputs.data).toHaveLength(768); + } + + // batched + { + let outputs = await upscaler([url, url]); + expect(outputs).toHaveLength(2); + for (let output of outputs) { + expect(output.width).toEqual(16); + expect(output.height).toEqual(16); + expect(output.channels).toEqual(3); + expect(output.data).toHaveLength(768); + } + } + + await upscaler.dispose(); + }, MAX_TEST_EXECUTION_TIME); + }); + describe('Document question answering', () => { // List all models which will be tested diff --git a/tests/processors.test.js b/tests/processors.test.js new file mode 100644 index 000000000..fe594613e --- /dev/null +++ b/tests/processors.test.js @@ -0,0 +1,240 @@ + +import { env, AutoProcessor, RawImage } from '../src/transformers.js'; +import { m, MAX_TEST_EXECUTION_TIME } from './init.js'; +import { compare } from './test_utils.js'; + +// Initialise the testing environment +env.allowLocalModels = false; +env.useFSCache = false; + +const avg = (array) => { + return Number(array.reduce((a, b) => a + b, array instanceof BigInt64Array ? 0n : 0)) / array.length; +} + +describe('Processors', () => { + + describe('Image processors', () => { + + const IMAGE_CACHE = new Map(); + const load_image = async (url) => { + const cached = IMAGE_CACHE.get(url); + if (cached) { + return cached; + } + const image = await RawImage.fromURL(url); + IMAGE_CACHE.set(url, image); + return image; + } + + const MODELS = { + swin2sr: 'caidas/swin2SR-classical-sr-x2-64', + sam: 'facebook/sam-vit-base', + 'donut-swin': 'naver-clova-ix/donut-base-finetuned-cord-v2', + resnet: 'microsoft/resnet-50', + vit: 'google/vit-base-patch16-224', + mobilevit: 'apple/mobilevit-small', + mobilevit_2: 'Xenova/quickdraw-mobilevit-small', + deit: 'facebook/deit-tiny-distilled-patch16-224', + beit: 'microsoft/beit-base-patch16-224-pt22k-ft22k', + detr: 'facebook/detr-resnet-50', + yolos: 'hustvl/yolos-small-300', + } + + const TEST_IMAGES = { + pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png', + checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', + receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', + tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', + + // grayscale image + skateboard: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png', + } + + // Swin2SRImageProcessor + // - tests when padding is a number (do_pad=true, pad_size=8) + it(MODELS.swin2sr, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.swin2sr)) + + { // Pad to multiple of 8 (3x3 -> 8x8) + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { pixel_values } = await processor(image); + + compare(pixel_values.dims, [1, 3, 8, 8]); + compare(avg(pixel_values.data), 0.5458333368102709); + } + + { // Do not pad if already a multiple of 8 (8x8 -> 8x8) + const image = await load_image(TEST_IMAGES.checkerboard_8x8); + const { pixel_values } = await processor(image); + compare(pixel_values.dims, [1, 3, 8, 8]); + compare(avg(pixel_values.data), 0.5); + } + }, MAX_TEST_EXECUTION_TIME); + + // SamProcessor/SamImageProcessor + // - tests normal padding (do_pad=true, pad_size={"height":1024,"width":1024}) + // - In addition to the image, pass in a list of points + it(MODELS.sam, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.sam)) + + { // Basic test + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { pixel_values } = await processor(image, [[[0, 0]]]); + compare(pixel_values.dims, [1, 3, 1024, 1024]); + compare(avg(pixel_values.data), -0.4505715670146813); + } + }, MAX_TEST_EXECUTION_TIME); + + // DonutProcessor/DonutFeatureExtractor + // - tests thumbnail resizing (do_thumbnail=true, size=[960, 1280]) + it(MODELS['donut-swin'], async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS['donut-swin'])) + + { + const image = await load_image(TEST_IMAGES.receipt); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 1280, 960]); + compare(avg(pixel_values.data), 0.1229388610053704); + + compare(original_sizes, [[864, 576]]); + compare(reshaped_input_sizes, [[1280, 853]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // ConvNextFeatureExtractor + it(MODELS.resnet, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.resnet)) + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 336]); + compare(avg(pixel_values.data), -0.27736667280600913); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 336]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // ViTFeatureExtractor + it(MODELS.vit, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.vit)) + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.22706867939852762); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // MobileViTFeatureExtractor + it(MODELS.mobilevit, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.mobilevit)) + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 256, 256]); + compare(avg(pixel_values.data), 0.4599160496887033); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[256, 256]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // MobileViTFeatureExtractor + // - tests not converting to rgb (do_convert_rgb=false) + it(MODELS.mobilevit_2, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.mobilevit_2)) + + { // Tests grayscale image + const image = await load_image(TEST_IMAGES.skateboard); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 1, 28, 28]); + compare(avg(pixel_values.data), 0.08558923671585128); + + compare(original_sizes, [[28, 28]]); + compare(reshaped_input_sizes, [[28, 28]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // DeiTFeatureExtractor + it(MODELS.deit, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.deit)) + + { // Tests grayscale image + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.2760336682859463); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, MAX_TEST_EXECUTION_TIME); + + // BeitFeatureExtractor + it(MODELS.beit, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.beit)) + + { // Tests grayscale image + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.22706867939852762); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, MAX_TEST_EXECUTION_TIME); + + + // DetrFeatureExtractor + it(MODELS.detr, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.detr)) + + { // Tests grayscale image + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image); + + compare(pixel_values.dims, [1, 3, 888, 1333]); + compare(avg(pixel_values.data), -0.27840224131001773); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[888, 1333]]); + + compare(pixel_mask.dims, [1, 64, 64]); + compare(avg(pixel_mask.data), 1); + + } + }, MAX_TEST_EXECUTION_TIME); + + + // YolosFeatureExtractor + it(MODELS.yolos, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.yolos)) + + { // Tests grayscale image + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 888, 1333]); + compare(avg(pixel_values.data), -0.27840224131001773); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[888, 1333]]); + } + }, MAX_TEST_EXECUTION_TIME); + }); +});