diff --git a/src/backends/onnx.js b/src/backends/onnx.js index de89da037..f4c9431fd 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -59,11 +59,11 @@ const supportedDevices = []; /** @type {ONNXExecutionProviders[]} */ let defaultDevices; let ONNX; -const ORT_SYMBOL = Symbol.for('onnxruntime'); -if (ORT_SYMBOL in globalThis) { - // If the JS runtime exposes their own ONNX runtime, use it - ONNX = globalThis[ORT_SYMBOL]; +if (apis.IS_EXPOSED_RUNTIME_ENV) { + // If the JS runtime exposes their own ONNX runtime, use it + ONNX = globalThis[apis.EXPOSED_RUNTIME_SYMBOL]; + defaultDevices = ['auto']; } else if (apis.IS_NODE_ENV) { ONNX = ONNX_NODE.default ?? ONNX_NODE; diff --git a/src/env.js b/src/env.js index 3fd7cdbdd..cc4e21a1f 100644 --- a/src/env.js +++ b/src/env.js @@ -35,6 +35,9 @@ const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self; const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator; +const EXPOSED_RUNTIME_SYMBOL = Symbol.for('onnxruntime'); +const IS_EXPOSED_RUNTIME_ENV = EXPOSED_RUNTIME_SYMBOL in globalThis; + const IS_PROCESS_AVAILABLE = typeof process !== 'undefined'; const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node'; const IS_FS_AVAILABLE = !isEmpty(fs); @@ -59,6 +62,12 @@ export const apis = Object.freeze({ /** Whether the WebNN API is available */ IS_WEBNN_AVAILABLE, + /** Symbol from JS environment that exposes their own ONNX runtime */ + EXPOSED_RUNTIME_SYMBOL, + + /** Whether we are running in a JS environment that exposes their own ONNX runtime */ + IS_EXPOSED_RUNTIME_ENV, + /** Whether the Node.js process API is available */ IS_PROCESS_AVAILABLE, diff --git a/src/models.js b/src/models.js index 4d22e947a..e807ff3a5 100644 --- a/src/models.js +++ b/src/models.js @@ -166,7 +166,10 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { // If the device is not specified, we use the default (supported) execution providers. const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */( - device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm') + device ?? ( + apis.IS_EXPOSED_RUNTIME_ENV ? 'auto' : ( + apis.IS_NODE_ENV ? 'cpu' : 'wasm' + )) ); const executionProviders = deviceToExecutionProviders(selectedDevice);