Skip to content

Commit

Permalink
[DO NOT MERGE] Code example for testing ORT-Web WebNN EP
Browse files Browse the repository at this point in the history
This is a very rough example to enable WebNN in transfermer.js,
I just add some hard codes to make the "Image classification w/ google/vite-base-patch16-224"
fp32 model work with ORT Web WebNN EP.

This PR depends on huggingface#596
  • Loading branch information
Honry committed Feb 28, 2024
1 parent 2a95f48 commit 0e5737a
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 16 deletions.
4 changes: 2 additions & 2 deletions examples/demo-site/src/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ <h2 class="fw-bolder">Demo</h2>
<label>Task: </label>
<div class="col-12 mt-1">
<select id="task" class="form-select">
<option value="translation" selected>
<option value="translation">
Translation w/ t5-small (78 MB)
</option>
<option value="text-generation">
Expand Down Expand Up @@ -119,7 +119,7 @@ <h2 class="fw-bolder">Demo</h2>
<option value="image-to-text">
Image to text w/ vit-gpt2-image-captioning (246 MB)
</option>
<option value="image-classification">
<option value="image-classification" selected>
Image classification w/ google/vit-base-patch16-224 (88 MB)
</option>
<option value="zero-shot-image-classification">
Expand Down
8 changes: 6 additions & 2 deletions examples/demo-site/src/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ self.addEventListener('message', async (event) => {
class PipelineFactory {
static task = null;
static model = null;
static quantized = true;

// NOTE: instance stores a promise that resolves to the pipeline
static instance = null;

constructor(tokenizer, model) {
constructor(tokenizer, model, quantized) {
this.tokenizer = tokenizer;
this.model = model;
this.quantized = quantized;
}

/**
Expand All @@ -65,7 +67,8 @@ class PipelineFactory {
}
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
progress_callback: progressCallback
progress_callback: progressCallback,
quantized: this.quantized,
});
}

Expand Down Expand Up @@ -131,6 +134,7 @@ class ImageToTextPipelineFactory extends PipelineFactory {
class ImageClassificationPipelineFactory extends PipelineFactory {
static task = 'image-classification';
static model = 'Xenova/vit-base-patch16-224';
static quantized = false;
}


Expand Down
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
},
"homepage": "https://github.com/xenova/transformers.js#readme",
"dependencies": {
"onnxruntime-web": "1.17.0",
"onnxruntime-web": "1.18.0-dev.20240130-9f68a27c7a",
"sharp": "^0.32.0",
"@huggingface/jinja": "^0.1.0"
},
"optionalDependencies": {
"onnxruntime-node": "1.17.0"
"onnxruntime-node": "1.18.0-dev.20240130-9f68a27c7a"
},
"devDependencies": {
"@types/jest": "^29.5.1",
Expand Down
3 changes: 2 additions & 1 deletion src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';
import * as ONNX_WEB from 'onnxruntime-web/experimental';

/** @type {import('onnxruntime-web')} The ONNX runtime module. */
export let ONNX;

export const executionProviders = [
// 'webnn',
// 'webgpu',
'wasm'
];
Expand Down
5 changes: 3 additions & 2 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ const localModelPath = RUNNING_LOCALLY
// In practice, users should probably self-host the necessary .wasm files.
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
? path.join(__dirname, '/dist/')
: `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`;

// : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`;
// Copy ort-web wasm files to examples/demo-site/src/dist/
: location.origin + '/dist/';

/**
* Global variable used to control execution. This provides users a simple way to configure Transformers.js.
Expand Down
30 changes: 23 additions & 7 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,25 @@ async function constructSession(pretrained_model_name_or_path, fileName, options
let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);

try {
return await InferenceSession.create(buffer, {
executionProviders,
});
let sessionOptions = { executionProviders };
if (pretrained_model_name_or_path == 'Xenova/vit-base-patch16-224') {
// Hard code example to use webnn for Xenova/vit-base-patch16-224
sessionOptions = {
executionProviders: [{
name: "webnn",
deviceType: "gpu",
}],
// input name: pixel_values, tensor: float32[batch_size,num_channels,height,width]
// WebNN only supports static shape model, use freeDimensionOverrides option to fix the input shape.
freeDimensionOverrides: {
batch_size: 1,
num_channels: 3,
height: 224,
width: 224,
},
}
}
return await InferenceSession.create(buffer, sessionOptions);
} catch (err) {
// If the execution provided was only wasm, throw the error
if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
Expand Down Expand Up @@ -205,13 +221,13 @@ async function sessionRun(session, inputs) {
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
let output = await session.run(ortFeed);
output = replaceTensors(output);
for (const [name, t] of Object.entries(checkedInputs)) {
// if we use gpu buffers for kv_caches, we own them and need to dispose()
if (name.startsWith('past_key_values')) {
t.dispose();
};
// if (name.startsWith('past_key_values')) {
// t.dispose();
// };
}
return output;
} catch (e) {
Expand Down

0 comments on commit 0e5737a

Please sign in to comment.