Skip to content

Commit

Permalink
Improve how we set devices
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 11, 2024
1 parent 034b959 commit d1b1059
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 35 deletions.
30 changes: 17 additions & 13 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,20 @@ export { Tensor } from 'onnxruntime-common';
const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node';

const supportedExecutionProviders = [];
let defaultExecutionProviders;
let ONNX;
if (USE_ONNXRUNTIME_NODE) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
supportedExecutionProviders.push('cpu');
defaultExecutionProviders = ['cpu'];
} else {
ONNX = ONNX_WEB;
if (WEBGPU_AVAILABLE) {
supportedExecutionProviders.push('webgpu');
}
supportedExecutionProviders.push('wasm');
defaultExecutionProviders = ['wasm'];
}

// @ts-ignore
Expand All @@ -42,23 +51,18 @@ const InferenceSession = ONNX.InferenceSession;
/**
* Create an ONNX inference session, with fallback support if an operation is not supported.
* @param {Uint8Array} buffer The ONNX model buffer.
* @param {Object} session_options ONNX inference session options.
* @param {Object} session_options ONNX inference session options.
* @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on.
* @returns {Promise<Object>} The ONNX inference session.
*/
export async function createInferenceSession(buffer, session_options) {
let executionProviders;
if (USE_ONNXRUNTIME_NODE) {
executionProviders = ['cpu'];
} else if (env.experimental.useWebGPU) {
// Only use the WebGPU version if the user enables the experimental flag.
if (WEBGPU_AVAILABLE) {
executionProviders = ['webgpu', 'wasm'];
export async function createInferenceSession(buffer, session_options, device = null) {
let executionProviders = defaultExecutionProviders;
if (device) { // User has specified a device
if (supportedExecutionProviders.includes(device)) {
executionProviders = [device];
} else {
console.warn('`env.experimental.useWebGPU = true` but WebGPU is not available in this environment. Using WASM as the execution provider.');
executionProviders = ['wasm'];
throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`)
}
} else {
executionProviders = ['wasm'];
}

// NOTE: Important to create a clone, since ORT modifies the object.
Expand Down
6 changes: 0 additions & 6 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ export const env = {
tfjs: {},
},

/////////////////// Experimental settings ///////////////////
experimental: {
// Whether to use the experimental WebGPU backend for ONNX.js.
useWebGPU: false,
},

__dirname,
version: VERSION,

Expand Down
41 changes: 26 additions & 15 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
*
* ```javascript
* import { AutoModel, AutoTokenizer } from '@xenova/transformers';
*
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*
*
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
Expand All @@ -28,7 +28,7 @@
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*
*
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
Expand Down Expand Up @@ -118,22 +118,29 @@ async function constructSession(pretrained_model_name_or_path, fileName, options
const modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`;
const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);

let session_options = options.session_options || {};
const session_options = options.session_options ?? {};

// handle onnx external data files
// TODO: parse external data from config/options
// if (session_options.externalData !== undefined) {
// for (let i = 0; i < session_options.externalData.length; i++) {
// const ext = session_options.externalData[i];
// // if the external data is a string, fetch the file and replace the string with its content
// if (typeof ext.data === "string") {
// const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
// ext.data = ext_buffer;
// }
if (session_options.externalData !== undefined) {
for (let i = 0; i < session_options.externalData.length; i++) {
const ext = session_options.externalData[i];
// if the external data is a string, fetch the file and replace the string with its content
if (typeof ext.data === "string") {
const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
ext.data = ext_buffer;
}
}
}

// TODO: Add support for preferredOutputLocation
// if (options.device == "webgpu") {
// for (let i = 0; i < config.layers; ++i) {
// options.session_options.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer';
// options.session_options.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer';
// }
// }

return await createInferenceSession(buffer, session_options);
return await createInferenceSession(buffer, session_options, options.device);
}

/**
Expand Down Expand Up @@ -198,7 +205,7 @@ async function sessionRun(session, inputs) {
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
let output = await session.run(ortFeed);
output = replaceTensors(output);
for (const [name, t] of Object.entries(checkedInputs)) {
// if we use gpu buffers for kv_caches, we own them and need to dispose()
Expand Down Expand Up @@ -741,6 +748,7 @@ export class PreTrainedModel extends Callable {
local_files_only = false,
revision = 'main',
model_file_name = null,
device = null,
session_options = {},
} = {}) {

Expand All @@ -752,6 +760,7 @@ export class PreTrainedModel extends Callable {
local_files_only,
revision,
model_file_name,
device,
session_options,
}

Expand Down Expand Up @@ -5448,6 +5457,7 @@ export class PretrainedMixin {
local_files_only = false,
revision = 'main',
model_file_name = null,
device = null,
session_options = {},
} = {}) {

Expand All @@ -5459,6 +5469,7 @@ export class PretrainedMixin {
local_files_only,
revision,
model_file_name,
device,
session_options,
}
config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
Expand Down
3 changes: 2 additions & 1 deletion src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -3019,8 +3019,8 @@ export async function pipeline(
cache_dir = null,
local_files_only = false,
revision = 'main',
device= null,
session_options = {},
// TODO: device option
} = {}
) {
// Helper method to construct pipeline
Expand Down Expand Up @@ -3048,6 +3048,7 @@ export async function pipeline(
cache_dir,
local_files_only,
revision,
device,
session_options,
}

Expand Down
3 changes: 3 additions & 0 deletions src/utils/devices.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/**
* @typedef {'cpu'|'gpu'|'wasm'|'webgpu'|null} DeviceType
*/
1 change: 1 addition & 0 deletions src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { dispatchCallback } from './core.js';
* @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model.
* @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files).
* @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models.
* @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings.
* @property {Object} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen.
*/

Expand Down

0 comments on commit d1b1059

Please sign in to comment.