Skip to content

Commit

Permalink
Move tensor clone for Worker ownership NaN issue
Browse files Browse the repository at this point in the history
  • Loading branch information
kungfooman committed Nov 19, 2023
1 parent b8719b1 commit f89a296
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';

/** @type {module} The ONNX runtime module. */
/** @type {import('onnxruntime-web')} The ONNX runtime module. */
export let ONNX;

export const executionProviders = [
Expand Down
21 changes: 15 additions & 6 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ import {

import { executionProviders, ONNX } from './backends/onnx.js';
import { medianFilter } from './transformers.js';
const { InferenceSession, Tensor: ONNXTensor } = ONNX;
const { InferenceSession, Tensor: ONNXTensor, env } = ONNX;

/** @typedef {import('onnxruntime-web').InferenceSession} InferenceSession */

//////////////////////////////////////////////////
// Model types: used internally
Expand Down Expand Up @@ -142,20 +144,27 @@ async function constructSession(pretrained_model_name_or_path, fileName, options
/**
* Validate model inputs
* @param {InferenceSession} session The InferenceSession object that will be run.
* @param {Object} inputs The inputs to check.
* @param {Record<string, Tensor>} inputs The inputs to check.
* @returns {Promise<Object>} A Promise that resolves to the checked inputs.
* @throws {Error} If any inputs are missing.
* @private
*/
async function validateInputs(session, inputs) {
// NOTE: Only create a shallow copy
// NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
const checkedInputs = {};
const missingInputs = [];
for (let inputName of session.inputNames) {
if (inputs[inputName] === undefined) {
for (const inputName of session.inputNames) {
const tensor = inputs[inputName];
if (!tensor) {
missingInputs.push(inputName);
} else {
checkedInputs[inputName] = inputs[inputName];
if (env.wasm.proxy) {
// Moving the tensor across Worker boundary moves ownership to the worker,
// which invalidates the tensor. So we simply sacrifize the clone for it.
checkedInputs[inputName] = tensor.clone();
} else {
checkedInputs[inputName] = tensor;
}
}
}
if (missingInputs.length > 0) {
Expand Down

0 comments on commit f89a296

Please sign in to comment.