diff --git a/README.md b/README.md index 4eb0c160..2286326a 100644 --- a/README.md +++ b/README.md @@ -65,8 +65,50 @@ async function main() { main(); ``` -Finally, you can find a complete -You can also find a complete chat app in [examples/simple-chat](examples/simple-chat/). +### Using Web Worker + +WebLLM comes with API support for WebWorker so you can hook +the generation process into a separate worker thread so that +the compute in the webworker won't disrupt the UI. + +We first create a worker script that created a ChatModule and +hook it up to a handler that handles requests. + +```typescript +// worker.ts +import { ChatWorkerHandler, ChatModule } from "@mlc-ai/web-llm"; + +// Hookup a chat module to a worker handler +const chat = new ChatModule(); +const handler = new ChatWorkerHandler(chat); +self.onmessage = (msg: MessageEvent) => { + handler.onmessage(msg); +}; +``` + +Then in the main logic, we create a `ChatWorkerClient` that +implements the same `ChatInterface`. The rest of the logic remains the same. + +```typescript +// main.ts +import * as webllm from "@mlc-ai/web-llm"; + +async function main() { + // Use a chat worker client instead of ChatModule here + const chat = new webllm.ChatWorkerClient(new Worker( + new URL('./worker.ts', import.meta.url), + {type: 'module'} + )); + // everything else remains the same +} +``` + + +### Build a ChatApp + +You can find a complete +a complete chat app example in [examples/simple-chat](examples/simple-chat/). + ## Customized Model Weights diff --git a/examples/README.md b/examples/README.md index 57998bcd..8b83ddd1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,5 +6,5 @@ Please send a pull request if you find things that belongs to here. ## Tutorial Examples - [get-started](get-started): minimum get started example. +- [web-worker](web-worker): get started with web worker backed chat. - [simple-chat](simple-chat): a mininum and complete chat app. - diff --git a/examples/get-started/README.md b/examples/get-started/README.md index 9896513e..b872605f 100644 --- a/examples/get-started/README.md +++ b/examples/get-started/README.md @@ -7,7 +7,7 @@ To try it out, you can do the following steps - `@mlc-ai/web-llm` points to a valid npm version e.g. ```js "dependencies": { - "@mlc-ai/web-llm": "^0.1.3" + "@mlc-ai/web-llm": "^0.2.0" } ``` Try this option if you would like to use WebLLM without building it yourself. diff --git a/examples/simple-chat/README.md b/examples/simple-chat/README.md index 461c6488..18f38a72 100644 --- a/examples/simple-chat/README.md +++ b/examples/simple-chat/README.md @@ -7,7 +7,7 @@ chat app based on WebLLM. To try it out, you can do the following steps - Option 1: `@mlc-ai/web-llm` points to a valid npm version e.g. ```js "dependencies": { - "@mlc-ai/web-llm": "^0.1.3" + "@mlc-ai/web-llm": "^0.2.0" } ``` Try this option if you would like to use WebLLM. diff --git a/examples/simple-chat/src/gh-config.js b/examples/simple-chat/src/gh-config.js index 3205ef2d..0a971502 100644 --- a/examples/simple-chat/src/gh-config.js +++ b/examples/simple-chat/src/gh-config.js @@ -18,5 +18,6 @@ export default { "vicuna-v1-7b-q4f32_0": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/vicuna-v1-7b-q4f32_0-webgpu.wasm", "RedPajama-INCITE-Chat-3B-v1-q4f32_0": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/RedPajama-INCITE-Chat-3B-v1-q4f32_0-webgpu.wasm", "RedPajama-INCITE-Chat-3B-v1-q4f16_0": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/RedPajama-INCITE-Chat-3B-v1-q4f16_0-webgpu.wasm" - } + }, + "use_web_worker": true } diff --git a/examples/simple-chat/src/llm_chat.html b/examples/simple-chat/src/llm_chat.html index f14e04f3..065a76be 100644 --- a/examples/simple-chat/src/llm_chat.html +++ b/examples/simple-chat/src/llm_chat.html @@ -1,7 +1,6 @@ -
diff --git a/examples/simple-chat/src/mlc-local-config.js b/examples/simple-chat/src/mlc-local-config.js index e0d9ce4a..8fbe9b83 100644 --- a/examples/simple-chat/src/mlc-local-config.js +++ b/examples/simple-chat/src/mlc-local-config.js @@ -22,5 +22,6 @@ export default { "vicuna-v1-7b-q4f32_0": "http://localhost:8000/vicuna-v1-7b-q4f32_0/vicuna-v1-7b-q4f32_0-webgpu.wasm", "RedPajama-INCITE-Chat-3B-v1-q4f32_0": "http://localhost:8000/RedPajama-INCITE-Chat-3B-v1-q4f32_0/RedPajama-INCITE-Chat-3B-v1-q4f32_0-webgpu.wasm", "RedPajama-INCITE-Chat-3B-v1-q4f16_0": "http://localhost:8000/RedPajama-INCITE-Chat-3B-v1-q4f16_0/RedPajama-INCITE-Chat-3B-v1-q4f16_0-webgpu.wasm" - } + }, + "use_web_worker": true } diff --git a/examples/simple-chat/src/simple_chat.ts b/examples/simple-chat/src/simple_chat.ts index c98c35ab..618c6595 100644 --- a/examples/simple-chat/src/simple_chat.ts +++ b/examples/simple-chat/src/simple_chat.ts @@ -1,5 +1,5 @@ import appConfig from "./app-config"; -import { ChatModule, ModelRecord } from "@mlc-ai/web-llm"; +import { ChatInterface, ChatModule, ChatWorkerClient, ModelRecord } from "@mlc-ai/web-llm"; function getElementAndCheck(id: string): HTMLElement { const element = document.getElementById(id); @@ -18,7 +18,7 @@ class ChatUI { private uiChat: HTMLElement; private uiChatInput: HTMLInputElement; private uiChatInfoLabel: HTMLLabelElement; - private chat: ChatModule; + private chat: ChatInterface; private config: AppConfig = appConfig; private selectedModel: string; private chatLoaded = false; @@ -27,8 +27,9 @@ class ChatUI { // all requests send to chat are sequentialized private chatRequestChain: Promise = Promise.resolve(); - constructor() { - this.chat = new ChatModule(); + constructor(chat: ChatInterface) { + // use web worker to run chat generation in background + this.chat = chat; // get the elements this.uiChat = getElementAndCheck("chatui-chat"); this.uiChatInput = getElementAndCheck("chatui-input") as HTMLInputElement; @@ -156,9 +157,10 @@ class ChatUI { private resetChatHistory() { const clearTags = ["left", "right", "init", "error"]; for (const tag of clearTags) { - const matches = this.uiChat.getElementsByClassName(`msg ${tag}-msg`); + // need to unpack to list so the iterator don't get affected by mutation + const matches = [...this.uiChat.getElementsByClassName(`msg ${tag}-msg`)]; for (const item of matches) { - item.remove(); + this.uiChat.removeChild(item); } } if (this.uiChatInfoLabel !== undefined) { @@ -211,11 +213,6 @@ class ChatUI { this.appendMessage("left", ""); const callbackUpdateResponse = (step, msg) => { - if (msg.endsWith("##")) { - msg = msg.substring(0, msg.length - 2); - } else if (msg.endsWith("#")) { - msg = msg.substring(0, msg.length - 1); - } this.updateLastMessage("left", msg); }; @@ -233,4 +230,15 @@ class ChatUI { } } -new ChatUI(); +const useWebWorker = appConfig.use_web_worker; +let chat: ChatInterface; + +if (useWebWorker) { + chat = new ChatWorkerClient(new Worker( + new URL('./worker.ts', import.meta.url), + {type: 'module'} + )); +} else { + chat = new ChatModule(); +} +new ChatUI(chat); diff --git a/examples/simple-chat/src/worker.ts b/examples/simple-chat/src/worker.ts new file mode 100644 index 00000000..5495c13d --- /dev/null +++ b/examples/simple-chat/src/worker.ts @@ -0,0 +1,8 @@ +// Serve the chat workload through web worker +import { ChatWorkerHandler, ChatModule } from "@mlc-ai/web-llm"; + +const chat = new ChatModule(); +const handler = new ChatWorkerHandler(chat); +self.onmessage = (msg: MessageEvent) => { + handler.onmessage(msg); +}; diff --git a/examples/web-worker/README.md b/examples/web-worker/README.md new file mode 100644 index 00000000..c6e00d62 --- /dev/null +++ b/examples/web-worker/README.md @@ -0,0 +1,25 @@ +# WebLLM Get Started with WebWorker + +This folder provides a minimum demo to show WebLLM API using +[WebWorker](https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Using_web_workers). +The main benefit of web worker is that all ML workloads runs on a separate thread as a result +will less likely block the UI. + +To try it out, you can do the following steps + +- Modify [package.json](package.json) to make sure either + - `@mlc-ai/web-llm` points to a valid npm version e.g. + ```js + "dependencies": { + "@mlc-ai/web-llm": "^0.2.0" + } + ``` + Try this option if you would like to use WebLLM without building it yourself. + - Or keep the dependencies as `"file:../.."`, and follow the build from source + instruction in the project to build webllm locally. This option is more useful + for developers who would like to hack WebLLM core package. +- Run the following command + ```bash + npm install + npm start + ``` diff --git a/examples/web-worker/package.json b/examples/web-worker/package.json new file mode 100644 index 00000000..424756dd --- /dev/null +++ b/examples/web-worker/package.json @@ -0,0 +1,17 @@ +{ + "name": "get-started-web-worker", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/get_started.html --port 8888", + "build": "parcel build src/get_started.html --dist-dir lib" + }, + "devDependencies": { + "parcel": "^2.8.3", + "typescript": "^4.9.5", + "tslib": "^2.3.1" + }, + "dependencies": { + "@mlc-ai/web-llm": "file:../.." + } +} diff --git a/examples/web-worker/src/get_started.html b/examples/web-worker/src/get_started.html new file mode 100644 index 00000000..a376ef62 --- /dev/null +++ b/examples/web-worker/src/get_started.html @@ -0,0 +1,22 @@ + + + + +

WebLLM Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + diff --git a/examples/web-worker/src/main.ts b/examples/web-worker/src/main.ts new file mode 100644 index 00000000..09812377 --- /dev/null +++ b/examples/web-worker/src/main.ts @@ -0,0 +1,41 @@ +import * as webllm from "@mlc-ai/web-llm"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +async function main() { + // Use a chat worker client instead of ChatModule here + const chat = new webllm.ChatWorkerClient(new Worker( + new URL('./worker.ts', import.meta.url), + {type: 'module'} + )); + + chat.setInitProgressCallback((report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }); + + await chat.reload("vicuna-v1-7b-q4f32_0"); + + const generateProgressCallback = (_step: number, message: string) => { + setLabel("generate-label", message); + }; + + const prompt0 = "What is the capital of Canada?"; + setLabel("prompt-label", prompt0); + const reply0 = await chat.generate(prompt0, generateProgressCallback); + console.log(reply0); + + const prompt1 = "Can you write a poem about it?"; + setLabel("prompt-label", prompt1); + const reply1 = await chat.generate(prompt1, generateProgressCallback); + console.log(reply1); + + console.log(await chat.runtimeStatsText()); +} + +main(); diff --git a/examples/web-worker/src/worker.ts b/examples/web-worker/src/worker.ts new file mode 100644 index 00000000..d480aec0 --- /dev/null +++ b/examples/web-worker/src/worker.ts @@ -0,0 +1,8 @@ +import { ChatWorkerHandler, ChatModule } from "@mlc-ai/web-llm"; + +// Hookup a chat module to a worker handler +const chat = new ChatModule(); +const handler = new ChatWorkerHandler(chat); +self.onmessage = (msg: MessageEvent) => { + handler.onmessage(msg); +}; diff --git a/package-lock.json b/package-lock.json index a3581895..3ff014f6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@mlc-ai/web-llm", - "version": "0.1.2", + "version": "0.2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@mlc-ai/web-llm", - "version": "0.1.2", + "version": "0.2.0", "license": "Apache-2.0", "devDependencies": { "@mlc-ai/web-tokenizers": "^0.1.0", diff --git a/package.json b/package.json index 34d2c20e..e93fc9d6 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@mlc-ai/web-llm", - "version": "0.1.3", + "version": "0.2.0", "description": "Hardware accelerated language model chats on browsers", "main": "lib/index.js", "types": "lib/index.d.ts", diff --git a/src/index.ts b/src/index.ts index 31fed193..3a2f1ec9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,3 +13,8 @@ export { export { ChatModule, } from "./chat_module"; + +export { + ChatWorkerHandler, + ChatWorkerClient +} from "./web_worker"; diff --git a/src/web_worker.ts b/src/web_worker.ts new file mode 100644 index 00000000..2e95260f --- /dev/null +++ b/src/web_worker.ts @@ -0,0 +1,327 @@ +import { AppConfig } from "./config"; +import { + ChatInterface, + ChatOptions, + GenerateProgressCallback, + InitProgressCallback, + InitProgressReport +} from "./types"; + +/** + * Message kind used by worker + */ +type RequestKind = ( + "return" | "throw" | + "reload" | "generate" | "runtimeStatsText" | + "interruptGenerate" | "unload" | "resetChat" | + "initProgressCallback" | "generateProgressCallback" +); + +interface ReloadParams { + localIdOrUrl: string; + chatOpts?: ChatOptions; + appConfig?: AppConfig +} + +interface GenerateParams { + input: string, + streamInterval?: number; +} + +interface GenerateProgressCallbackParams { + step: number, + currentMessage: string; +} + +type MessageContent = + GenerateProgressCallbackParams | + ReloadParams | + GenerateParams | + InitProgressReport | + string | + null; + +/** + * The message used in exchange between worker + * and the main thread. + */ +interface WorkerMessage { + kind: RequestKind, + uuid: string, + content: MessageContent; +} + +/** + * Worker handler that can be used in a WebWorker + * + * @example + * + * // setup a chat worker handler that routes + * // requests to the chat + * const chat = new ChatModule(); + * cont handler = new ChatWorkerHandler(chat); + * onmessage = handler.onmessage; + */ +export class ChatWorkerHandler { + private chat: ChatInterface; + + constructor(chat: ChatInterface) { + this.chat = chat; + this.chat.setInitProgressCallback((report: InitProgressReport) => { + const msg: WorkerMessage = { + kind: "initProgressCallback", + uuid: "", + content: report + }; + postMessage(msg); + }); + } + + async handleTask(uuid: string, task: ()=>Promise) { + try { + const res = await task(); + const msg: WorkerMessage = { + kind: "return", + uuid: uuid, + content: res + }; + postMessage(msg); + } catch(err) { + const errStr = (err as object).toString(); + const msg: WorkerMessage = { + kind: "throw", + uuid: uuid, + content: errStr + }; + postMessage(msg); + } + } + + onmessage(event: MessageEvent) { + const msg = event.data as WorkerMessage; + switch(msg.kind) { + case "reload": { + this.handleTask(msg.uuid, async () => { + const params = msg.content as ReloadParams; + await this.chat.reload(params.localIdOrUrl, params.chatOpts, params.appConfig); + return null; + }) + return; + } + case "generate": { + this.handleTask(msg.uuid, async() => { + const params = msg.content as GenerateParams; + const progressCallback = (step: number, currentMessage: string) => { + const cbMessage: WorkerMessage = { + kind: "generateProgressCallback", + uuid: msg.uuid, + content: { + step: step, + currentMessage: currentMessage + } + }; + postMessage(cbMessage); + }; + return await this.chat.generate(params.input, progressCallback, params.streamInterval); + }) + return; + } + case "runtimeStatsText": { + this.handleTask(msg.uuid, async() => { + return await this.chat.runtimeStatsText(); + }); + return; + } + case "interruptGenerate": { + this.handleTask(msg.uuid, async () => { + this.chat.interruptGenerate(); + return null; + }); + return; + } + case "unload": { + this.handleTask(msg.uuid, async () => { + await this.chat.unload(); + return null; + }); + return; + } + case "resetChat": { + this.handleTask(msg.uuid, async () => { + await this.chat.resetChat(); + return null; + }); + return; + } + default: { + throw Error("Invalid kind, msg=" + msg); + } + } + } +} + +interface ChatWorker { + onmessage: any, + postMessage: (message: any) => void; +} + +/** + * A client of chat worker that exposes the chat interface + * + * @example + * + * const chat = new webllm.ChatWorkerClient(new Worker( + * new URL('./worker.ts', import.meta.url), + * {type: 'module'} + * )); + */ +export class ChatWorkerClient implements ChatInterface { + public worker: ChatWorker; + private initProgressCallback?: InitProgressCallback; + private generateCallbackRegistry = new Map(); + private pendingPromise = new Mapvoid>(); + + constructor(worker: any) { + this.worker = worker; + worker.onmessage = (event: any) => { + this.onmessage(event); + } + } + + setInitProgressCallback(initProgressCallback: InitProgressCallback) { + this.initProgressCallback = initProgressCallback; + } + + private getPromise(msg: WorkerMessage): Promise { + const uuid = msg.uuid; + const executor = ( + resolve: (arg: T) => void, + reject: (arg: any) => void + ) => { + const cb = (msg: WorkerMessage) => { + if (msg.kind == "return") { + resolve(msg.content as T); + } else { + if (msg.kind != "throw") { + reject("Uknown msg kind " + msg.kind); + } else { + reject(msg.content); + } + } + }; + this.pendingPromise.set(uuid, cb); + }; + const promise = new Promise(executor); + this.worker.postMessage(msg); + return promise; + } + + async reload(localIdOrUrl: string, chatOpts?: ChatOptions, appConfig?: AppConfig): Promise { + const msg: WorkerMessage = { + kind: "reload", + uuid: crypto.randomUUID(), + content: { + localIdOrUrl: localIdOrUrl, + chatOpts: chatOpts, + appConfig: appConfig + } + }; + await this.getPromise(msg); + } + + async generate( + input: string, + progressCallback?: GenerateProgressCallback, + streamInterval?: number + ) : Promise { + const msg: WorkerMessage = { + kind: "generate", + uuid: crypto.randomUUID(), + content: { + input: input, + streamInterval: streamInterval + } + }; + if (progressCallback !== undefined) { + this.generateCallbackRegistry.set(msg.uuid, progressCallback); + } + return await this.getPromise(msg); + } + + async runtimeStatsText(): Promise { + const msg: WorkerMessage = { + kind: "runtimeStatsText", + uuid: crypto.randomUUID(), + content: null + }; + return await this.getPromise(msg); + } + + interruptGenerate(): void { + const msg: WorkerMessage = { + kind: "interruptGenerate", + uuid: crypto.randomUUID(), + content: null + }; + this.getPromise(msg); + } + + async unload(): Promise { + const msg: WorkerMessage = { + kind: "unload", + uuid: crypto.randomUUID(), + content: null + }; + await this.getPromise(msg); + } + + async resetChat(): Promise { + const msg: WorkerMessage = { + kind: "resetChat", + uuid: crypto.randomUUID(), + content: null + }; + await this.getPromise(msg); + } + + onmessage(event: any) { + const msg = event.data as WorkerMessage; + switch (msg.kind) { + case "initProgressCallback": { + if (this.initProgressCallback !== undefined) { + this.initProgressCallback(msg.content as InitProgressReport); + } + return; + } + case "generateProgressCallback": { + const params = msg.content as GenerateProgressCallbackParams; + const cb = this.generateCallbackRegistry.get(msg.uuid); + if (cb !== undefined) { + cb(params.step, params.currentMessage); + } + return; + } + case "return": { + const cb = this.pendingPromise.get(msg.uuid); + if (cb === undefined) { + throw Error("return from a unknown uuid msg=" + msg.uuid); + } + this.pendingPromise.delete(msg.uuid); + cb(msg); + return; + } + case "throw": { + const cb = this.pendingPromise.get(msg.uuid); + if (cb === undefined) { + throw Error("return from a unknown uuid, msg=" + msg); + } + this.pendingPromise.delete(msg.uuid); + cb(msg); + return; + } + default: { + throw Error("Unknown msg kind, msg=" + msg); + } + } + } +}