Skip to content

Commit

Permalink
Improve web/node split
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 11, 2024
1 parent 69e5f55 commit 8b7c2ce
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 77 deletions.
103 changes: 44 additions & 59 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,23 @@ import { env, RUNNING_LOCALLY } from '../env.js';
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';
import * as ONNX_WEB from 'onnxruntime-web/webgpu';

export { Tensor } from 'onnxruntime-common';

let ONNX;

const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'

const ONNX_MODULES = new Map();
const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node';

let ONNX;
if (USE_ONNXRUNTIME_NODE) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
ONNX_MODULES.set('node', ONNX);
} else {
// @ts-ignore
ONNX = ONNX_WEB.default ?? ONNX_WEB;
ONNX_MODULES.set('web', ONNX);

// Running in a browser-environment
// TODO: Check if 1.17.1 fixes this issue.
// SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
// As a temporary fix, we disable it for now.
// For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
if (isIOS) {
ONNX.env.wasm.simd = false;
}
ONNX = ONNX_WEB;
}

// @ts-ignore
const InferenceSession = ONNX.InferenceSession;

/**
* Create an ONNX inference session, with fallback support if an operation is not supported.
* @param {Uint8Array} buffer The ONNX model buffer.
Expand All @@ -60,33 +47,18 @@ if (USE_ONNXRUNTIME_NODE) {
*/
export async function createInferenceSession(buffer, session_options) {
let executionProviders;
let InferenceSession;
if (USE_ONNXRUNTIME_NODE) {
const ONNX_NODE = ONNX_MODULES.get('node');
InferenceSession = ONNX_NODE.InferenceSession;
executionProviders = ['cpu'];
Object.assign(ONNX_NODE.env, env.backends.onnx);

} else if (WEBGPU_AVAILABLE && env.experimental.useWebGPU) {
// Only import the WebGPU version if the user enables the experimental flag.
let ONNX_WEBGPU = ONNX_MODULES.get('webgpu');
if (ONNX_WEBGPU === undefined) {
ONNX_WEBGPU = await import('onnxruntime-web/webgpu');
ONNX_MODULES.set('webgpu', ONNX_WEBGPU)
} else if (env.experimental.useWebGPU) {
// Only use the WebGPU version if the user enables the experimental flag.
if (WEBGPU_AVAILABLE) {
executionProviders = ['webgpu', 'wasm'];
} else {
console.warn('`env.experimental.useWebGPU = true` but WebGPU is not available in this environment. Using WASM as the execution provider.');
executionProviders = ['wasm'];
}

InferenceSession = ONNX_WEBGPU.InferenceSession;

// If WebGPU is available and the user enables the experimental flag,
// try to use the WebGPU execution provider.
executionProviders = ['webgpu', 'wasm'];
Object.assign(ONNX_WEBGPU.env, env.backends.onnx);

} else {
const ONNX_WEB = ONNX_MODULES.get('web');
InferenceSession = ONNX_WEB.InferenceSession;
executionProviders = ['wasm'];
Object.assign(ONNX_WEB.env, env.backends.onnx);
}

// NOTE: Important to create a clone, since ORT modifies the object.
Expand All @@ -104,32 +76,45 @@ export async function createInferenceSession(buffer, session_options) {
* @returns {boolean} Whether the object is an ONNX tensor.
*/
export function isONNXTensor(x) {
for (const module of ONNX_MODULES.values()) {
if (x instanceof module.Tensor) {
return true;
}
}
return false;
return x instanceof ONNX.Tensor;
}

/**
* Check if ONNX's WASM backend is being proxied.
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
*/
export function isONNXProxy() {
// TODO: Update this when allowing non-WASM backends.
return ONNX.env.wasm.proxy;
}
// @ts-ignore
const ONNX_ENV = ONNX?.env;
if (ONNX_ENV?.wasm) {
// Initialize wasm backend with suitable default settings.

if (ONNX?.env?.wasm) {
// Set path to wasm files. This is needed when running in a web worker.
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
// We use remote wasm files by default to make it easier for newer users.
// In practice, users should probably self-host the necessary .wasm files.
ONNX.env.wasm.wasmPaths = RUNNING_LOCALLY
ONNX_ENV.wasm.wasmPaths = RUNNING_LOCALLY
? path.join(env.__dirname, '/dist/')
: `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`;

// Proxy the WASM backend to prevent the UI from freezing
ONNX_ENV.wasm.proxy = true;
// ONNX_ENV.wasm.numThreads = 1; // TODO is this needed?

// Running in a browser-environment
// TODO: Check if 1.17.1 fixes this issue.
// SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
// As a temporary fix, we disable it for now.
// For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
if (isIOS) {
ONNX_ENV.wasm.simd = false;
}
}

/**
* Check if ONNX's WASM backend is being proxied.
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
*/
export function isONNXProxy() {
// TODO: Update this when allowing non-WASM backends.
return ONNX_ENV?.wasm?.proxy;
}

// Expose ONNX environment variables to `env.backends.onnx`
env.backends.onnx = ONNX.env;
env.backends.onnx = ONNX_ENV;
4 changes: 2 additions & 2 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ export class PretrainedConfig {
* Helper class which is used to instantiate pretrained configs with the `from_pretrained` function.
*
* @example
* let config = await AutoConfig.from_pretrained('bert-base-uncased');
* const config = await AutoConfig.from_pretrained('Xenova/bert-base-uncased');
*/
export class AutoConfig {
/** @type {PretrainedConfig.from_pretrained} */
/** @type {typeof PretrainedConfig.from_pretrained} */
static async from_pretrained(...args) {
return PretrainedConfig.from_pretrained(...args);
}
Expand Down
25 changes: 9 additions & 16 deletions webpack.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ function buildConfig({
name = '',
suffix = '.js',
type = 'module', // 'module' | 'commonjs'
dynamicImportMode = undefined, // 'eager' | undefined
ignoreModules = [], // 'eager' | undefined
} = {}) {
const outputModule = type === 'module';

const alias = Object.fromEntries(ignoreModules.map((module) => {
return [module, false];
}));

return {
mode: 'development',
devtool: 'source-map',
Expand Down Expand Up @@ -60,13 +64,7 @@ function buildConfig({
experiments: {
outputModule,
},
// module: {
// parser: {
// javascript: {
// dynamicImportMode,
// }
// }
// },
resolve: { alias },

// Development server
devServer: {
Expand All @@ -84,13 +82,8 @@ export default [
type: 'module',
}),
buildConfig({
name: '.webgpu',
type: 'module',
dynamicImportMode: 'eager',
suffix: '.cjs',
type: 'commonjs',
ignoreModules: ['onnxruntime-web', 'onnxruntime-web/webgpu'],
}),
// TODO:
// buildConfig({
// suffix: '.cjs',
// type: 'commonjs',
// }),
];

0 comments on commit 8b7c2ce

Please sign in to comment.