diff --git a/src/utils/core.js b/src/utils/core.js index e11fd087b..3cea50782 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -1,10 +1,10 @@ /** * @file Core utility functions/classes for Transformers.js. - * + * * These are only used internally, meaning an end-user shouldn't * need to access anything here. - * + * * @module utils/core */ @@ -65,7 +65,7 @@ export function escapeRegExp(string) { * Check if a value is a typed array. * @param {*} val The value to check. * @returns {boolean} True if the value is a `TypedArray`, false otherwise. - * + * * Adapted from https://stackoverflow.com/a/71091338/13989043 */ export function isTypedArray(val) { @@ -82,6 +82,15 @@ export function isIntegralNumber(x) { return Number.isInteger(x) || typeof x === 'bigint' } +/** + * Determine if a provided width or height is nullish. + * @param {*} x The value to check. + * @returns {boolean} True if the value is `null`, `undefined` or `-1`, false otherwise. + */ +export function isNullishDimension(x) { + return x === null || x === undefined || x === -1; +} + /** * Calculates the dimensions of a nested array. * @@ -151,9 +160,9 @@ export function calculateReflectOffset(i, w) { } /** - * - * @param {Object} o - * @param {string[]} props + * + * @param {Object} o + * @param {string[]} props * @returns {Object} */ export function pick(o, props) { @@ -170,7 +179,7 @@ export function pick(o, props) { /** * Calculate the length of a string, taking multi-byte characters into account. * This mimics the behavior of Python's `len` function. - * @param {string} s The string to calculate the length of. + * @param {string} s The string to calculate the length of. * @returns {number} The length of the string. */ export function len(s) { diff --git a/src/utils/image.js b/src/utils/image.js index 73114b13b..03648d233 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -1,13 +1,14 @@ /** - * @file Helper module for image processing. - * - * These functions and classes are only used internally, + * @file Helper module for image processing. + * + * These functions and classes are only used internally, * meaning an end-user shouldn't need to access anything here. - * + * * @module utils/image */ +import { isNullishDimension } from './core.js'; import { getFile } from './hub.js'; import { env } from '../env.js'; import { Tensor } from './tensor.js'; @@ -91,7 +92,7 @@ export class RawImage { this.channels = channels; } - /** + /** * Returns the size of the image (width, height). * @returns {[number, number]} The size of the image (width, height). */ @@ -101,9 +102,9 @@ export class RawImage { /** * Helper method for reading an image from a variety of input types. - * @param {RawImage|string|URL} input + * @param {RawImage|string|URL} input * @returns The image object. - * + * * **Example:** Read image from a URL. * ```javascript * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); @@ -181,7 +182,7 @@ export class RawImage { /** * Helper method to create a new Image from a tensor - * @param {Tensor} tensor + * @param {Tensor} tensor */ static fromTensor(tensor, channel_format = 'CHW') { if (tensor.dims.length !== 3) { @@ -306,8 +307,8 @@ export class RawImage { /** * Resize the image to the given dimensions. This method uses the canvas API to perform the resizing. - * @param {number} width The width of the new image. - * @param {number} height The height of the new image. + * @param {number} width The width of the new image. `null` or `-1` will preserve the aspect ratio. + * @param {number} height The height of the new image. `null` or `-1` will preserve the aspect ratio. * @param {Object} options Additional options for resizing. * @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use. * @returns {Promise} `this` to support chaining. @@ -324,6 +325,20 @@ export class RawImage { // Ensure resample method is a string let resampleMethod = RESAMPLING_MAPPING[resample] ?? resample; + // Calculate width / height to maintain aspect ratio, in the event that + // the user passed a null value in. + // This allows users to pass in something like `resize(320, null)` to + // resize to 320 width, but maintain aspect ratio. + const nullish_width = isNullishDimension(width); + const nullish_height = isNullishDimension(height); + if (nullish_width && nullish_height) { + return this; + } else if (nullish_width) { + width = (height / this.height) * this.width; + } else if (nullish_height) { + height = (width / this.width) * this.height; + } + if (BROWSER_ENV) { // TODO use `resample` in browser environment @@ -360,7 +375,7 @@ export class RawImage { case 'nearest': case 'bilinear': case 'bicubic': - // Perform resizing using affine transform. + // Perform resizing using affine transform. // This matches how the python Pillow library does it. img = img.affine([width / this.width, 0, 0, height / this.height], { interpolator: resampleMethod @@ -373,7 +388,7 @@ export class RawImage { img = img.resize({ width, height, fit: 'fill', - kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3 + kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3 }); break; @@ -452,7 +467,7 @@ export class RawImage { // Create canvas object for this image const canvas = this.toCanvas(); - // Create a new canvas of the desired size. This is needed since if the + // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); @@ -500,7 +515,7 @@ export class RawImage { // Create canvas object for this image const canvas = this.toCanvas(); - // Create a new canvas of the desired size. This is needed since if the + // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); diff --git a/tests/utils/utils.test.js b/tests/utils/utils.test.js index 8a1891f19..7d50cbdc0 100644 --- a/tests/utils/utils.test.js +++ b/tests/utils/utils.test.js @@ -1,5 +1,6 @@ import { AutoProcessor, hamming, hanning, mel_filter_bank } from "../../src/transformers.js"; import { getFile } from "../../src/utils/hub.js"; +import { RawImage } from "../../src/utils/image.js"; import { MAX_TEST_EXECUTION_TIME } from "../init.js"; import { compare } from "../test_utils.js"; @@ -59,4 +60,35 @@ describe("Utilities", () => { expect(await data.text()).toBe("Hello, world!"); }); }); + + describe("Image utilities", () => { + let image; + beforeAll(async () => { + image = await RawImage.fromURL("https://picsum.photos/300/200"); + }); + + it("Read image from URL", async () => { + expect(image.width).toBe(300); + expect(image.height).toBe(200); + expect(image.channels).toBe(3); + }); + + it("Can resize image", async () => { + const resized = await image.resize(150, 100); + expect(resized.width).toBe(150); + expect(resized.height).toBe(100); + }); + + it("Can resize with aspect ratio", async () => { + const resized = await image.resize(150, null); + expect(resized.width).toBe(150); + expect(resized.height).toBe(100); + }); + + it("Returns original image if width and height are null", async () => { + const resized = await image.resize(null, null); + expect(resized.width).toBe(300); + expect(resized.height).toBe(200); + }); + }); });