Skip to content

Commit

Permalink
tensor_to_joints stream utility function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569043195
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 28, 2023
1 parent 0ae9ff6 commit 66a2794
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
31 changes: 31 additions & 0 deletions mediapipe/framework/api2/stream/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,34 @@ cc_test(
"//mediapipe/framework/port:status_matchers",
],
)

cc_library(
name = "tensor_to_joints",
srcs = ["tensor_to_joints.cc"],
hdrs = ["tensor_to_joints.h"],
deps = [
"//mediapipe/calculators/tensor:tensor_to_joints_calculator",
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:tensor",
],
)

cc_test(
name = "tensor_to_joints_test",
srcs = ["tensor_to_joints_test.cc"],
deps = [
":tensor_to_joints",
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)
27 changes: 27 additions & 0 deletions mediapipe/framework/api2/stream/tensor_to_joints.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"

#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h"
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/tensor.h"

namespace mediapipe::api2::builder {

namespace {} // namespace

Stream<JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
const int num_joints,
const int start_index,
Graph& graph) {
auto& to_joints = graph.AddNode("TensorToJointsCalculator");
auto& to_joints_options =
to_joints.GetOptions<TensorToJointsCalculatorOptions>();
to_joints_options.set_num_joints(num_joints);
to_joints_options.set_start_index(start_index);
tensor.ConnectTo(to_joints[TensorToJointsCalculator::kInTensor]);
return to_joints[TensorToJointsCalculator::kOutJoints];
}

} // namespace mediapipe::api2::builder
26 changes: 26 additions & 0 deletions mediapipe/framework/api2/stream/tensor_to_joints.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_

#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/tensor.h"

namespace mediapipe::api2::builder {

// Updates @graph to convert @tensor to a JointList skipping first @start_index
// values of a @tensor.
Stream<mediapipe::JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
const int num_joints,
const int start_index,
Graph& graph);

// Updates @graph to convert @tensor to a JointList.
inline Stream<::mediapipe::JointList> ConvertTensorToJoints(
Stream<Tensor> tensor, const int num_joints, Graph& graph) {
return ConvertTensorToJointsAtIndex(tensor, num_joints, /*start_index=*/0,
graph);
}

} // namespace mediapipe::api2::builder

#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
74 changes: 74 additions & 0 deletions mediapipe/framework/api2/stream/tensor_to_joints_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"

#include <vector>

#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"

namespace mediapipe::api2::builder {
namespace {

TEST(ConvertTensorToJoints, ConvertTensorToJoints) {
Graph graph;

Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
Stream<JointList> joint_list =
ConvertTensorToJoints(tensor, /*num_joints=*/56, graph);
joint_list.SetName("joints");

EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "TensorToJointsCalculator"
input_stream: "TENSOR:__stream_0"
output_stream: "JOINTS:joints"
options {
[mediapipe.TensorToJointsCalculatorOptions.ext] {
num_joints: 56
start_index: 0
}
}
}
input_stream: "TENSOR:__stream_0"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

TEST(ConvertTensorToJointsAtIndex, ConvertTensorToJointsAtIndex) {
Graph graph;

Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
Stream<JointList> joint_list = ConvertTensorToJointsAtIndex(
tensor, /*num_joints=*/56, /*start_index=*/3, graph);
joint_list.SetName("joints");

EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "TensorToJointsCalculator"
input_stream: "TENSOR:__stream_0"
output_stream: "JOINTS:joints"
options {
[mediapipe.TensorToJointsCalculatorOptions.ext] {
num_joints: 56
start_index: 3
}
}
}
input_stream: "TENSOR:__stream_0"
)pb")));

CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}

} // namespace
} // namespace mediapipe::api2::builder

0 comments on commit 66a2794

Please sign in to comment.