diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts index 674e17a938..1a9b89aace 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts @@ -107,6 +107,31 @@ export class LoraModel { } } +/** + * A wrapper around native Promise; exposes functions to resolve or reject it. + */ +class Deferred { + /** The wrapped by this Deferred. */ + readonly promise: Promise; + + /** Resolve with the provided value. */ + readonly resolve: (result: T) => void; + + /** Reject with the provided reasons. */ + readonly reject: (reasons?: Array) => void; + + constructor() { + let resolve!: (value: T) => void; + let reject!: (reasons?: Array) => void; + this.promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + this.resolve = resolve; + this.reject = reject; + } +} + /** * Performs LLM Inference on text. */ @@ -122,10 +147,42 @@ export class LlmInference extends TaskRunner { private readonly samplerParams: SamplerParameters; private isProcessing = false; private latestTokenCostQueryResult?: number; - private resolveGeneration?: (result: string) => void; + private resultDeferred?: Deferred; private userProgressListener?: ProgressListener; private streamingReader?: StreamingReader; + // The WebGPU device used for LLM inference. + private wgpuDevice?: GPUDevice; + // Holds WebGPU errors for WebGPU-involved invocations. Should be checked and + // cleaned up after each WebGPU-involved invocation. + private readonly wgpuErrors: Array = []; + + /** + * For each WebGPU's 'uncapturederror' event, hold the error. Also, add hints + * into error message, if it's known by the task. + */ + private readonly wgpuErrorHandler = (event: Event) => { + let error = (event as GPUUncapturedErrorEvent).error; + const bufferSizeError = error.message.match( + /exceeds the max buffer size limit \(([0-9]+)\)\./, + ); + if ( + bufferSizeError && + Number(bufferSizeError[1]) > MAX_BUFFER_SIZE_FOR_LLM + ) { + error = new Error( + `Failed to run this LLM model, but you could try a smaller LLM ` + + `model. WebGPU throws: "${error.message}"`, + ); + } else if (error.message.match(/is larger than the maximum binding size/)) { + error = new Error( + `Failed to run LLM inference, the supported max binding size is ` + + `smaller than the required size. WebGPU throws: "${error.message}"`, + ); + } + this.wgpuErrors.push(error); + }; + /** * Initializes the Wasm runtime and creates a new `LlmInference` based * on the provided options. @@ -288,8 +345,19 @@ export class LlmInference extends TaskRunner { this.isProcessing = true; if (options.baseOptions?.gpuOptions?.device) { + if (this.wgpuDevice) { + this.wgpuDevice.removeEventListener( + 'uncapturederror', + this.wgpuErrorHandler, + ); + } + this.wgpuDevice = options.baseOptions.gpuOptions.device; (this.graphRunner as unknown as LlmGraphRunner).initializeForWebGpu( - options.baseOptions.gpuOptions.device, + this.wgpuDevice, + ); + this.wgpuDevice.addEventListener( + 'uncapturederror', + this.wgpuErrorHandler, ); } if ('maxTokens' in options) { @@ -376,12 +444,13 @@ export class LlmInference extends TaskRunner { // code hasn't quite finished running by this point in time. However, that // microtask seems to complete before any code await-ing this function, so // this should be fine. This seems to be similarly true for our - // resolveGeneration usage as well. - return finishedLoadingDataPromise.then(() => { - refreshGraphPromise.then(() => { + // resultDeferred usage as well. + return Promise.all([finishedLoadingDataPromise, refreshGraphPromise]).then( + () => { this.isProcessing = false; - }); - }); + this.checkWgpuErrors(); + }, + ); } protected override get baseOptions(): BaseOptionsProto { @@ -481,9 +550,8 @@ export class LlmInference extends TaskRunner { ); } this.finishProcessing(); - return new Promise((resolve, reject) => { - this.resolveGeneration = resolve; - }); + this.resultDeferred = new Deferred(); + return this.resultDeferred.promise; } /** @@ -592,6 +660,22 @@ export class LlmInference extends TaskRunner { this.samplerParams.setTemperature(DEFAULT_TEMPERATURE); } + /** Checks if there are any WebGPU errors and throws them if so. */ + private checkWgpuErrors(): void { + if (this.wgpuErrors.length > 0) { + // Clean the stack of errors. + const errors = [...this.wgpuErrors]; + this.wgpuErrors.length = 0; + + if (this.resultDeferred) { + this.resultDeferred.reject(errors); + this.resultDeferred = undefined; + } else { + throw errors; + } + } + } + // TODO: b/324919242 - Add sync API for BYOM Web API when Chrome JSPI is // available @@ -608,7 +692,8 @@ export class LlmInference extends TaskRunner { stripLeadingWhitespace, ); this.generationResult.push(decodedText); - if (this.userProgressListener) { + // Don't trigger the user progress listener if there are WebGPU errors. + if (this.userProgressListener && this.wgpuErrors.length === 0) { this.userProgressListener(decodedText, /* done= */ false); } this.setLatestOutputTimestamp(timestamp); @@ -622,13 +707,15 @@ export class LlmInference extends TaskRunner { OUTPUT_END_STREAM, (bool, timestamp) => { this.isProcessing = false; - if (this.resolveGeneration) { - this.resolveGeneration(this.generationResult.join('')); + this.setLatestOutputTimestamp(timestamp); + this.checkWgpuErrors(); + if (this.resultDeferred) { + this.resultDeferred.resolve(this.generationResult.join('')); + this.resultDeferred = undefined; } if (this.userProgressListener) { this.userProgressListener(/* partialResult= */ '', /* done= */ true); } - this.setLatestOutputTimestamp(timestamp); }, ); this.graphRunner.attachEmptyPacketListener( @@ -636,6 +723,11 @@ export class LlmInference extends TaskRunner { (timestamp) => { this.isProcessing = false; this.setLatestOutputTimestamp(timestamp); + this.checkWgpuErrors(); + if (this.resultDeferred) { + this.resultDeferred.resolve(this.generationResult.join('')); + this.resultDeferred = undefined; + } }, ); @@ -665,9 +757,18 @@ export class LlmInference extends TaskRunner { // instead, we use a special async-only variant of closeGraph which we can // chain into our promises to ensure proper ordering, calling that first so // the built-in closeGraph becomes a no-op. + this.wgpuDevice?.removeEventListener( + 'uncapturederror', + this.wgpuErrorHandler, + ); return (this.graphRunner as unknown as LlmGraphRunner) .closeGraphAsync() .then(() => { + this.wgpuDevice?.addEventListener( + 'uncapturederror', + this.wgpuErrorHandler, + ); + this.wgpuErrors.length = 0; this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); // Start initialization; this is async when StreamingReader is used. this.finishProcessing(); @@ -837,6 +938,10 @@ export class LlmInference extends TaskRunner { } override close() { + this.wgpuDevice?.removeEventListener( + 'uncapturederror', + this.wgpuErrorHandler, + ); super.close(); } }