Skip to content

Commit

Permalink
Allow models to be uploaded via ReadableStreamDefaultReader
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631097139
  • Loading branch information
schmidt-sebastian authored and copybara-github committed May 6, 2024
1 parent e323ace commit fb75f14
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 128 deletions.
194 changes: 137 additions & 57 deletions mediapipe/tasks/web/core/task_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@ import {CalculatorGraphConfig} from '../../../framework/calculator_pb';
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner';
import {
BaseOptions,
TaskRunnerOptions,
} from '../../../tasks/web/core/task_runner_options';
import {
FileLocator,
GraphRunner,
WasmMediaPipeConstructor,
createMediaPipeLib,
} from '../../../web/graph_runner/graph_runner';
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';

import {WasmFileset} from './wasm_fileset';
Expand All @@ -47,9 +55,11 @@ export class CachedGraphRunner extends CachedGraphRunnerType {}
* @return A fully instantiated instance of `T`.
*/
export async function createTaskRunner<T extends TaskRunner>(
type: WasmMediaPipeConstructor<T>,
canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined,
fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> {
type: WasmMediaPipeConstructor<T>,
canvas: HTMLCanvasElement | OffscreenCanvas | null | undefined,
fileset: WasmFileset,
options: TaskRunnerOptions,
): Promise<T> {
const fileLocator: FileLocator = {
locateFile(file): string {
// We currently only use a single .wasm file and a single .data file (for
Expand All @@ -62,12 +72,16 @@ export async function createTaskRunner<T extends TaskRunner>(
return fileset.assetBinaryPath.toString();
}
return file;
}
},
};

const instance = await createMediaPipeLib(
type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas,
fileLocator);
type,
fileset.wasmLoaderPath,
fileset.assetLoaderPath,
canvas,
fileLocator,
);
await instance.setOptions(options);
return instance;
}
Expand All @@ -85,9 +99,11 @@ export abstract class TaskRunner {
* @return A fully instantiated instance of `T`.
*/
protected static async createInstance<T extends TaskRunner>(
type: WasmMediaPipeConstructor<T>,
canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined,
fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> {
type: WasmMediaPipeConstructor<T>,
canvas: HTMLCanvasElement | OffscreenCanvas | null | undefined,
fileset: WasmFileset,
options: TaskRunnerOptions,
): Promise<T> {
return createTaskRunner(type, canvas, fileset, options);
}

Expand All @@ -112,56 +128,78 @@ export abstract class TaskRunner {
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/
protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true):
Promise<void> {
protected applyOptions(
options: TaskRunnerOptions,
loadTfliteModel = true,
): Promise<void> {
if (loadTfliteModel) {
const baseOptions: BaseOptions = options.baseOptions || {};

// Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) {
if (
options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath
) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer',
);
} else if (
!(
this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath
)
) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set',
);
}

this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => {
if (!response.ok) {
throw new Error(`Failed to fetch model: ${
baseOptions.modelAssetPath} (${response.status})`);
} else {
return response.arrayBuffer();
}
})
.then(buffer => {
try {
// Try to delete file as we cannot overwrite an existing file
// using our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
} else {
.then((response) => {
if (!response.ok) {
throw new Error(
`Failed to fetch model: ${baseOptions.modelAssetPath} (${response.status})`,
);
} else {
return response.arrayBuffer();
}
})
.then((buffer) => {
try {
// Try to delete file as we cannot overwrite an existing file
// using our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/',
'model.dat',
new Uint8Array(buffer),
/* canRead= */ true,
/* canWrite= */ false,
/* canOwn= */ false,
);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
} else if (baseOptions.modelAssetBuffer instanceof Uint8Array) {
this.setExternalFile(baseOptions.modelAssetBuffer);
} else if (baseOptions.modelAssetBuffer) {
return streamToUint8Array(baseOptions.modelAssetBuffer).then(
(buffer) => {
this.setExternalFile(buffer);
this.refreshGraph();
this.onGraphRefreshed();
},
);
}
}

Expand All @@ -182,8 +220,8 @@ export abstract class TaskRunner {

/** Returns the current CalculatorGraphConfig. */
protected getCalculatorGraphConfig(): CalculatorGraphConfig {
let config: CalculatorGraphConfig|undefined;
this.graphRunner.getCalculatorGraphConfig(binaryData => {
let config: CalculatorGraphConfig | undefined;
this.graphRunner.getCalculatorGraphConfig((binaryData) => {
config = CalculatorGraphConfig.deserializeBinary(binaryData);
});
if (!config) {
Expand Down Expand Up @@ -232,8 +270,10 @@ export abstract class TaskRunner {
* ignored.
*/
protected setLatestOutputTimestamp(timestamp: number): void {
this.latestOutputTimestamp =
Math.max(this.latestOutputTimestamp, timestamp);
this.latestOutputTimestamp = Math.max(
this.latestOutputTimestamp,
timestamp,
);
}

/**
Expand All @@ -254,8 +294,9 @@ export abstract class TaskRunner {
throw new Error(this.processingErrors[0].message);
} else if (errorCount > 1) {
throw new Error(
'Encountered multiple errors: ' +
this.processingErrors.map(e => e.message).join(', '));
'Encountered multiple errors: ' +
this.processingErrors.map((e) => e.message).join(', '),
);
}
} finally {
this.processingErrors = [];
Expand All @@ -265,7 +306,9 @@ export abstract class TaskRunner {
/** Configures the `externalFile` option */
protected setExternalFile(modelAssetPath?: string): void;
protected setExternalFile(modelAssetBuffer?: Uint8Array): void;
protected setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void {
protected setExternalFile(
modelAssetPathOrBuffer?: Uint8Array | string,
): void {
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
if (typeof modelAssetPathOrBuffer === 'string') {
externalFile.setFileName(modelAssetPathOrBuffer);
Expand All @@ -292,7 +335,8 @@ export abstract class TaskRunner {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(
new InferenceCalculatorOptions.Delegate.TfLite());
new InferenceCalculatorOptions.Delegate.TfLite(),
);
}
}

Expand All @@ -309,7 +353,8 @@ export abstract class TaskRunner {
this.keepaliveNode.setCalculator('PassThroughCalculator');
this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM);
this.keepaliveNode.addOutputStream(
FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX);
FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX,
);
graphConfig.addInputStream(FREE_MEMORY_STREAM);
graphConfig.addNode(this.keepaliveNode);
}
Expand All @@ -323,7 +368,10 @@ export abstract class TaskRunner {
/** Frees any streams being kept alive by the keepStreamAlive callback. */
protected freeKeepaliveStreams() {
this.graphRunner.addBoolToStream(
true, FREE_MEMORY_STREAM, this.latestOutputTimestamp);
true,
FREE_MEMORY_STREAM,
this.latestOutputTimestamp,
);
}

/**
Expand All @@ -336,4 +384,36 @@ export abstract class TaskRunner {
}
}

/** Converts a ReadableStreamDefaultReader to a Uint8Array. */
async function streamToUint8Array(
reader: ReadableStreamDefaultReader<Uint8Array>,
): Promise<Uint8Array> {
const chunks: Uint8Array[] = [];
let totalLength = 0;

while (true) {
const {done, value} = await reader.read();
if (done) {
break;
}
chunks.push(value);
totalLength += value.length;
}

if (chunks.length === 0) {
return new Uint8Array(0);
} else if (chunks.length === 1) {
return chunks[0];
} else {
// Merge chunks
const combined = new Uint8Array(totalLength);
let offset = 0;
for (const chunk of chunks) {
combined.set(chunk, offset);
offset += chunk.length;
}
return combined;
}
}


10 changes: 5 additions & 5 deletions mediapipe/tasks/web/core/task_runner_options.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ export declare interface BaseOptions {
* The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetPath?: string|undefined;
modelAssetPath?: string | undefined;

/**
* A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
* A buffer or stream reader containing the model asset. Only one of
* `modelAssetPath` or `modelAssetBuffer` can be set.
*/
modelAssetBuffer?: Uint8Array|undefined;
modelAssetBuffer?: Uint8Array | ReadableStreamDefaultReader | undefined;

/** Overrides the default backend to use for the provided model. */
delegate?: 'CPU'|'GPU'|undefined;
delegate?: 'CPU' | 'GPU' | undefined;
}

/** Options to configure MediaPipe Tasks in general. */
Expand Down
Loading

0 comments on commit fb75f14

Please sign in to comment.