From fb75f149a2946e3b1a2becce6cf181fc09ceddfe Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 6 May 2024 09:56:15 -0700 Subject: [PATCH] Allow models to be uploaded via ReadableStreamDefaultReader PiperOrigin-RevId: 631097139 --- mediapipe/tasks/web/core/task_runner.ts | 194 +++++++++++++----- .../tasks/web/core/task_runner_options.d.ts | 10 +- mediapipe/tasks/web/core/task_runner_test.ts | 189 ++++++++++++----- .../web/genai/llm_inference/llm_inference.ts | 19 +- .../graph_runner_streaming_reader.ts | 33 ++- 5 files changed, 317 insertions(+), 128 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 735f541ec1..0daa6b7583 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -19,8 +19,16 @@ import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; -import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner'; +import { + BaseOptions, + TaskRunnerOptions, +} from '../../../tasks/web/core/task_runner_options'; +import { + FileLocator, + GraphRunner, + WasmMediaPipeConstructor, + createMediaPipeLib, +} from '../../../web/graph_runner/graph_runner'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; import {WasmFileset} from './wasm_fileset'; @@ -47,9 +55,11 @@ export class CachedGraphRunner extends CachedGraphRunnerType {} * @return A fully instantiated instance of `T`. */ export async function createTaskRunner( - type: WasmMediaPipeConstructor, - canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined, - fileset: WasmFileset, options: TaskRunnerOptions): Promise { + type: WasmMediaPipeConstructor, + canvas: HTMLCanvasElement | OffscreenCanvas | null | undefined, + fileset: WasmFileset, + options: TaskRunnerOptions, +): Promise { const fileLocator: FileLocator = { locateFile(file): string { // We currently only use a single .wasm file and a single .data file (for @@ -62,12 +72,16 @@ export async function createTaskRunner( return fileset.assetBinaryPath.toString(); } return file; - } + }, }; const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas, - fileLocator); + type, + fileset.wasmLoaderPath, + fileset.assetLoaderPath, + canvas, + fileLocator, + ); await instance.setOptions(options); return instance; } @@ -85,9 +99,11 @@ export abstract class TaskRunner { * @return A fully instantiated instance of `T`. */ protected static async createInstance( - type: WasmMediaPipeConstructor, - canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined, - fileset: WasmFileset, options: TaskRunnerOptions): Promise { + type: WasmMediaPipeConstructor, + canvas: HTMLCanvasElement | OffscreenCanvas | null | undefined, + fileset: WasmFileset, + options: TaskRunnerOptions, + ): Promise { return createTaskRunner(type, canvas, fileset, options); } @@ -112,22 +128,32 @@ export abstract class TaskRunner { * @param loadTfliteModel Whether to load the model specified in * `options.baseOptions`. */ - protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true): - Promise { + protected applyOptions( + options: TaskRunnerOptions, + loadTfliteModel = true, + ): Promise { if (loadTfliteModel) { const baseOptions: BaseOptions = options.baseOptions || {}; // Validate that exactly one model is configured - if (options.baseOptions?.modelAssetBuffer && - options.baseOptions?.modelAssetPath) { + if ( + options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath + ) { throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || - this.baseOptions.getModelAsset()?.hasFileName() || - options.baseOptions?.modelAssetBuffer || - options.baseOptions?.modelAssetPath)) { + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer', + ); + } else if ( + !( + this.baseOptions.getModelAsset()?.hasFileContent() || + this.baseOptions.getModelAsset()?.hasFileName() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath + ) + ) { throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set', + ); } this.setAcceleration(baseOptions); @@ -135,33 +161,45 @@ export abstract class TaskRunner { // We don't use `await` here since we want to apply most settings // synchronously. return fetch(baseOptions.modelAssetPath.toString()) - .then(response => { - if (!response.ok) { - throw new Error(`Failed to fetch model: ${ - baseOptions.modelAssetPath} (${response.status})`); - } else { - return response.arrayBuffer(); - } - }) - .then(buffer => { - try { - // Try to delete file as we cannot overwrite an existing file - // using our current API. - this.graphRunner.wasmModule.FS_unlink('/model.dat'); - } catch { - } - // TODO: Consider passing the model to the graph as an - // input side packet as this might reduce copies. - this.graphRunner.wasmModule.FS_createDataFile( - '/', 'model.dat', new Uint8Array(buffer), - /* canRead= */ true, /* canWrite= */ false, - /* canOwn= */ false); - this.setExternalFile('/model.dat'); - this.refreshGraph(); - this.onGraphRefreshed(); - }); - } else { + .then((response) => { + if (!response.ok) { + throw new Error( + `Failed to fetch model: ${baseOptions.modelAssetPath} (${response.status})`, + ); + } else { + return response.arrayBuffer(); + } + }) + .then((buffer) => { + try { + // Try to delete file as we cannot overwrite an existing file + // using our current API. + this.graphRunner.wasmModule.FS_unlink('/model.dat'); + } catch {} + // TODO: Consider passing the model to the graph as an + // input side packet as this might reduce copies. + this.graphRunner.wasmModule.FS_createDataFile( + '/', + 'model.dat', + new Uint8Array(buffer), + /* canRead= */ true, + /* canWrite= */ false, + /* canOwn= */ false, + ); + this.setExternalFile('/model.dat'); + this.refreshGraph(); + this.onGraphRefreshed(); + }); + } else if (baseOptions.modelAssetBuffer instanceof Uint8Array) { this.setExternalFile(baseOptions.modelAssetBuffer); + } else if (baseOptions.modelAssetBuffer) { + return streamToUint8Array(baseOptions.modelAssetBuffer).then( + (buffer) => { + this.setExternalFile(buffer); + this.refreshGraph(); + this.onGraphRefreshed(); + }, + ); } } @@ -182,8 +220,8 @@ export abstract class TaskRunner { /** Returns the current CalculatorGraphConfig. */ protected getCalculatorGraphConfig(): CalculatorGraphConfig { - let config: CalculatorGraphConfig|undefined; - this.graphRunner.getCalculatorGraphConfig(binaryData => { + let config: CalculatorGraphConfig | undefined; + this.graphRunner.getCalculatorGraphConfig((binaryData) => { config = CalculatorGraphConfig.deserializeBinary(binaryData); }); if (!config) { @@ -232,8 +270,10 @@ export abstract class TaskRunner { * ignored. */ protected setLatestOutputTimestamp(timestamp: number): void { - this.latestOutputTimestamp = - Math.max(this.latestOutputTimestamp, timestamp); + this.latestOutputTimestamp = Math.max( + this.latestOutputTimestamp, + timestamp, + ); } /** @@ -254,8 +294,9 @@ export abstract class TaskRunner { throw new Error(this.processingErrors[0].message); } else if (errorCount > 1) { throw new Error( - 'Encountered multiple errors: ' + - this.processingErrors.map(e => e.message).join(', ')); + 'Encountered multiple errors: ' + + this.processingErrors.map((e) => e.message).join(', '), + ); } } finally { this.processingErrors = []; @@ -265,7 +306,9 @@ export abstract class TaskRunner { /** Configures the `externalFile` option */ protected setExternalFile(modelAssetPath?: string): void; protected setExternalFile(modelAssetBuffer?: Uint8Array): void; - protected setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void { + protected setExternalFile( + modelAssetPathOrBuffer?: Uint8Array | string, + ): void { const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); if (typeof modelAssetPathOrBuffer === 'string') { externalFile.setFileName(modelAssetPathOrBuffer); @@ -292,7 +335,8 @@ export abstract class TaskRunner { acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); } else { acceleration.setTflite( - new InferenceCalculatorOptions.Delegate.TfLite()); + new InferenceCalculatorOptions.Delegate.TfLite(), + ); } } @@ -309,7 +353,8 @@ export abstract class TaskRunner { this.keepaliveNode.setCalculator('PassThroughCalculator'); this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM); this.keepaliveNode.addOutputStream( - FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX); + FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX, + ); graphConfig.addInputStream(FREE_MEMORY_STREAM); graphConfig.addNode(this.keepaliveNode); } @@ -323,7 +368,10 @@ export abstract class TaskRunner { /** Frees any streams being kept alive by the keepStreamAlive callback. */ protected freeKeepaliveStreams() { this.graphRunner.addBoolToStream( - true, FREE_MEMORY_STREAM, this.latestOutputTimestamp); + true, + FREE_MEMORY_STREAM, + this.latestOutputTimestamp, + ); } /** @@ -336,4 +384,36 @@ export abstract class TaskRunner { } } +/** Converts a ReadableStreamDefaultReader to a Uint8Array. */ +async function streamToUint8Array( + reader: ReadableStreamDefaultReader, +): Promise { + const chunks: Uint8Array[] = []; + let totalLength = 0; + + while (true) { + const {done, value} = await reader.read(); + if (done) { + break; + } + chunks.push(value); + totalLength += value.length; + } + + if (chunks.length === 0) { + return new Uint8Array(0); + } else if (chunks.length === 1) { + return chunks[0]; + } else { + // Merge chunks + const combined = new Uint8Array(totalLength); + let offset = 0; + for (const chunk of chunks) { + combined.set(chunk, offset); + offset += chunk.length; + } + return combined; + } +} + diff --git a/mediapipe/tasks/web/core/task_runner_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts index 6b09608966..cc3dd1b34b 100644 --- a/mediapipe/tasks/web/core/task_runner_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -22,16 +22,16 @@ export declare interface BaseOptions { * The model path to the model asset file. Only one of `modelAssetPath` or * `modelAssetBuffer` can be set. */ - modelAssetPath?: string|undefined; + modelAssetPath?: string | undefined; /** - * A buffer containing the model aaset. Only one of `modelAssetPath` or - * `modelAssetBuffer` can be set. + * A buffer or stream reader containing the model asset. Only one of + * `modelAssetPath` or `modelAssetBuffer` can be set. */ - modelAssetBuffer?: Uint8Array|undefined; + modelAssetBuffer?: Uint8Array | ReadableStreamDefaultReader | undefined; /** Overrides the default backend to use for the provided model. */ - delegate?: 'CPU'|'GPU'|undefined; + delegate?: 'CPU' | 'GPU' | undefined; } /** Options to configure MediaPipe Tasks in general. */ diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index ab867c8f64..8b6f40f541 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -20,7 +20,10 @@ import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_ import {GpuOrigin as GpuOriginProto} from '../../../gpu/gpu_origin_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; -import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; +import { + createSpyWasmModule, + SpyWasmModule, +} from '../../../tasks/web/core/task_runner_test_utils'; import * as graphRunner from '../../../web/graph_runner/graph_runner'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource URL builder @@ -30,11 +33,11 @@ import {TaskRunnerOptions} from './task_runner_options'; import {WasmFileset} from './wasm_fileset'; type Writeable = { - -readonly[P in keyof T]: T[P] + -readonly [P in keyof T]: T[P]; }; class TaskRunnerFake extends TaskRunner { - private errorListener: ErrorListener|undefined; + private errorListener: ErrorListener | undefined; private errors: string[] = []; baseOptions = new BaseOptionsProto(); @@ -44,14 +47,20 @@ class TaskRunnerFake extends TaskRunner { } constructor() { - super(jasmine.createSpyObj([ - 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', - 'registerModelResourcesGraphService', 'attachErrorListener' - ])); - const graphRunner = - this.graphRunner as jasmine.SpyObj>; + super( + jasmine.createSpyObj([ + 'setAutoRenderToScreen', + 'setGraph', + 'finishProcessing', + 'registerModelResourcesGraphService', + 'attachErrorListener', + ]), + ); + const graphRunner = this.graphRunner as jasmine.SpyObj< + Writeable + >; expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); - graphRunner.attachErrorListener.and.callFake(listener => { + graphRunner.attachErrorListener.and.callFake((listener) => { this.errorListener = listener; }); graphRunner.setGraph.and.callFake(() => { @@ -79,8 +88,9 @@ class TaskRunnerFake extends TaskRunner { override setGraph(graphData: Uint8Array, isBinary: boolean): void { super.setGraph(graphData, isBinary); - expect(this.graphRunner.registerModelResourcesGraphService) - .toHaveBeenCalled(); + expect( + this.graphRunner.registerModelResourcesGraphService, + ).toHaveBeenCalled(); } setOptions(options: TaskRunnerOptions): Promise { @@ -124,11 +134,13 @@ describe('TaskRunner', () => { allowPrecisionLoss: true, cachedKernelPath: undefined, serializedModelDir: undefined, - cacheWritingBehavior: InferenceCalculatorOptions.Delegate.Gpu - .CacheWritingBehavior.WRITE_OR_ERROR, + cacheWritingBehavior: + InferenceCalculatorOptions.Delegate.Gpu.CacheWritingBehavior + .WRITE_OR_ERROR, modelToken: undefined, - usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage - .SUSTAINED_SPEED, + usage: + InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage + .SUSTAINED_SPEED, }, tflite: undefined, nnapi: undefined, @@ -154,13 +166,13 @@ describe('TaskRunner', () => { let fetchSpy: jasmine.Spy; let taskRunner: TaskRunnerFake; let fetchStatus: number; - let locator: graphRunner.FileLocator|undefined; + let locator: graphRunner.FileLocator | undefined; let oldCreate = graphRunner.createMediaPipeLib; beforeEach(() => { fetchStatus = 200; - fetchSpy = jasmine.createSpy().and.callFake(async url => { + fetchSpy = jasmine.createSpy().and.callFake(async (url) => { return { arrayBuffer: () => mockBytes.buffer, ok: fetchStatus === 200, @@ -172,13 +184,15 @@ describe('TaskRunner', () => { // Monkeypatch an exported static method for testing! oldCreate = graphRunner.createMediaPipeLib; locator = undefined; - (graphRunner as {createMediaPipeLib: Function}).createMediaPipeLib = - jasmine.createSpy().and.callFake( - (type, wasmLoaderPath, assetLoaderPath, canvas, fileLocator) => { - locator = fileLocator; - // tslint:disable-next-line:no-any Monkeypatching for test mocks. - return Promise.resolve(taskRunner as any); - }); + (graphRunner as {createMediaPipeLib: Function}).createMediaPipeLib = jasmine + .createSpy() + .and.callFake( + (type, wasmLoaderPath, assetLoaderPath, canvas, fileLocator) => { + locator = fileLocator; + // tslint:disable-next-line:no-any Monkeypatching for test mocks. + return Promise.resolve(taskRunner as any); + }, + ); taskRunner = TaskRunnerFake.createFake(); }); @@ -186,7 +200,7 @@ describe('TaskRunner', () => { afterEach(() => { // Restore the monkeypatch. (graphRunner as {createMediaPipeLib: Function}).createMediaPipeLib = - oldCreate; + oldCreate; }); it('constructs with useful file locators for asset.data files', () => { @@ -200,7 +214,7 @@ describe('TaskRunner', () => { const options = { baseOptions: { modelAssetPath: `modelAssetPath`, - } + }, }; const runner = createTaskRunner(TaskRunnerFake, null, fileset, options); @@ -222,7 +236,7 @@ describe('TaskRunner', () => { const options = { baseOptions: { modelAssetPath: `modelAssetPath`, - } + }, }; const runner = createTaskRunner(TaskRunnerFake, null, fileset, options); @@ -275,9 +289,9 @@ describe('TaskRunner', () => { it('verifies that at least one model asset option is provided', () => { expect(() => { taskRunner.setOptions({}); - }) - .toThrowError( - /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }).toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/, + ); }); it('verifies that no more than one model asset option is provided', () => { @@ -285,25 +299,27 @@ describe('TaskRunner', () => { taskRunner.setOptions({ baseOptions: { modelAssetPath: `foo`, - modelAssetBuffer: new Uint8Array([]) - } + modelAssetBuffer: new Uint8Array([]), + }, }); - }) - .toThrowError( - /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }).toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/, + ); }); - it('doesn\'t require model once it is configured', async () => { - await taskRunner.setOptions( - {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + it("doesn't require model once it is configured", async () => { + await taskRunner.setOptions({ + baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}, + }); expect(() => { taskRunner.setOptions({}); }).not.toThrowError(); }); it('writes model to file system', async () => { - await taskRunner.setOptions( - {baseOptions: {modelAssetPath: `foo`}}); + await taskRunner.setOptions({ + baseOptions: {modelAssetPath: `foo`}, + }); expect(fetchSpy).toHaveBeenCalled(); expect(taskRunner.wasmModule.FS_createDataFile).toHaveBeenCalled(); @@ -311,16 +327,72 @@ describe('TaskRunner', () => { }); it('does not download model when bytes are provided', async () => { - await taskRunner.setOptions( - {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + await taskRunner.setOptions({ + baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}, + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can read from ReadableStreamDefaultReader (with empty data)', async () => { + if (typeof ReadableStream === 'undefined') return; // No Node support + + const readableStream = new ReadableStream({ + start(controller) { + controller.close(); + }, + }); + await taskRunner.setOptions({ + baseOptions: {modelAssetBuffer: readableStream.getReader()}, + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject().modelAsset?.fileContent).toEqual( + '', + ); + }); + + it('can read from ReadableStreamDefaultReader (with one chunk)', async () => { + if (typeof ReadableStream === 'undefined') return; // No Node support + + const bytes = new Uint8Array(mockBytes); + const readableStream = new ReadableStream({ + start(controller) { + controller.enqueue(bytes); + controller.close(); + }, + }); + await taskRunner.setOptions({ + baseOptions: {modelAssetBuffer: readableStream.getReader()}, + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can read from ReadableStreamDefaultReader (with two chunks)', async () => { + if (typeof ReadableStream === 'undefined') return; // No Node support + + const readableStream = new ReadableStream({ + start(controller) { + controller.enqueue(new Uint8Array([0, 1])); + controller.enqueue(new Uint8Array([2, 3])); + controller.close(); + }, + }); + await taskRunner.setOptions({ + baseOptions: {modelAssetBuffer: readableStream.getReader()}, + }); expect(fetchSpy).not.toHaveBeenCalled(); expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); }); it('changes model synchronously when bytes are provided', () => { - const resolvedPromise = taskRunner.setOptions( - {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + const resolvedPromise = taskRunner.setOptions({ + baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}, + }); // Check that the change has been applied even though we do not await the // above Promise @@ -330,10 +402,11 @@ describe('TaskRunner', () => { it('returns custom error if model download failed', () => { fetchStatus = 404; - return expectAsync(taskRunner.setOptions({ - baseOptions: {modelAssetPath: `notfound.tflite`} - })) - .toBeRejectedWithError('Failed to fetch model: notfound.tflite (404)'); + return expectAsync( + taskRunner.setOptions({ + baseOptions: {modelAssetPath: `notfound.tflite`}, + }), + ).toBeRejectedWithError('Failed to fetch model: notfound.tflite (404)'); }); it('can enable CPU delegate', async () => { @@ -341,7 +414,7 @@ describe('TaskRunner', () => { baseOptions: { modelAssetBuffer: new Uint8Array(mockBytes), delegate: 'CPU', - } + }, }); expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); }); @@ -351,10 +424,11 @@ describe('TaskRunner', () => { baseOptions: { modelAssetBuffer: new Uint8Array(mockBytes), delegate: 'GPU', - } + }, }); - expect(taskRunner.baseOptions.toObject()) - .toEqual(mockBytesResultWithGpuDelegate); + expect(taskRunner.baseOptions.toObject()).toEqual( + mockBytesResultWithGpuDelegate, + ); }); it('can reset delegate', async () => { @@ -362,7 +436,7 @@ describe('TaskRunner', () => { baseOptions: { modelAssetBuffer: new Uint8Array(mockBytes), delegate: 'GPU', - } + }, }); // Clear delegate await taskRunner.setOptions({baseOptions: {delegate: undefined}}); @@ -374,10 +448,11 @@ describe('TaskRunner', () => { baseOptions: { modelAssetBuffer: new Uint8Array(mockBytes), delegate: 'GPU', - } + }, }); await taskRunner.setOptions({baseOptions: {}}); - expect(taskRunner.baseOptions.toObject()) - .toEqual(mockBytesResultWithGpuDelegate); + expect(taskRunner.baseOptions.toObject()).toEqual( + mockBytesResultWithGpuDelegate, + ); }); }); diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts index 1f4e29542e..d64f64dbb4 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts @@ -153,7 +153,7 @@ export class LlmInference extends TaskRunner { */ static async createFromModelBuffer( wasmFileset: WasmFileset, - modelAssetBuffer: Uint8Array, + modelAssetBuffer: Uint8Array | ReadableStreamDefaultReader, ): Promise { const webgpuDevice = await LlmInference.createWebGpuDevice(); const llmInferenceOptions = { @@ -316,21 +316,32 @@ export class LlmInference extends TaskRunner { 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer', ); } + let consumedBuffer = false; if (options.baseOptions.modelAssetPath) { this.streamingReader = StreamingReader.loadFromUrl( options.baseOptions.modelAssetPath, onFinishedLoadingData, ); - } else if (options.baseOptions.modelAssetBuffer) { + } else if (options.baseOptions.modelAssetBuffer instanceof Uint8Array) { this.streamingReader = StreamingReader.loadFromArray( options.baseOptions.modelAssetBuffer, onFinishedLoadingData, ); + consumedBuffer = true; + } else if (options.baseOptions.modelAssetBuffer) { + this.streamingReader = StreamingReader.loadFromReader( + options.baseOptions.modelAssetBuffer, + onFinishedLoadingData, + ); + consumedBuffer = true; + } else { + onFinishedLoadingData(); + } + + if (consumedBuffer) { // Remove the reference on the asset buffer since it is now owned by // `streamingReader`. options.baseOptions.modelAssetBuffer = undefined; - } else { - onFinishedLoadingData(); } } // To allow graph closure across ASYNCIFY, where we cannot get a callback, diff --git a/mediapipe/web/graph_runner/graph_runner_streaming_reader.ts b/mediapipe/web/graph_runner/graph_runner_streaming_reader.ts index 64f179ef23..c38dafb6ed 100644 --- a/mediapipe/web/graph_runner/graph_runner_streaming_reader.ts +++ b/mediapipe/web/graph_runner/graph_runner_streaming_reader.ts @@ -97,12 +97,32 @@ export class StreamingReader { private readonly onFinished: () => void, ) {} + /* + * Creates a StreamingReader from a ReadableStreamDefaultReader. + * @param reader The reader + */ + static loadFromReader( + reader: ReadableStreamDefaultReader, + onFinished: () => void, + ): StreamingReader { + const fetchMore = async () => { + const {value, done} = await reader.read(); + if (done) { + return undefined; + } + return value; + }; + return new StreamingReader([], fetchMore, onFinished); + } + /* * Creates a StreamingReader from a URL. * @param url The URL to request the file. */ static loadFromUrl( - url: string, onFinished: () => void): StreamingReader { + url: string, + onFinished: () => void, + ): StreamingReader { const readerPromise = fetch(url.toString()).then( (response) => response?.body?.getReader() as ReadableStreamDefaultReader, ); @@ -148,7 +168,7 @@ export class StreamingReader { if (mode === ReadMode.DISCARD_ALL) { this.dataArray = []; this.fetchMoreData = () => Promise.resolve(undefined); - this.onFinished(); // Signal that we're done with data + this.onFinished(); // Signal that we're done with data return Promise.resolve(0); } @@ -210,9 +230,12 @@ export class StreamingReader { * @param array The file data. */ static loadFromArray( - array: Uint8Array, onFinished: () => void): StreamingReader { - return new StreamingReader([new DiscardableDataChunk(array)], () => - Promise.resolve(undefined), + array: Uint8Array, + onFinished: () => void, + ): StreamingReader { + return new StreamingReader( + [new DiscardableDataChunk(array)], + () => Promise.resolve(undefined), onFinished, ); }