From 38a3bf6dab2265d9f0c2f613064535863194e6b9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 29 Jul 2024 13:14:44 +0200 Subject: [PATCH] Only check fp16 support for webgpu device WASM/CPU works with fp16, but WebGPU is device-dependent --- src/models.js | 6 +++--- src/utils/dtypes.js | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/models.js b/src/models.js index d99953f0a..e98f35e59 100644 --- a/src/models.js +++ b/src/models.js @@ -53,7 +53,7 @@ import { DATA_TYPES, DEFAULT_DEVICE_DTYPE_MAPPING, DEFAULT_DTYPE_SUFFIX_MAPPING, - isFp16Supported, + isWebGpuFp16Supported, } from './utils/dtypes.js'; import { @@ -175,8 +175,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(dtype)) { throw new Error(`Invalid dtype: ${dtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); - } else if (dtype === DATA_TYPES.fp16 && !(await isFp16Supported())) { - throw new Error(`The device does not support fp16.`); + } else if (dtype === DATA_TYPES.fp16 && device === 'webgpu' && !(await isWebGpuFp16Supported())) { + throw new Error(`The device (${device}) does not support fp16.`); } // Construct the model file name diff --git a/src/utils/dtypes.js b/src/utils/dtypes.js index 24b0765b3..13f00c538 100644 --- a/src/utils/dtypes.js +++ b/src/utils/dtypes.js @@ -7,17 +7,15 @@ import { DEVICE_TYPES } from "./devices.js"; // For more information, see https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753 /** - * Checks if fp16 support is available in the current environment. + * Checks if WebGPU fp16 support is available in the current environment. */ -export const isFp16Supported = (function () { +export const isWebGpuFp16Supported = (function () { /** @type {boolean} */ let cachedResult; return async function () { if (cachedResult === undefined) { - if (apis.IS_NODE_ENV) { - cachedResult = true; - } else if (!apis.IS_WEBGPU_AVAILABLE) { + if (!apis.IS_WEBGPU_AVAILABLE) { cachedResult = false; } else { try {