From 2ecccaf076705a273cf3d00590ae168b69ca2cfb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 27 Sep 2023 16:47:07 -0700 Subject: [PATCH] concatenate stream utility function. PiperOrigin-RevId: 568997695 --- mediapipe/calculators/core/BUILD | 1 + .../core/concatenate_proto_list_calculator.cc | 14 ++ mediapipe/framework/api2/stream/BUILD | 32 +++ mediapipe/framework/api2/stream/concatenate.h | 69 +++++++ .../framework/api2/stream/concatenate_test.cc | 188 ++++++++++++++++++ 5 files changed, 304 insertions(+) create mode 100644 mediapipe/framework/api2/stream/concatenate.h create mode 100644 mediapipe/framework/api2/stream/concatenate_test.cc diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index dcc7ab7243..b8f86b1334 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -325,6 +325,7 @@ cc_library( ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:body_rig_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc index 6c58e11107..a66e71464a 100644 --- a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc +++ b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc @@ -18,6 +18,7 @@ #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -128,6 +129,19 @@ class ConcatenateClassificationListCalculator }; MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator); +class ConcatenateJointListCalculator + : public ConcatenateListsCalculator { + protected: + int ListSize(const JointList& list) const override { + return list.joint_size(); + } + const Joint GetItem(const JointList& list, int idx) const override { + return list.joint(idx); + } + Joint* AddItem(JointList& list) const override { return list.add_joint(); } +}; +MEDIAPIPE_REGISTER_NODE(ConcatenateJointListCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index 1391165fae..5753374551 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -2,6 +2,38 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +cc_library( + name = "concatenate", + hdrs = ["concatenate.h"], + deps = [ + "//mediapipe/calculators/core:concatenate_proto_list_calculator", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:concatenate_vector_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + ], +) + +cc_test( + name = "concatenate_test", + srcs = ["concatenate_test.cc"], + deps = [ + ":concatenate", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) + cc_library( name = "detections_to_rects", srcs = ["detections_to_rects.cc"], diff --git a/mediapipe/framework/api2/stream/concatenate.h b/mediapipe/framework/api2/stream/concatenate.h new file mode 100644 index 0000000000..573d436a8d --- /dev/null +++ b/mediapipe/framework/api2/stream/concatenate.h @@ -0,0 +1,69 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_ + +#include + +#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe::api2::builder { + +namespace internal_stream_concatenate { + +// Helper function that adds a node to a graph, that is capable of concatenating +// a specific type (T). +template +GenericNode& AddConcatenateVectorNode(Graph& graph) { + if constexpr (std::is_same_v) { + return graph.AddNode("ConcatenateLandmarkListCalculator"); + } else if constexpr (std::is_same_v) { + return graph.AddNode("ConcatenateJointListCalculator"); + } else if constexpr (std::is_same_v>) { + return graph.AddNode("ConcatenateTensorVectorCalculator"); + } else { + static_assert(dependent_false::value, + "Concatenate node is not available for the specified type."); + } +} + +template +Stream Concatenate(StreamsT& streams, + const bool only_emit_if_all_present, + Graph& graph) { + auto& concatenator = AddConcatenateVectorNode(graph); + for (int i = 0; i < streams.size(); ++i) { + streams[i].ConnectTo(concatenator.In("")[i]); + } + + auto& concatenator_opts = + concatenator + .template GetOptions(); + concatenator_opts.set_only_emit_if_all_present(only_emit_if_all_present); + + return concatenator.Out("").template Cast(); +} + +} // namespace internal_stream_concatenate + +template +Stream Concatenate(StreamsT& streams, Graph& graph) { + return internal_stream_concatenate::Concatenate( + streams, /*only_emit_if_all_present=*/false, graph); +} + +template +Stream ConcatenateIfAllPresent(StreamsT& streams, Graph& graph) { + return internal_stream_concatenate::Concatenate( + streams, /*only_emit_if_all_present=*/true, graph); +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_ diff --git a/mediapipe/framework/api2/stream/concatenate_test.cc b/mediapipe/framework/api2/stream/concatenate_test.cc new file mode 100644 index 0000000000..e7785f780d --- /dev/null +++ b/mediapipe/framework/api2/stream/concatenate_test.cc @@ -0,0 +1,188 @@ +#include "mediapipe/framework/api2/stream/concatenate.h" + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(Concatenate, ConcatenateLandmarkList) { + Graph graph; + std::vector> items = { + graph.In("LMK_LIST")[0].Cast(), + graph.In("LMK_LIST")[1].Cast()}; + Stream landmark_list = Concatenate(items, graph); + landmark_list.SetName("landmark_list"); + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "ConcatenateLandmarkListCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "landmark_list" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: false + } + } + } + input_stream: "LMK_LIST:0:__stream_0" + input_stream: "LMK_LIST:1:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(Concatenate, ConcatenateLandmarkList_IfAllPresent) { + Graph graph; + std::vector> items = { + graph.In("LMK_LIST")[0].Cast(), + graph.In("LMK_LIST")[1].Cast()}; + Stream landmark_list = ConcatenateIfAllPresent(items, graph); + landmark_list.SetName("landmark_list"); + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "ConcatenateLandmarkListCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "landmark_list" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: true + } + } + } + input_stream: "LMK_LIST:0:__stream_0" + input_stream: "LMK_LIST:1:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(Concatenate, ConcatenateJointList) { + Graph graph; + std::vector> items = { + graph.In("JT_LIST")[0].Cast(), + graph.In("JT_LIST")[1].Cast()}; + Stream joint_list = Concatenate(items, graph); + joint_list.SetName("joint_list"); + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "ConcatenateJointListCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "joint_list" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: false + } + } + } + input_stream: "JT_LIST:0:__stream_0" + input_stream: "JT_LIST:1:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(Concatenate, ConcatenateJointList_IfAllPresent) { + Graph graph; + std::vector> items = { + graph.In("JT_LIST")[0].Cast(), + graph.In("JT_LIST")[1].Cast()}; + Stream joint_list = ConcatenateIfAllPresent(items, graph); + joint_list.SetName("joint_list"); + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "ConcatenateJointListCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "joint_list" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: true + } + } + } + input_stream: "JT_LIST:0:__stream_0" + input_stream: "JT_LIST:1:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(Concatenate, ConcatenateTensorVectorList) { + Graph graph; + std::vector>> items = { + graph.In("VT_LIST")[0].Cast>(), + graph.In("VT_LIST")[1].Cast>()}; + Stream> tensors = Concatenate(items, graph); + tensors.SetName("joint_list"); + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "ConcatenateTensorVectorCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "joint_list" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: false + } + } + } + input_stream: "VT_LIST:0:__stream_0" + input_stream: "VT_LIST:1:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(Concatenate, ConcatenateTensorVectorList_IfAllPresent) { + Graph graph; + std::vector>> items = { + graph.In("VT_LIST")[0].Cast>(), + graph.In("VT_LIST")[1].Cast>()}; + + Stream> tensors = ConcatenateIfAllPresent(items, graph); + tensors.SetName("joint_list"); + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "ConcatenateTensorVectorCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "joint_list" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: true + } + } + } + input_stream: "VT_LIST:0:__stream_0" + input_stream: "VT_LIST:1:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder