diff --git a/src/processors.js b/src/processors.js index 6191daa40..4713a6ae2 100644 --- a/src/processors.js +++ b/src/processors.js @@ -33,6 +33,7 @@ import { min, max, softmax, + bankers_round, } from './utils/maths.js'; @@ -174,14 +175,15 @@ function validate_audio_inputs(audio, feature_extractor) { * @private */ function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { - let x = Math.round(val / multiple) * multiple; + const a = val / multiple; + let x = bankers_round(a) * multiple; if (maxVal !== null && x > maxVal) { - x = Math.floor(val / multiple) * multiple; + x = Math.floor(a) * multiple; } if (x < minVal) { - x = Math.ceil(val / multiple) * multiple; + x = Math.ceil(a) * multiple; } return x; @@ -513,8 +515,8 @@ export class ImageFeatureExtractor extends FeatureExtractor { if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) { // determine new height and width - let scale_height = size.height / srcHeight; - let scale_width = size.width / srcWidth; + let scale_height = newHeight / srcHeight; + let scale_width = newWidth / srcWidth; // scale as little as possible if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) { @@ -765,9 +767,9 @@ export class SegformerFeatureExtractor extends ImageFeatureExtractor { return toReturn; } } -export class DPTImageProcessor extends ImageFeatureExtractor { } -export class BitImageProcessor extends ImageFeatureExtractor { } export class DPTFeatureExtractor extends ImageFeatureExtractor { } +export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor +export class BitImageProcessor extends ImageFeatureExtractor { } export class GLPNFeatureExtractor extends ImageFeatureExtractor { } export class CLIPFeatureExtractor extends ImageFeatureExtractor { } export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { } diff --git a/src/utils/maths.js b/src/utils/maths.js index e33392973..264b69fc7 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -952,3 +952,17 @@ export function round(num, decimals) { const pow = Math.pow(10, decimals); return Math.round(num * pow) / pow; } + +/** + * Helper function to round a number to the nearest integer, with ties rounded to the nearest even number. + * Also known as "bankers' rounding". This is the default rounding mode in python. For example: + * 1.5 rounds to 2 and 2.5 rounds to 2. + * + * @param {number} x The number to round + * @returns {number} The rounded number + */ +export function bankers_round(x) { + const r = Math.round(x); + const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r; + return br; +} diff --git a/tests/maths.test.js b/tests/maths.test.js index 9a7d3dc3c..480af5de3 100644 --- a/tests/maths.test.js +++ b/tests/maths.test.js @@ -2,7 +2,7 @@ import { compare } from './test_utils.js'; import { getFile } from '../src/utils/hub.js'; -import { FFT, medianFilter } from '../src/utils/maths.js'; +import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js'; const fft = (arr, complex = false) => { @@ -27,6 +27,19 @@ const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json() describe('Mathematical operations', () => { + describe('bankers rounding', () => { + it('should round up to nearest even', () => { + expect(bankers_round(-0.5)).toBe(0); + expect(bankers_round(1.5)).toBe(2); + expect(bankers_round(19.5)).toBe(20); + }); + it('should round down to nearest even', () => { + expect(bankers_round(-1.5)).toBe(-2); + expect(bankers_round(2.5)).toBe(2); + expect(bankers_round(18.5)).toBe(18); + }); + }); + describe('median filtering', () => { diff --git a/tests/processors.test.js b/tests/processors.test.js index e972daf78..c9ab33982 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -52,6 +52,7 @@ describe('Processors', () => { pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png', pattern_3x5: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x5.png', checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', + checkerboard_64x32: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_64x32.png', receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png', @@ -433,6 +434,7 @@ describe('Processors', () => { // DPTImageProcessor // - tests ensure_multiple_of // - tests keep_aspect_ratio + // - tests bankers rounding it(MODELS.dpt_2, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt_2)) @@ -446,6 +448,18 @@ describe('Processors', () => { compare(original_sizes, [[480, 640]]); compare(reshaped_input_sizes, [[518, 686]]); } + + { + const image = await load_image(TEST_IMAGES.checkerboard_64x32); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + // NOTE: without bankers rounding, this would be [1, 3, 266, 518] + compare(pixel_values.dims, [1, 3, 252, 518]); + compare(avg(pixel_values.data), 0.2267402559518814); + + compare(original_sizes, [[32, 64]]); + compare(reshaped_input_sizes, [[252, 518]]); + } }, MAX_TEST_EXECUTION_TIME); // EfficientNetImageProcessor