Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor per-model unit testing #1083

Merged
merged 14 commits into from
Dec 11, 2024
80 changes: 50 additions & 30 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3666,9 +3666,11 @@ export class CLIPModel extends CLIPPreTrainedModel { }
export class CLIPTextModel extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand Down Expand Up @@ -3701,9 +3703,11 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand All @@ -3713,9 +3717,11 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
export class CLIPVisionModel extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}

Expand Down Expand Up @@ -3748,9 +3754,11 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -3834,9 +3842,11 @@ export class SiglipModel extends SiglipPreTrainedModel { }
export class SiglipTextModel extends SiglipPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand Down Expand Up @@ -3869,9 +3879,11 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
export class SiglipVisionModel extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -3926,18 +3938,22 @@ export class JinaCLIPModel extends JinaCLIPPreTrainedModel {
export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -6159,9 +6175,11 @@ export class ClapModel extends ClapPreTrainedModel { }
export class ClapTextModelWithProjection extends ClapPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand Down Expand Up @@ -6194,9 +6212,11 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'audio_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'audio_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down
15 changes: 14 additions & 1 deletion src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,21 @@ export class Tensor {
if (!DataTypeMap.hasOwnProperty(type)) {
throw new Error(`Unsupported type: ${type}`);
}

// Handle special cases where a mapping function is needed (e.g., where one type is a bigint and the other is a number)
let map_fn;
const is_source_bigint = ['int64', 'uint64'].includes(this.type);
const is_dest_bigint = ['int64', 'uint64'].includes(type);
if (is_source_bigint && !is_dest_bigint) {
// TypeError: Cannot convert a BigInt value to a number
map_fn = Number;
} else if (!is_source_bigint && is_dest_bigint) {
// TypeError: Cannot convert [x] to a BigInt
map_fn = BigInt;
}

// @ts-ignore
return new Tensor(type, DataTypeMap[type].from(this.data), this.dims);
return new Tensor(type, DataTypeMap[type].from(this.data, map_fn), this.dims);
}
}

Expand Down
43 changes: 43 additions & 0 deletions tests/asset_cache.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { RawImage } from "../src/transformers.js";

const BASE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/";
const TEST_IMAGES = Object.freeze({
white_image: BASE_URL + "white-image.png",
pattern_3x3: BASE_URL + "pattern_3x3.png",
pattern_3x5: BASE_URL + "pattern_3x5.png",
checkerboard_8x8: BASE_URL + "checkerboard_8x8.png",
checkerboard_64x32: BASE_URL + "checkerboard_64x32.png",
gradient_1280x640: BASE_URL + "gradient_1280x640.png",
receipt: BASE_URL + "receipt.png",
tiger: BASE_URL + "tiger.jpg",
paper: BASE_URL + "nougat_paper.png",
cats: BASE_URL + "cats.jpg",

// grayscale image
skateboard: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png",

vitmatte_image: BASE_URL + "vitmatte_image.png",
vitmatte_trimap: BASE_URL + "vitmatte_trimap.png",

beetle: BASE_URL + "beetle.png",
book_cover: BASE_URL + "book-cover.png",
});

/** @type {Map<string, RawImage>} */
const IMAGE_CACHE = new Map();
const load_image = async (url) => {
const cached = IMAGE_CACHE.get(url);
if (cached) {
return cached;
}
const image = await RawImage.fromURL(url);
IMAGE_CACHE.set(url, image);
return image;
};

/**
* Load a cached image.
* @param {keyof typeof TEST_IMAGES} name The name of the image to load.
* @returns {Promise<RawImage>} The loaded image.
*/
export const load_cached_image = (name) => load_image(TEST_IMAGES[name]);
58 changes: 58 additions & 0 deletions tests/init.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,66 @@ export function init() {
registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY);
}

export const MAX_PROCESSOR_LOAD_TIME = 10_000; // 10 seconds
export const MAX_MODEL_LOAD_TIME = 15_000; // 15 seconds
export const MAX_TEST_EXECUTION_TIME = 60_000; // 60 seconds
export const MAX_MODEL_DISPOSE_TIME = 1_000; // 1 second

export const MAX_TEST_TIME = MAX_MODEL_LOAD_TIME + MAX_TEST_EXECUTION_TIME + MAX_MODEL_DISPOSE_TIME;

export const DEFAULT_MODEL_OPTIONS = {
dtype: "fp32",
};

expect.extend({
toBeCloseToNested(received, expected, numDigits = 2) {
const compare = (received, expected, path = "") => {
if (typeof received === "number" && typeof expected === "number" && !Number.isInteger(received) && !Number.isInteger(expected)) {
const pass = Math.abs(received - expected) < Math.pow(10, -numDigits);
return {
pass,
message: () => (pass ? `✓ At path '${path}': expected ${received} not to be close to ${expected} with tolerance of ${numDigits} decimal places` : `✗ At path '${path}': expected ${received} to be close to ${expected} with tolerance of ${numDigits} decimal places`),
};
} else if (Array.isArray(received) && Array.isArray(expected)) {
if (received.length !== expected.length) {
return {
pass: false,
message: () => `✗ At path '${path}': array lengths differ. Received length ${received.length}, expected length ${expected.length}`,
};
}
for (let i = 0; i < received.length; i++) {
const result = compare(received[i], expected[i], `${path}[${i}]`);
if (!result.pass) return result;
}
} else if (typeof received === "object" && typeof expected === "object" && received !== null && expected !== null) {
const receivedKeys = Object.keys(received);
const expectedKeys = Object.keys(expected);
if (receivedKeys.length !== expectedKeys.length) {
return {
pass: false,
message: () => `✗ At path '${path}': object keys length differ. Received keys: ${JSON.stringify(receivedKeys)}, expected keys: ${JSON.stringify(expectedKeys)}`,
};
}
for (const key of receivedKeys) {
if (!expected.hasOwnProperty(key)) {
return {
pass: false,
message: () => `✗ At path '${path}': key '${key}' found in received but not in expected`,
};
}
const result = compare(received[key], expected[key], `${path}.${key}`);
if (!result.pass) return result;
}
} else {
const pass = received === expected;
return {
pass,
message: () => (pass ? `✓ At path '${path}': expected ${JSON.stringify(received)} not to equal ${JSON.stringify(expected)}` : `✗ At path '${path}': expected ${JSON.stringify(received)} to equal ${JSON.stringify(expected)}`),
};
}
return { pass: true };
};

return compare(received, expected);
},
});
Loading
Loading