Skip to content

Commit

Permalink
Introduce FixGraphBackEdges utils function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573925628
  • Loading branch information
MediaPipe Team authored and copybara-github committed Oct 16, 2023
1 parent a1e1b5d commit 2e11444
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 34 deletions.
14 changes: 14 additions & 0 deletions mediapipe/tasks/cc/core/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace core {
namespace {
constexpr char kFinishedTag[] = "FINISHED";
constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator";
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";

} // namespace

Expand Down Expand Up @@ -89,6 +90,19 @@ CalculatorGraphConfig AddFlowLimiterCalculator(
return config;
}

void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config) {
// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
for (int i = 0; i < graph_config.node_size(); ++i) {
if (graph_config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
auto* info = graph_config.mutable_node(i)->add_input_stream_info();
info->set_tag_index("LOOP");
info->set_back_edge(true);
}
}
}

} // namespace core
} // namespace tasks
} // namespace mediapipe
4 changes: 4 additions & 0 deletions mediapipe/tasks/cc/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ ::mediapipe::CalculatorGraphConfig AddFlowLimiterCalculator(
std::vector<std::string> input_stream_tags, std::string finished_stream_tag,
int max_in_flight = 1, int max_in_queue = 1);

// Fixs the graph config containing PreviousLoopbackCalculator where the edge
// forming a loop needs to be tagged as back edge.
void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config);

} // namespace core
} // namespace tasks
} // namespace mediapipe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
graph[Output<std::vector<FaceGeometry>>(kFaceGeometryTag)];
}

// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
CalculatorGraphConfig config = graph.GetConfig();
for (int i = 0; i < config.node_size(); ++i) {
if (config.node(i).calculator() == "PreviousLoopbackCalculator") {
auto* info = config.mutable_node(i)->add_input_stream_info();
info->set_tag_index(kLoopTag);
info->set_back_edge(true);
break;
}
}
core::FixGraphBackEdges(config);
return config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];

// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
CalculatorGraphConfig config = graph.GetConfig();
for (int i = 0; i < config.node_size(); ++i) {
if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
auto* info = config.mutable_node(i)->add_input_stream_info();
info->set_tag_index("LOOP");
info->set_back_edge(true);
break;
}
}
core::FixGraphBackEdges(config);
return config;
}

Expand Down
1 change: 1 addition & 0 deletions mediapipe/tasks/cc/vision/pose_landmarker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ cc_library(
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
Expand Down
14 changes: 2 additions & 12 deletions mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
Expand Down Expand Up @@ -259,19 +260,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
}

// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
CalculatorGraphConfig config = graph.GetConfig();
for (int i = 0; i < config.node_size(); ++i) {
if (config.node(i).calculator() == "PreviousLoopbackCalculator") {
auto* info = config.mutable_node(i)->add_input_stream_info();
info->set_tag_index(kLoopTag);
info->set_back_edge(true);
break;
}
}

core::FixGraphBackEdges(config);
return config;
}

Expand Down

0 comments on commit 2e11444

Please sign in to comment.