Skip to content

Commit

Permalink
Switch inference tests to ResourceProviderCalculator & update builder…
Browse files Browse the repository at this point in the history
… to refer MODEL_RESOURCE.

PiperOrigin-RevId: 675740211
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 17, 2024
1 parent f621694 commit 313c72d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 31 deletions.
18 changes: 6 additions & 12 deletions mediapipe/calculators/tensor/inference_calculator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,18 @@ constexpr char kGraphWithModelAsInputSidePacket[] = R"(
input_stream: "tensor_in"
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { string_value: "mediapipe/calculators/tensor/testdata/add.bin" }
calculator: "ResourceProviderCalculator"
output_side_packet: "RESOURCE:model_resource"
node_options {
[type.googleapis.com/mediapipe.ResourceProviderCalculatorOptions]: {
resource_id: "mediapipe/calculators/tensor/testdata/add.bin"
}
}
}
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
}
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
input_side_packet: "MODEL_RESOURCE:model_resource"
output_side_packet: "MODEL:model"
}
Expand Down
3 changes: 1 addition & 2 deletions mediapipe/calculators/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,7 @@ cc_test(
":tflite_inference_calculator",
":tflite_inference_calculator_cc_proto",
":tflite_model_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/calculators/util:resource_provider_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
Expand Down
18 changes: 6 additions & 12 deletions mediapipe/calculators/tflite/tflite_inference_calculator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,18 @@ TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
input_stream: "tensor_in"
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { string_value: "mediapipe/calculators/tflite/testdata/add.bin" }
calculator: "ResourceProviderCalculator"
output_side_packet: "RESOURCE:model_resource"
node_options {
[type.googleapis.com/mediapipe.ResourceProviderCalculatorOptions]: {
resource_id: "mediapipe/calculators/tflite/testdata/add.bin"
}
}
}
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
}
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
input_side_packet: "MODEL_RESOURCE:model_resource"
output_side_packet: "MODEL:model"
}
Expand Down
10 changes: 5 additions & 5 deletions mediapipe/framework/api2/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,18 @@ using SideSource = SourceImpl<true, T>;
// parts utility/convenience functions or classes.
//
// For example:
// SidePacket<TfLiteModelPtr> GetModel(SidePacket<std::string> model_blob,
// SidePacket<TfLiteModelPtr> GetModel(SidePacket<Resource> model_resource,
// Graph& graph) {
// auto& model_node = graph.AddNode("TfLiteModelCalculator");
// model_blob >> model_node.SideIn("MODEL_BLOB");
// model_resource >> model_node.SideIn("MODEL_RESOURCE");
// return model_node.SideOut("MODEL").Cast<TfLiteModelPtr>();
// }
//
// Where graph can use it as:
// Graph graph;
// SidePacket<std::string> model_blob =
// graph.SideIn("MODEL_BLOB").Cast<std::string>();
// SidePacket<TfLiteModelPtr> model = GetModel(model_blob, graph);
// SidePacket<Resource> model_resource =
// graph.SideIn("MODEL_RESOURCE").Cast<Resource>();
// SidePacket<TfLiteModelPtr> model = GetModel(model_resource, graph);
template <typename T>
using SidePacket = SideSource<T>;

Expand Down

0 comments on commit 313c72d

Please sign in to comment.