Skip to content

Commit

Permalink
Create RawImage.toTensor helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 19, 2024
1 parent 63e13ed commit 85f1b1f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
6 changes: 1 addition & 5 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3779,11 +3779,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { }
* import { Tensor, cat } from '@xenova/transformers';
*
* // Visualize predicted alpha matte
* const imageTensor = new Tensor(
* 'uint8',
* new Uint8Array(image.data),
* [image.height, image.width, image.channels]
* ).transpose(2, 0, 1);
* const imageTensor = image.toTensor();
*
* // Convert float (0-1) alpha matte to uint8 (0-255)
* const alphaChannel = alphas
Expand Down
20 changes: 19 additions & 1 deletion src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import { getFile } from './hub.js';
import { env } from '../env.js';
import { Tensor } from './tensor.js';

// Will be empty (or not used) if running in browser or web-worker
import sharp from 'sharp';
Expand Down Expand Up @@ -166,7 +167,7 @@ export class RawImage {

/**
* Helper method to create a new Image from a tensor
* @param {import('./tensor.js').Tensor} tensor
* @param {Tensor} tensor
*/
static fromTensor(tensor, channel_format = 'CHW') {
if (tensor.dims.length !== 3) {
Expand Down Expand Up @@ -586,6 +587,23 @@ export class RawImage {
return await canvas.convertToBlob({ type, quality });
}

toTensor(channel_format = 'CHW') {
let tensor = new Tensor(
'uint8',
new Uint8Array(this.data),
[this.height, this.width, this.channels]
);

if (channel_format === 'HWC') {
// Do nothing
} else if (channel_format === 'CHW') { // hwc -> chw
tensor = tensor.permute(2, 0, 1);
} else {
throw new Error(`Unsupported channel format: ${channel_format}`);
}
return tensor;
}

toCanvas() {
if (!BROWSER_ENV) {
throw new Error('toCanvas() is only supported in browser environments.')
Expand Down

0 comments on commit 85f1b1f

Please sign in to comment.