Skip to content

Commit

Permalink
Only check fp16 support for webgpu device
Browse files Browse the repository at this point in the history
WASM/CPU works with fp16, but WebGPU is device-dependent
  • Loading branch information
xenova committed Jul 29, 2024
1 parent c6aeb4b commit 38a3bf6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import {
DATA_TYPES,
DEFAULT_DEVICE_DTYPE_MAPPING,
DEFAULT_DTYPE_SUFFIX_MAPPING,
isFp16Supported,
isWebGpuFp16Supported,
} from './utils/dtypes.js';

import {
Expand Down Expand Up @@ -175,8 +175,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {

if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(dtype)) {
throw new Error(`Invalid dtype: ${dtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
} else if (dtype === DATA_TYPES.fp16 && !(await isFp16Supported())) {
throw new Error(`The device does not support fp16.`);
} else if (dtype === DATA_TYPES.fp16 && device === 'webgpu' && !(await isWebGpuFp16Supported())) {
throw new Error(`The device (${device}) does not support fp16.`);
}

// Construct the model file name
Expand Down
8 changes: 3 additions & 5 deletions src/utils/dtypes.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ import { DEVICE_TYPES } from "./devices.js";
// For more information, see https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753

/**
* Checks if fp16 support is available in the current environment.
* Checks if WebGPU fp16 support is available in the current environment.
*/
export const isFp16Supported = (function () {
export const isWebGpuFp16Supported = (function () {
/** @type {boolean} */
let cachedResult;

return async function () {
if (cachedResult === undefined) {
if (apis.IS_NODE_ENV) {
cachedResult = true;
} else if (!apis.IS_WEBGPU_AVAILABLE) {
if (!apis.IS_WEBGPU_AVAILABLE) {
cachedResult = false;
} else {
try {
Expand Down

0 comments on commit 38a3bf6

Please sign in to comment.