From d1b1059c7c6af09ec4592f1e4ed4c0bcc21786be Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Mar 2024 03:03:20 +0200 Subject: [PATCH] Improve how we set devices --- src/backends/onnx.js | 30 +++++++++++++++++------------- src/env.js | 6 ------ src/models.js | 41 ++++++++++++++++++++++++++--------------- src/pipelines.js | 3 ++- src/utils/devices.js | 3 +++ src/utils/hub.js | 1 + 6 files changed, 49 insertions(+), 35 deletions(-) create mode 100644 src/utils/devices.js diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 36ad8ae52..d5917d888 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -29,11 +29,20 @@ export { Tensor } from 'onnxruntime-common'; const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'; +const supportedExecutionProviders = []; +let defaultExecutionProviders; let ONNX; if (USE_ONNXRUNTIME_NODE) { ONNX = ONNX_NODE.default ?? ONNX_NODE; + supportedExecutionProviders.push('cpu'); + defaultExecutionProviders = ['cpu']; } else { ONNX = ONNX_WEB; + if (WEBGPU_AVAILABLE) { + supportedExecutionProviders.push('webgpu'); + } + supportedExecutionProviders.push('wasm'); + defaultExecutionProviders = ['wasm']; } // @ts-ignore @@ -42,23 +51,18 @@ const InferenceSession = ONNX.InferenceSession; /** * Create an ONNX inference session, with fallback support if an operation is not supported. * @param {Uint8Array} buffer The ONNX model buffer. - * @param {Object} session_options ONNX inference session options. + * @param {Object} session_options ONNX inference session options. + * @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on. * @returns {Promise} The ONNX inference session. */ -export async function createInferenceSession(buffer, session_options) { - let executionProviders; - if (USE_ONNXRUNTIME_NODE) { - executionProviders = ['cpu']; - } else if (env.experimental.useWebGPU) { - // Only use the WebGPU version if the user enables the experimental flag. - if (WEBGPU_AVAILABLE) { - executionProviders = ['webgpu', 'wasm']; +export async function createInferenceSession(buffer, session_options, device = null) { + let executionProviders = defaultExecutionProviders; + if (device) { // User has specified a device + if (supportedExecutionProviders.includes(device)) { + executionProviders = [device]; } else { - console.warn('`env.experimental.useWebGPU = true` but WebGPU is not available in this environment. Using WASM as the execution provider.'); - executionProviders = ['wasm']; + throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`) } - } else { - executionProviders = ['wasm']; } // NOTE: Important to create a clone, since ORT modifies the object. diff --git a/src/env.js b/src/env.js index 3b2724b84..e4a736457 100644 --- a/src/env.js +++ b/src/env.js @@ -84,12 +84,6 @@ export const env = { tfjs: {}, }, - /////////////////// Experimental settings /////////////////// - experimental: { - // Whether to use the experimental WebGPU backend for ONNX.js. - useWebGPU: false, - }, - __dirname, version: VERSION, diff --git a/src/models.js b/src/models.js index e935f3111..4219ca9ea 100644 --- a/src/models.js +++ b/src/models.js @@ -6,10 +6,10 @@ * * ```javascript * import { AutoModel, AutoTokenizer } from '@xenova/transformers'; - * + * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); - * + * * let inputs = await tokenizer('I love transformers!'); * let { logits } = await model(inputs); * // Tensor { @@ -28,7 +28,7 @@ * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); - * + * * let { input_ids } = await tokenizer('translate English to German: I love transformers!'); * let outputs = await model.generate(input_ids); * let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true }); @@ -118,22 +118,29 @@ async function constructSession(pretrained_model_name_or_path, fileName, options const modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); - let session_options = options.session_options || {}; + const session_options = options.session_options ?? {}; // handle onnx external data files - // TODO: parse external data from config/options - // if (session_options.externalData !== undefined) { - // for (let i = 0; i < session_options.externalData.length; i++) { - // const ext = session_options.externalData[i]; - // // if the external data is a string, fetch the file and replace the string with its content - // if (typeof ext.data === "string") { - // const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); - // ext.data = ext_buffer; - // } + if (session_options.externalData !== undefined) { + for (let i = 0; i < session_options.externalData.length; i++) { + const ext = session_options.externalData[i]; + // if the external data is a string, fetch the file and replace the string with its content + if (typeof ext.data === "string") { + const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); + ext.data = ext_buffer; + } + } + } + + // TODO: Add support for preferredOutputLocation + // if (options.device == "webgpu") { + // for (let i = 0; i < config.layers; ++i) { + // options.session_options.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer'; + // options.session_options.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer'; // } // } - return await createInferenceSession(buffer, session_options); + return await createInferenceSession(buffer, session_options, options.device); } /** @@ -198,7 +205,7 @@ async function sessionRun(session, inputs) { try { // pass the original ort tensor const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); - let output = await session.run(ortFeed); + let output = await session.run(ortFeed); output = replaceTensors(output); for (const [name, t] of Object.entries(checkedInputs)) { // if we use gpu buffers for kv_caches, we own them and need to dispose() @@ -741,6 +748,7 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', model_file_name = null, + device = null, session_options = {}, } = {}) { @@ -752,6 +760,7 @@ export class PreTrainedModel extends Callable { local_files_only, revision, model_file_name, + device, session_options, } @@ -5448,6 +5457,7 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', model_file_name = null, + device = null, session_options = {}, } = {}) { @@ -5459,6 +5469,7 @@ export class PretrainedMixin { local_files_only, revision, model_file_name, + device, session_options, } config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); diff --git a/src/pipelines.js b/src/pipelines.js index 7b73393e3..13475c460 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3019,8 +3019,8 @@ export async function pipeline( cache_dir = null, local_files_only = false, revision = 'main', + device= null, session_options = {}, - // TODO: device option } = {} ) { // Helper method to construct pipeline @@ -3048,6 +3048,7 @@ export async function pipeline( cache_dir, local_files_only, revision, + device, session_options, } diff --git a/src/utils/devices.js b/src/utils/devices.js new file mode 100644 index 000000000..8a0a83dca --- /dev/null +++ b/src/utils/devices.js @@ -0,0 +1,3 @@ +/** + * @typedef {'cpu'|'gpu'|'wasm'|'webgpu'|null} DeviceType + */ diff --git a/src/utils/hub.js b/src/utils/hub.js index 34bd07433..4062d4826 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -28,6 +28,7 @@ import { dispatchCallback } from './core.js'; * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model. * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. + * @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. * @property {Object} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. */