Skip to content

Commit

Permalink
Add bankers rounding test case
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 20, 2024
1 parent 85f1b1f commit 1b7c458
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
16 changes: 9 additions & 7 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
min,
max,
softmax,
bankers_round,
} from './utils/maths.js';


Expand Down Expand Up @@ -174,14 +175,15 @@ function validate_audio_inputs(audio, feature_extractor) {
* @private
*/
function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
let x = Math.round(val / multiple) * multiple;
const a = val / multiple;
let x = bankers_round(a) * multiple;

if (maxVal !== null && x > maxVal) {
x = Math.floor(val / multiple) * multiple;
x = Math.floor(a) * multiple;
}

if (x < minVal) {
x = Math.ceil(val / multiple) * multiple;
x = Math.ceil(a) * multiple;
}

return x;
Expand Down Expand Up @@ -513,8 +515,8 @@ export class ImageFeatureExtractor extends FeatureExtractor {
if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) {

// determine new height and width
let scale_height = size.height / srcHeight;
let scale_width = size.width / srcWidth;
let scale_height = newHeight / srcHeight;
let scale_width = newWidth / srcWidth;

// scale as little as possible
if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) {
Expand Down Expand Up @@ -765,9 +767,9 @@ export class SegformerFeatureExtractor extends ImageFeatureExtractor {
return toReturn;
}
}
export class DPTImageProcessor extends ImageFeatureExtractor { }
export class BitImageProcessor extends ImageFeatureExtractor { }
export class DPTFeatureExtractor extends ImageFeatureExtractor { }
export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor
export class BitImageProcessor extends ImageFeatureExtractor { }
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { }
Expand Down
14 changes: 14 additions & 0 deletions src/utils/maths.js
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,17 @@ export function round(num, decimals) {
const pow = Math.pow(10, decimals);
return Math.round(num * pow) / pow;
}

/**
* Helper function to round a number to the nearest integer, with ties rounded to the nearest even number.
* Also known as "bankers' rounding". This is the default rounding mode in python. For example:
* 1.5 rounds to 2 and 2.5 rounds to 2.
*
* @param {number} x The number to round
* @returns {number} The rounded number
*/
export function bankers_round(x) {
const r = Math.round(x);
const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r;
return br;
}
15 changes: 14 additions & 1 deletion tests/maths.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { compare } from './test_utils.js';

import { getFile } from '../src/utils/hub.js';
import { FFT, medianFilter } from '../src/utils/maths.js';
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';


const fft = (arr, complex = false) => {
Expand All @@ -27,6 +27,19 @@ const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json()

describe('Mathematical operations', () => {

describe('bankers rounding', () => {
it('should round up to nearest even', () => {
expect(bankers_round(-0.5)).toBe(0);
expect(bankers_round(1.5)).toBe(2);
expect(bankers_round(19.5)).toBe(20);
});
it('should round down to nearest even', () => {
expect(bankers_round(-1.5)).toBe(-2);
expect(bankers_round(2.5)).toBe(2);
expect(bankers_round(18.5)).toBe(18);
});
});

describe('median filtering', () => {


Expand Down
14 changes: 14 additions & 0 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ describe('Processors', () => {
pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png',
pattern_3x5: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x5.png',
checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png',
checkerboard_64x32: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_64x32.png',
receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png',
tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg',
paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png',
Expand Down Expand Up @@ -433,6 +434,7 @@ describe('Processors', () => {
// DPTImageProcessor
// - tests ensure_multiple_of
// - tests keep_aspect_ratio
// - tests bankers rounding
it(MODELS.dpt_2, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt_2))

Expand All @@ -446,6 +448,18 @@ describe('Processors', () => {
compare(original_sizes, [[480, 640]]);
compare(reshaped_input_sizes, [[518, 686]]);
}

{
const image = await load_image(TEST_IMAGES.checkerboard_64x32);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

// NOTE: without bankers rounding, this would be [1, 3, 266, 518]
compare(pixel_values.dims, [1, 3, 252, 518]);
compare(avg(pixel_values.data), 0.2267402559518814);

compare(original_sizes, [[32, 64]]);
compare(reshaped_input_sizes, [[252, 518]]);
}
}, MAX_TEST_EXECUTION_TIME);

// EfficientNetImageProcessor
Expand Down

0 comments on commit 1b7c458

Please sign in to comment.