diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index 323c36713f..a33645b2d5 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -324,3 +324,34 @@ cc_test( "//mediapipe/framework/port:status_matchers", ], ) + +cc_library( + name = "tensor_to_joints", + srcs = ["tensor_to_joints.cc"], + hdrs = ["tensor_to_joints.h"], + deps = [ + "//mediapipe/calculators/tensor:tensor_to_joints_calculator", + "//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:tensor", + ], +) + +cc_test( + name = "tensor_to_joints_test", + srcs = ["tensor_to_joints_test.cc"], + deps = [ + ":tensor_to_joints", + "//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) diff --git a/mediapipe/framework/api2/stream/tensor_to_joints.cc b/mediapipe/framework/api2/stream/tensor_to_joints.cc new file mode 100644 index 0000000000..cce4001c50 --- /dev/null +++ b/mediapipe/framework/api2/stream/tensor_to_joints.cc @@ -0,0 +1,27 @@ +#include "mediapipe/framework/api2/stream/tensor_to_joints.h" + +#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h" +#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe::api2::builder { + +namespace {} // namespace + +Stream ConvertTensorToJointsAtIndex(Stream tensor, + const int num_joints, + const int start_index, + Graph& graph) { + auto& to_joints = graph.AddNode("TensorToJointsCalculator"); + auto& to_joints_options = + to_joints.GetOptions(); + to_joints_options.set_num_joints(num_joints); + to_joints_options.set_start_index(start_index); + tensor.ConnectTo(to_joints[TensorToJointsCalculator::kInTensor]); + return to_joints[TensorToJointsCalculator::kOutJoints]; +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/tensor_to_joints.h b/mediapipe/framework/api2/stream/tensor_to_joints.h new file mode 100644 index 0000000000..54d8f822d9 --- /dev/null +++ b/mediapipe/framework/api2/stream/tensor_to_joints.h @@ -0,0 +1,26 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_ + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to convert @tensor to a JointList skipping first @start_index +// values of a @tensor. +Stream ConvertTensorToJointsAtIndex(Stream tensor, + const int num_joints, + const int start_index, + Graph& graph); + +// Updates @graph to convert @tensor to a JointList. +inline Stream<::mediapipe::JointList> ConvertTensorToJoints( + Stream tensor, const int num_joints, Graph& graph) { + return ConvertTensorToJointsAtIndex(tensor, num_joints, /*start_index=*/0, + graph); +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_ diff --git a/mediapipe/framework/api2/stream/tensor_to_joints_test.cc b/mediapipe/framework/api2/stream/tensor_to_joints_test.cc new file mode 100644 index 0000000000..d76970cf7e --- /dev/null +++ b/mediapipe/framework/api2/stream/tensor_to_joints_test.cc @@ -0,0 +1,74 @@ +#include "mediapipe/framework/api2/stream/tensor_to_joints.h" + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(ConvertTensorToJoints, ConvertTensorToJoints) { + Graph graph; + + Stream tensor = graph.In("TENSOR").Cast(); + Stream joint_list = + ConvertTensorToJoints(tensor, /*num_joints=*/56, graph); + joint_list.SetName("joints"); + + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "TensorToJointsCalculator" + input_stream: "TENSOR:__stream_0" + output_stream: "JOINTS:joints" + options { + [mediapipe.TensorToJointsCalculatorOptions.ext] { + num_joints: 56 + start_index: 0 + } + } + } + input_stream: "TENSOR:__stream_0" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(ConvertTensorToJointsAtIndex, ConvertTensorToJointsAtIndex) { + Graph graph; + + Stream tensor = graph.In("TENSOR").Cast(); + Stream joint_list = ConvertTensorToJointsAtIndex( + tensor, /*num_joints=*/56, /*start_index=*/3, graph); + joint_list.SetName("joints"); + + EXPECT_THAT(graph.GetConfig(), + EqualsProto(ParseTextProtoOrDie(R"pb( + node { + calculator: "TensorToJointsCalculator" + input_stream: "TENSOR:__stream_0" + output_stream: "JOINTS:joints" + options { + [mediapipe.TensorToJointsCalculatorOptions.ext] { + num_joints: 56 + start_index: 3 + } + } + } + input_stream: "TENSOR:__stream_0" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder