Skip to content

Commit

Permalink
Raise WebGPU errors to JavaScript.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631569286
  • Loading branch information
MediaPipe Team authored and copybara-github committed May 7, 2024
1 parent 99fc736 commit 28c9032
Showing 1 changed file with 119 additions and 14 deletions.
133 changes: 119 additions & 14 deletions mediapipe/tasks/web/genai/llm_inference/llm_inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,31 @@ export class LoraModel {
}
}

/**
* A wrapper around native Promise; exposes functions to resolve or reject it.
*/
class Deferred<T> {
/** The wrapped by this Deferred. */
readonly promise: Promise<T>;

/** Resolve with the provided value. */
readonly resolve: (result: T) => void;

/** Reject with the provided reasons. */
readonly reject: (reasons?: Array<Error | GPUError>) => void;

constructor() {
let resolve!: (value: T) => void;
let reject!: (reasons?: Array<Error | GPUError>) => void;
this.promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
this.resolve = resolve;
this.reject = reject;
}
}

/**
* Performs LLM Inference on text.
*/
Expand All @@ -122,10 +147,42 @@ export class LlmInference extends TaskRunner {
private readonly samplerParams: SamplerParameters;
private isProcessing = false;
private latestTokenCostQueryResult?: number;
private resolveGeneration?: (result: string) => void;
private resultDeferred?: Deferred<string>;
private userProgressListener?: ProgressListener;
private streamingReader?: StreamingReader;

// The WebGPU device used for LLM inference.
private wgpuDevice?: GPUDevice;
// Holds WebGPU errors for WebGPU-involved invocations. Should be checked and
// cleaned up after each WebGPU-involved invocation.
private readonly wgpuErrors: Array<Error | GPUError> = [];

/**
* For each WebGPU's 'uncapturederror' event, hold the error. Also, add hints
* into error message, if it's known by the task.
*/
private readonly wgpuErrorHandler = (event: Event) => {
let error = (event as GPUUncapturedErrorEvent).error;
const bufferSizeError = error.message.match(
/exceeds the max buffer size limit \(([0-9]+)\)\./,
);
if (
bufferSizeError &&
Number(bufferSizeError[1]) > MAX_BUFFER_SIZE_FOR_LLM
) {
error = new Error(
`Failed to run this LLM model, but you could try a smaller LLM ` +
`model. WebGPU throws: "${error.message}"`,
);
} else if (error.message.match(/is larger than the maximum binding size/)) {
error = new Error(
`Failed to run LLM inference, the supported max binding size is ` +
`smaller than the required size. WebGPU throws: "${error.message}"`,
);
}
this.wgpuErrors.push(error);
};

/**
* Initializes the Wasm runtime and creates a new `LlmInference` based
* on the provided options.
Expand Down Expand Up @@ -288,8 +345,19 @@ export class LlmInference extends TaskRunner {
this.isProcessing = true;

if (options.baseOptions?.gpuOptions?.device) {
if (this.wgpuDevice) {
this.wgpuDevice.removeEventListener(
'uncapturederror',
this.wgpuErrorHandler,
);
}
this.wgpuDevice = options.baseOptions.gpuOptions.device;
(this.graphRunner as unknown as LlmGraphRunner).initializeForWebGpu(
options.baseOptions.gpuOptions.device,
this.wgpuDevice,
);
this.wgpuDevice.addEventListener(
'uncapturederror',
this.wgpuErrorHandler,
);
}
if ('maxTokens' in options) {
Expand Down Expand Up @@ -376,12 +444,13 @@ export class LlmInference extends TaskRunner {
// code hasn't quite finished running by this point in time. However, that
// microtask seems to complete before any code await-ing this function, so
// this should be fine. This seems to be similarly true for our
// resolveGeneration usage as well.
return finishedLoadingDataPromise.then(() => {
refreshGraphPromise.then(() => {
// resultDeferred usage as well.
return Promise.all([finishedLoadingDataPromise, refreshGraphPromise]).then(
() => {
this.isProcessing = false;
});
});
this.checkWgpuErrors();
},
);
}

protected override get baseOptions(): BaseOptionsProto {
Expand Down Expand Up @@ -481,9 +550,8 @@ export class LlmInference extends TaskRunner {
);
}
this.finishProcessing();
return new Promise<string>((resolve, reject) => {
this.resolveGeneration = resolve;
});
this.resultDeferred = new Deferred<string>();
return this.resultDeferred.promise;
}

/**
Expand Down Expand Up @@ -592,6 +660,22 @@ export class LlmInference extends TaskRunner {
this.samplerParams.setTemperature(DEFAULT_TEMPERATURE);
}

/** Checks if there are any WebGPU errors and throws them if so. */
private checkWgpuErrors(): void {
if (this.wgpuErrors.length > 0) {
// Clean the stack of errors.
const errors = [...this.wgpuErrors];
this.wgpuErrors.length = 0;

if (this.resultDeferred) {
this.resultDeferred.reject(errors);
this.resultDeferred = undefined;
} else {
throw errors;
}
}
}

// TODO: b/324919242 - Add sync API for BYOM Web API when Chrome JSPI is
// available

Expand All @@ -608,7 +692,8 @@ export class LlmInference extends TaskRunner {
stripLeadingWhitespace,
);
this.generationResult.push(decodedText);
if (this.userProgressListener) {
// Don't trigger the user progress listener if there are WebGPU errors.
if (this.userProgressListener && this.wgpuErrors.length === 0) {
this.userProgressListener(decodedText, /* done= */ false);
}
this.setLatestOutputTimestamp(timestamp);
Expand All @@ -622,20 +707,27 @@ export class LlmInference extends TaskRunner {
OUTPUT_END_STREAM,
(bool, timestamp) => {
this.isProcessing = false;
if (this.resolveGeneration) {
this.resolveGeneration(this.generationResult.join(''));
this.setLatestOutputTimestamp(timestamp);
this.checkWgpuErrors();
if (this.resultDeferred) {
this.resultDeferred.resolve(this.generationResult.join(''));
this.resultDeferred = undefined;
}
if (this.userProgressListener) {
this.userProgressListener(/* partialResult= */ '', /* done= */ true);
}
this.setLatestOutputTimestamp(timestamp);
},
);
this.graphRunner.attachEmptyPacketListener(
OUTPUT_END_STREAM,
(timestamp) => {
this.isProcessing = false;
this.setLatestOutputTimestamp(timestamp);
this.checkWgpuErrors();
if (this.resultDeferred) {
this.resultDeferred.resolve(this.generationResult.join(''));
this.resultDeferred = undefined;
}
},
);

Expand Down Expand Up @@ -665,9 +757,18 @@ export class LlmInference extends TaskRunner {
// instead, we use a special async-only variant of closeGraph which we can
// chain into our promises to ensure proper ordering, calling that first so
// the built-in closeGraph becomes a no-op.
this.wgpuDevice?.removeEventListener(
'uncapturederror',
this.wgpuErrorHandler,
);
return (this.graphRunner as unknown as LlmGraphRunner)
.closeGraphAsync()
.then(() => {
this.wgpuDevice?.addEventListener(
'uncapturederror',
this.wgpuErrorHandler,
);
this.wgpuErrors.length = 0;
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
// Start initialization; this is async when StreamingReader is used.
this.finishProcessing();
Expand Down Expand Up @@ -837,6 +938,10 @@ export class LlmInference extends TaskRunner {
}

override close() {
this.wgpuDevice?.removeEventListener(
'uncapturederror',
this.wgpuErrorHandler,
);
super.close();
}
}
Expand Down

0 comments on commit 28c9032

Please sign in to comment.