Skip to content

Commit

Permalink
Extend verifyGraph to be compatible with proto3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591047275
  • Loading branch information
MediaPipe Team authored and copybara-github committed Dec 14, 2023
1 parent df7fead commit 746d775
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions mediapipe/tasks/web/core/task_runner_test_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,23 @@ export interface MediapipeTasksFake {
/** An map of field paths to values */
export type FieldPathToValue = [string[] | string, unknown];

type JsonObject = Record<string, unknown>;

type Deserializer = (binaryProto: string | Uint8Array) => JsonObject;

/**
* Verifies that the graph has been initialized and that it contains the
* provided options.
*
* @param deserializer - the function to convert a binary proto to a JsonObject.
* For example, the deserializer of HolisticLandmarkerOptions's binary proto is
* HolisticLandmarkerOptions.deserializeBinary(binaryProto).toObject().
*/
export function verifyGraph(
tasksFake: MediapipeTasksFake,
expectedCalculatorOptions?: FieldPathToValue,
expectedBaseOptions?: FieldPathToValue,
deserializer?: Deserializer,
): void {
expect(tasksFake.graph).toBeDefined();
// Our graphs should have at least one node in them for processing, and
Expand All @@ -89,22 +98,30 @@ export function verifyGraph(
expect(node).toEqual(
jasmine.objectContaining({calculator: tasksFake.calculatorName}));

let proto;
if (deserializer) {
const binaryProto =
tasksFake.graph!.getNodeList()[0].getNodeOptionsList()[0].getValue();
proto = deserializer(binaryProto);
} else {
proto = (node.options as {ext: unknown}).ext;
}

if (expectedBaseOptions) {
const [fieldPath, value] = expectedBaseOptions;
let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions;
let baseOptions = (proto as {baseOptions: unknown}).baseOptions;
for (const fieldName of (
Array.isArray(fieldPath) ? fieldPath : [fieldPath])) {
proto = ((proto ?? {}) as Record<string, unknown>)[fieldName];
baseOptions = ((baseOptions ?? {}) as JsonObject)[fieldName];
}
expect(proto).toEqual(value);
expect(baseOptions).toEqual(value);
}

if (expectedCalculatorOptions) {
const [fieldPath, value] = expectedCalculatorOptions;
let proto = (node.options as {ext: unknown}).ext;
for (const fieldName of (
Array.isArray(fieldPath) ? fieldPath : [fieldPath])) {
proto = ((proto ?? {}) as Record<string, unknown>)[fieldName];
proto = ((proto ?? {}) as JsonObject)[fieldName];
}
expect(proto).toEqual(value);
}
Expand Down

0 comments on commit 746d775

Please sign in to comment.