-
Notifications
You must be signed in to change notification settings - Fork 795
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
222 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Using quantized models (dtypes) | ||
|
||
Before Transformers.js v3, we used the `quantized` option to specify whether to use a quantized (q8) or full-precision (fp32) variant of the model by setting `quantized` to `true` or `false`, respectively. Now, we've added the ability to select from a much larger list with the `dtype` parameter. | ||
|
||
The list of available quantizations depends on the model, but some common ones are: full-precision (`"fp32"`), half-precision (`"fp16"`), 8-bit (`"q8"`, `"int8"`, `"uint8"`), and 4-bit (`"q4"`, `"bnb4"`, `"q4f16"`). | ||
|
||
<p align="center"> | ||
<picture> | ||
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/dtypes-dark.jpg" style="max-width: 100%;"> | ||
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/dtypes-light.jpg" style="max-width: 100%;"> | ||
<img alt="Available dtypes for mixedbread-ai/mxbai-embed-xsmall-v1" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/dtypes-dark.jpg" style="max-width: 100%;"> | ||
</picture> | ||
<a href="https://huggingface.co/mixedbread-ai/mxbai-embed-xsmall-v1/tree/main/onnx">(e.g., mixedbread-ai/mxbai-embed-xsmall-v1)</a> | ||
</p> | ||
|
||
## Basic usage | ||
|
||
**Example:** Run Qwen2.5-0.5B-Instruct in 4-bit quantization ([demo](https://v2.scrimba.com/s0dlcpv0ci)) | ||
|
||
```js | ||
import { pipeline } from "@huggingface/transformers"; | ||
|
||
// Create a text generation pipeline | ||
const generator = await pipeline( | ||
"text-generation", | ||
"onnx-community/Qwen2.5-0.5B-Instruct", | ||
{ dtype: "q4", device: "webgpu" }, | ||
); | ||
|
||
// Define the list of messages | ||
const messages = [ | ||
{ role: "system", content: "You are a helpful assistant." }, | ||
{ role: "user", content: "Tell me a funny joke." }, | ||
]; | ||
|
||
// Generate a response | ||
const output = await generator(messages, { max_new_tokens: 128 }); | ||
console.log(output[0].generated_text.at(-1).content); | ||
``` | ||
|
||
## Per-module dtypes | ||
|
||
Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype. | ||
|
||
**Example:** Run Florence-2 on WebGPU ([demo](https://v2.scrimba.com/s0pdm485fo)) | ||
|
||
```js | ||
import { Florence2ForConditionalGeneration } from "@huggingface/transformers"; | ||
|
||
const model = await Florence2ForConditionalGeneration.from_pretrained( | ||
"onnx-community/Florence-2-base-ft", | ||
{ | ||
dtype: { | ||
embed_tokens: "fp16", | ||
vision_encoder: "fp16", | ||
encoder_model: "q4", | ||
decoder_model_merged: "q4", | ||
}, | ||
device: "webgpu", | ||
}, | ||
); | ||
``` | ||
|
||
<p align="middle"> | ||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/florence-2-webgpu.gif" alt="Florence-2 running on WebGPU" /> | ||
</p> | ||
|
||
<details> | ||
<summary> | ||
See full code example | ||
</summary> | ||
|
||
```js | ||
import { | ||
Florence2ForConditionalGeneration, | ||
AutoProcessor, | ||
AutoTokenizer, | ||
RawImage, | ||
} from "@huggingface/transformers"; | ||
|
||
// Load model, processor, and tokenizer | ||
const model_id = "onnx-community/Florence-2-base-ft"; | ||
const model = await Florence2ForConditionalGeneration.from_pretrained( | ||
model_id, | ||
{ | ||
dtype: { | ||
embed_tokens: "fp16", | ||
vision_encoder: "fp16", | ||
encoder_model: "q4", | ||
decoder_model_merged: "q4", | ||
}, | ||
device: "webgpu", | ||
}, | ||
); | ||
const processor = await AutoProcessor.from_pretrained(model_id); | ||
const tokenizer = await AutoTokenizer.from_pretrained(model_id); | ||
|
||
// Load image and prepare vision inputs | ||
const url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"; | ||
const image = await RawImage.fromURL(url); | ||
const vision_inputs = await processor(image); | ||
|
||
// Specify task and prepare text inputs | ||
const task = "<MORE_DETAILED_CAPTION>"; | ||
const prompts = processor.construct_prompts(task); | ||
const text_inputs = tokenizer(prompts); | ||
|
||
// Generate text | ||
const generated_ids = await model.generate({ | ||
...text_inputs, | ||
...vision_inputs, | ||
max_new_tokens: 100, | ||
}); | ||
|
||
// Decode generated text | ||
const generated_text = tokenizer.batch_decode(generated_ids, { | ||
skip_special_tokens: false, | ||
})[0]; | ||
|
||
// Post-process the generated text | ||
const result = processor.post_process_generation( | ||
generated_text, | ||
task, | ||
image.size, | ||
); | ||
console.log(result); | ||
// { '<MORE_DETAILED_CAPTION>': 'A green car is parked in front of a tan building. The building has a brown door and two brown windows. The car is a two door and the door is closed. The green car has black tires.' } | ||
``` | ||
|
||
</details> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Running models on WebGPU | ||
|
||
WebGPU is a new web standard for accelerated graphics and compute. The [API](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API) enables web developers to use the underlying system's GPU to carry out high-performance computations directly in the browser. WebGPU is the successor to [WebGL](https://developer.mozilla.org/en-US/docs/Web/API/WebGL_API) and provides significantly better performance, because it allows for more direct interaction with modern GPUs. Lastly, it supports general-purpose GPU computations, which makes it just perfect for machine learning! | ||
|
||
> [!WARNING] | ||
> As of October 2024, global WebGPU support is around 70% (according to [caniuse.com](https://caniuse.com/webgpu)), meaning some users may not be able to use the API. | ||
> | ||
> If the following demos do not work in your browser, you may need to enable it using a feature flag: | ||
> | ||
> - Firefox: with the `dom.webgpu.enabled` flag (see [here](https://developer.mozilla.org/en-US/docs/Mozilla/Firefox/Experimental_features#:~:text=tested%20by%20Firefox.-,WebGPU%20API,-The%20WebGPU%20API)). | ||
> - Safari: with the `WebGPU` feature flag (see [here](https://webkit.org/blog/14879/webgpu-now-available-for-testing-in-safari-technology-preview/)). | ||
> - Older Chromium browsers (on Windows, macOS, Linux): with the `enable-unsafe-webgpu` flag (see [here](https://developer.chrome.com/docs/web-platform/webgpu/troubleshooting-tips)). | ||
## Usage in Transformers.js v3 | ||
|
||
Thanks to our collaboration with [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), enabling WebGPU acceleration is as simple as setting `device: 'webgpu'` when loading a model. Let's see some examples! | ||
|
||
**Example:** Compute text embeddings on WebGPU ([demo](https://v2.scrimba.com/s06a2smeej)) | ||
|
||
```js | ||
import { pipeline } from "@huggingface/transformers"; | ||
|
||
// Create a feature-extraction pipeline | ||
const extractor = await pipeline( | ||
"feature-extraction", | ||
"mixedbread-ai/mxbai-embed-xsmall-v1", | ||
{ device: "webgpu" }, | ||
}); | ||
|
||
// Compute embeddings | ||
const texts = ["Hello world!", "This is an example sentence."]; | ||
const embeddings = await extractor(texts, { pooling: "mean", normalize: true }); | ||
console.log(embeddings.tolist()); | ||
// [ | ||
// [-0.016986183822155, 0.03228696808218956, -0.0013630966423079371, ... ], | ||
// [0.09050482511520386, 0.07207386940717697, 0.05762749910354614, ... ], | ||
// ] | ||
``` | ||
|
||
**Example:** Perform automatic speech recognition with OpenAI whisper on WebGPU ([demo](https://v2.scrimba.com/s0oi76h82g)) | ||
|
||
```js | ||
import { pipeline } from "@huggingface/transformers"; | ||
|
||
// Create automatic speech recognition pipeline | ||
const transcriber = await pipeline( | ||
"automatic-speech-recognition", | ||
"onnx-community/whisper-tiny.en", | ||
{ device: "webgpu" }, | ||
); | ||
|
||
// Transcribe audio from a URL | ||
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav"; | ||
const output = await transcriber(url); | ||
console.log(output); | ||
// { text: ' And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.' } | ||
``` | ||
|
||
**Example:** Perform image classification with MobileNetV4 on WebGPU ([demo](https://v2.scrimba.com/s0fv2uab1t)) | ||
|
||
```js | ||
import { pipeline } from "@huggingface/transformers"; | ||
|
||
// Create image classification pipeline | ||
const classifier = await pipeline( | ||
"image-classification", | ||
"onnx-community/mobilenetv4_conv_small.e2400_r224_in1k", | ||
{ device: "webgpu" }, | ||
); | ||
|
||
// Classify an image from a URL | ||
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg"; | ||
const output = await classifier(url); | ||
console.log(output); | ||
// [ | ||
// { label: 'tiger, Panthera tigris', score: 0.6149784922599792 }, | ||
// { label: 'tiger cat', score: 0.30281734466552734 }, | ||
// { label: 'tabby, tabby cat', score: 0.0019135422771796584 }, | ||
// { label: 'lynx, catamount', score: 0.0012161266058683395 }, | ||
// { label: 'Egyptian cat', score: 0.0011465961579233408 } | ||
// ] | ||
``` | ||
|
||
## Reporting bugs and providing feedback | ||
|
||
Due to the experimental nature of the WebGPU API, especially in non-Chromium browsers, you may | ||
|