diff --git a/src/models.js b/src/models.js index 1a9b021f1..9c65ce5d8 100644 --- a/src/models.js +++ b/src/models.js @@ -3779,11 +3779,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { } * import { Tensor, cat } from '@xenova/transformers'; * * // Visualize predicted alpha matte - * const imageTensor = new Tensor( - * 'uint8', - * new Uint8Array(image.data), - * [image.height, image.width, image.channels] - * ).transpose(2, 0, 1); + * const imageTensor = image.toTensor(); * * // Convert float (0-1) alpha matte to uint8 (0-255) * const alphaChannel = alphas diff --git a/src/utils/image.js b/src/utils/image.js index 2d12cb876..1ee77d900 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -10,6 +10,7 @@ import { getFile } from './hub.js'; import { env } from '../env.js'; +import { Tensor } from './tensor.js'; // Will be empty (or not used) if running in browser or web-worker import sharp from 'sharp'; @@ -166,7 +167,7 @@ export class RawImage { /** * Helper method to create a new Image from a tensor - * @param {import('./tensor.js').Tensor} tensor + * @param {Tensor} tensor */ static fromTensor(tensor, channel_format = 'CHW') { if (tensor.dims.length !== 3) { @@ -586,6 +587,23 @@ export class RawImage { return await canvas.convertToBlob({ type, quality }); } + toTensor(channel_format = 'CHW') { + let tensor = new Tensor( + 'uint8', + new Uint8Array(this.data), + [this.height, this.width, this.channels] + ); + + if (channel_format === 'HWC') { + // Do nothing + } else if (channel_format === 'CHW') { // hwc -> chw + tensor = tensor.permute(2, 0, 1); + } else { + throw new Error(`Unsupported channel format: ${channel_format}`); + } + return tensor; + } + toCanvas() { if (!BROWSER_ENV) { throw new Error('toCanvas() is only supported in browser environments.')