Skip to content

Commit

Permalink
concatenate stream utility function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568997695
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 27, 2023
1 parent 983fda5 commit 2ecccaf
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 0 deletions.
1 change: 1 addition & 0 deletions mediapipe/calculators/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ cc_library(
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
Expand Down
14 changes: 14 additions & 0 deletions mediapipe/calculators/core/concatenate_proto_list_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
Expand Down Expand Up @@ -128,6 +129,19 @@ class ConcatenateClassificationListCalculator
};
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);

class ConcatenateJointListCalculator
: public ConcatenateListsCalculator<Joint, JointList> {
protected:
int ListSize(const JointList& list) const override {
return list.joint_size();
}
const Joint GetItem(const JointList& list, int idx) const override {
return list.joint(idx);
}
Joint* AddItem(JointList& list) const override { return list.add_joint(); }
};
MEDIAPIPE_REGISTER_NODE(ConcatenateJointListCalculator);

} // namespace api2
} // namespace mediapipe

Expand Down
32 changes: 32 additions & 0 deletions mediapipe/framework/api2/stream/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@ package(default_visibility = ["//visibility:public"])

licenses(["notice"])

cc_library(
name = "concatenate",
hdrs = ["concatenate.h"],
deps = [
"//mediapipe/calculators/core:concatenate_proto_list_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
],
)

cc_test(
name = "concatenate_test",
srcs = ["concatenate_test.cc"],
deps = [
":concatenate",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)

cc_library(
name = "detections_to_rects",
srcs = ["detections_to_rects.cc"],
Expand Down
69 changes: 69 additions & 0 deletions mediapipe/framework/api2/stream/concatenate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_

#include <vector>

#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"

namespace mediapipe::api2::builder {

namespace internal_stream_concatenate {

// Helper function that adds a node to a graph, that is capable of concatenating
// a specific type (T).
template <class T>
GenericNode& AddConcatenateVectorNode(Graph& graph) {
if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
return graph.AddNode("ConcatenateLandmarkListCalculator");
} else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
return graph.AddNode("ConcatenateJointListCalculator");
} else if constexpr (std::is_same_v<T, std::vector<Tensor>>) {
return graph.AddNode("ConcatenateTensorVectorCalculator");
} else {
static_assert(dependent_false<T>::value,
"Concatenate node is not available for the specified type.");
}
}

template <typename StreamsT,
typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> Concatenate(StreamsT& streams,
const bool only_emit_if_all_present,
Graph& graph) {
auto& concatenator = AddConcatenateVectorNode<PayloadT>(graph);
for (int i = 0; i < streams.size(); ++i) {
streams[i].ConnectTo(concatenator.In("")[i]);
}

auto& concatenator_opts =
concatenator
.template GetOptions<mediapipe::ConcatenateVectorCalculatorOptions>();
concatenator_opts.set_only_emit_if_all_present(only_emit_if_all_present);

return concatenator.Out("").template Cast<PayloadT>();
}

} // namespace internal_stream_concatenate

template <typename StreamsT,
typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> Concatenate(StreamsT& streams, Graph& graph) {
return internal_stream_concatenate::Concatenate(
streams, /*only_emit_if_all_present=*/false, graph);
}

template <typename StreamsT,
typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> ConcatenateIfAllPresent(StreamsT& streams, Graph& graph) {
return internal_stream_concatenate::Concatenate(
streams, /*only_emit_if_all_present=*/true, graph);
}

} // namespace mediapipe::api2::builder

#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
188 changes: 188 additions & 0 deletions mediapipe/framework/api2/stream/concatenate_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#include "mediapipe/framework/api2/stream/concatenate.h"

#include <vector>

#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"

namespace mediapipe::api2::builder {
namespace {

TEST(Concatenate, ConcatenateLandmarkList) {
Graph graph;
std::vector<Stream<LandmarkList>> items = {
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
Stream<LandmarkList> landmark_list = Concatenate(items, graph);
landmark_list.SetName("landmark_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateLandmarkListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "landmark_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: false
}
}
}
input_stream: "LMK_LIST:0:__stream_0"
input_stream: "LMK_LIST:1:__stream_1"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

TEST(Concatenate, ConcatenateLandmarkList_IfAllPresent) {
Graph graph;
std::vector<Stream<LandmarkList>> items = {
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
Stream<LandmarkList> landmark_list = ConcatenateIfAllPresent(items, graph);
landmark_list.SetName("landmark_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateLandmarkListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "landmark_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
input_stream: "LMK_LIST:0:__stream_0"
input_stream: "LMK_LIST:1:__stream_1"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

TEST(Concatenate, ConcatenateJointList) {
Graph graph;
std::vector<Stream<JointList>> items = {
graph.In("JT_LIST")[0].Cast<JointList>(),
graph.In("JT_LIST")[1].Cast<JointList>()};
Stream<JointList> joint_list = Concatenate(items, graph);
joint_list.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateJointListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: false
}
}
}
input_stream: "JT_LIST:0:__stream_0"
input_stream: "JT_LIST:1:__stream_1"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

TEST(Concatenate, ConcatenateJointList_IfAllPresent) {
Graph graph;
std::vector<Stream<JointList>> items = {
graph.In("JT_LIST")[0].Cast<JointList>(),
graph.In("JT_LIST")[1].Cast<JointList>()};
Stream<JointList> joint_list = ConcatenateIfAllPresent(items, graph);
joint_list.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateJointListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
input_stream: "JT_LIST:0:__stream_0"
input_stream: "JT_LIST:1:__stream_1"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

TEST(Concatenate, ConcatenateTensorVectorList) {
Graph graph;
std::vector<Stream<std::vector<Tensor>>> items = {
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
Stream<std::vector<Tensor>> tensors = Concatenate(items, graph);
tensors.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateTensorVectorCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: false
}
}
}
input_stream: "VT_LIST:0:__stream_0"
input_stream: "VT_LIST:1:__stream_1"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

TEST(Concatenate, ConcatenateTensorVectorList_IfAllPresent) {
Graph graph;
std::vector<Stream<std::vector<Tensor>>> items = {
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};

Stream<std::vector<Tensor>> tensors = ConcatenateIfAllPresent(items, graph);
tensors.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateTensorVectorCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
input_stream: "VT_LIST:0:__stream_0"
input_stream: "VT_LIST:1:__stream_1"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

} // namespace
} // namespace mediapipe::api2::builder

0 comments on commit 2ecccaf

Please sign in to comment.