From 2e11444f5c4cb51aa6e793ba2603ad7f657ec1aa Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 16 Oct 2023 14:11:04 -0700 Subject: [PATCH] Introduce FixGraphBackEdges utils function. PiperOrigin-RevId: 573925628 --- mediapipe/tasks/cc/core/utils.cc | 14 ++++++++++++++ mediapipe/tasks/cc/core/utils.h | 4 ++++ .../face_landmarker/face_landmarker_graph.cc | 12 +----------- .../hand_landmarker/hand_landmarker_graph.cc | 12 +----------- mediapipe/tasks/cc/vision/pose_landmarker/BUILD | 1 + .../pose_landmarker/pose_landmarker_graph.cc | 14 ++------------ 6 files changed, 23 insertions(+), 34 deletions(-) diff --git a/mediapipe/tasks/cc/core/utils.cc b/mediapipe/tasks/cc/core/utils.cc index 168c4363c8..d6db1d69c3 100644 --- a/mediapipe/tasks/cc/core/utils.cc +++ b/mediapipe/tasks/cc/core/utils.cc @@ -32,6 +32,7 @@ namespace core { namespace { constexpr char kFinishedTag[] = "FINISHED"; constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator"; +constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; } // namespace @@ -89,6 +90,19 @@ CalculatorGraphConfig AddFlowLimiterCalculator( return config; } +void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config) { + // TODO remove when support is fixed. + // As mediapipe GraphBuilder currently doesn't support configuring + // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. + for (int i = 0; i < graph_config.node_size(); ++i) { + if (graph_config.node(i).calculator() == kPreviousLoopbackCalculatorName) { + auto* info = graph_config.mutable_node(i)->add_input_stream_info(); + info->set_tag_index("LOOP"); + info->set_back_edge(true); + } + } +} + } // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/utils.h b/mediapipe/tasks/cc/core/utils.h index 54d63866d9..5b56c8f575 100644 --- a/mediapipe/tasks/cc/core/utils.h +++ b/mediapipe/tasks/cc/core/utils.h @@ -84,6 +84,10 @@ ::mediapipe::CalculatorGraphConfig AddFlowLimiterCalculator( std::vector input_stream_tags, std::string finished_stream_tag, int max_in_flight = 1, int max_in_queue = 1); +// Fixs the graph config containing PreviousLoopbackCalculator where the edge +// forming a loop needs to be tagged as back edge. +void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config); + } // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc index e563ba29a0..c9a9a19326 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc @@ -393,18 +393,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kFaceGeometryTag)]; } - // TODO remove when support is fixed. - // As mediapipe GraphBuilder currently doesn't support configuring - // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. CalculatorGraphConfig config = graph.GetConfig(); - for (int i = 0; i < config.node_size(); ++i) { - if (config.node(i).calculator() == "PreviousLoopbackCalculator") { - auto* info = config.mutable_node(i)->add_input_stream_info(); - info->set_tag_index(kLoopTag); - info->set_back_edge(true); - break; - } - } + core::FixGraphBackEdges(config); return config; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index bb0eb4833f..34f7e7a9f6 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -254,18 +254,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kPalmDetectionsTag)]; hand_landmarker_outputs.image >> graph[Output(kImageTag)]; - // TODO remove when support is fixed. - // As mediapipe GraphBuilder currently doesn't support configuring - // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. CalculatorGraphConfig config = graph.GetConfig(); - for (int i = 0; i < config.node_size(); ++i) { - if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) { - auto* info = config.mutable_node(i)->add_input_stream_info(); - info->set_tag_index("LOOP"); - info->set_back_edge(true); - break; - } - } + core::FixGraphBackEdges(config); return config; } diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index a706b405e2..51ae92adc9 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -138,6 +138,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph", "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index 5e33e744bf..2f5a8b99b2 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h" @@ -259,19 +260,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kSegmentationMaskTag)]; } - // TODO remove when support is fixed. - // As mediapipe GraphBuilder currently doesn't support configuring - // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. CalculatorGraphConfig config = graph.GetConfig(); - for (int i = 0; i < config.node_size(); ++i) { - if (config.node(i).calculator() == "PreviousLoopbackCalculator") { - auto* info = config.mutable_node(i)->add_input_stream_info(); - info->set_tag_index(kLoopTag); - info->set_back_edge(true); - break; - } - } - + core::FixGraphBackEdges(config); return config; }