-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
5 changed files
with
345 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
mediapipe/calculators/tensor/tensor_to_joints_calculator.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright 2023 The MediaPipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h" | ||
|
||
#include <utility> | ||
|
||
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h" | ||
#include "mediapipe/framework/calculator_framework.h" | ||
#include "mediapipe/framework/formats/body_rig.pb.h" | ||
#include "mediapipe/framework/formats/tensor.h" | ||
#include "mediapipe/framework/port/ret_check.h" | ||
|
||
namespace mediapipe { | ||
namespace api2 { | ||
namespace { | ||
|
||
// Number of values in 6D representation of rotation. | ||
constexpr int kRotation6dSize = 6; | ||
|
||
} // namespace | ||
|
||
class TensorToJointsCalculatorImpl | ||
: public mediapipe::api2::NodeImpl<TensorToJointsCalculator> { | ||
public: | ||
absl::Status Open(CalculatorContext* cc) override { | ||
const auto& options = cc->Options<TensorToJointsCalculatorOptions>(); | ||
|
||
// Get number of joints. | ||
RET_CHECK_GE(options.num_joints(), 0); | ||
num_joints_ = options.num_joints(); | ||
|
||
// Get start index. | ||
start_index_ = options.start_index(); | ||
|
||
return absl::OkStatus(); | ||
} | ||
|
||
absl::Status Process(CalculatorContext* cc) override { | ||
// Skip if Tensor is empty. | ||
if (kInTensor(cc).IsEmpty()) { | ||
return absl::OkStatus(); | ||
} | ||
|
||
// Get raw floats from the Tensor. | ||
const Tensor& tensor = kInTensor(cc).Get(); | ||
RET_CHECK_EQ(tensor.shape().num_elements(), | ||
num_joints_ * kRotation6dSize + start_index_) | ||
<< "Unexpected number of values in Tensor"; | ||
const float* raw_floats = tensor.GetCpuReadView().buffer<float>(); | ||
|
||
// Convert raw floats into Joint rotations. | ||
JointList joints; | ||
for (int joint_idx = 0; joint_idx < num_joints_; ++joint_idx) { | ||
Joint* joint = joints.add_joint(); | ||
for (int idx_6d = 0; idx_6d < kRotation6dSize; ++idx_6d) { | ||
joint->add_rotation_6d( | ||
raw_floats[start_index_ + joint_idx * kRotation6dSize + idx_6d]); | ||
} | ||
} | ||
|
||
kOutJoints(cc).Send(std::move(joints)); | ||
return absl::OkStatus(); | ||
} | ||
|
||
private: | ||
int num_joints_ = 0; | ||
int start_index_ = 0; | ||
}; | ||
MEDIAPIPE_NODE_IMPLEMENTATION(TensorToJointsCalculatorImpl); | ||
|
||
} // namespace api2 | ||
} // namespace mediapipe |
64 changes: 64 additions & 0 deletions
64
mediapipe/calculators/tensor/tensor_to_joints_calculator.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright 2023 The MediaPipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_ | ||
#define MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_ | ||
|
||
#include <memory> | ||
|
||
#include "mediapipe/framework/api2/node.h" | ||
#include "mediapipe/framework/calculator_framework.h" | ||
#include "mediapipe/framework/formats/body_rig.pb.h" | ||
#include "mediapipe/framework/formats/tensor.h" | ||
|
||
namespace mediapipe { | ||
namespace api2 { | ||
|
||
// A calculator to convert Tensors to JointList. | ||
// | ||
// Calculator fills in only rotation of the joints leaving visibility undefined. | ||
// | ||
// Input: | ||
// TENSOR - std::vector<Tensor> with kFloat32 values | ||
// Vector of tensors to be converted to joints. Only the first tensor will | ||
// be used. Number of values is expected to be multiple of six. | ||
// | ||
// Output: | ||
// JOINTS - JointList | ||
// List of joints with rotations extracted from given tensor and undefined | ||
// visibility. | ||
// | ||
// Example: | ||
// node { | ||
// calculator: "TensorToJointsCalculator" | ||
// input_stream: "TENSOR:tensor" | ||
// output_stream: "JOINTS:joints" | ||
// options: { | ||
// [mediapipe.TensorToJointsCalculatorOptions.ext] { | ||
// num_joints: 56 | ||
// start_index: 3 | ||
// } | ||
// } | ||
// } | ||
class TensorToJointsCalculator : public NodeIntf { | ||
public: | ||
static constexpr Input<mediapipe::Tensor> kInTensor{"TENSOR"}; | ||
static constexpr Output<mediapipe::JointList> kOutJoints{"JOINTS"}; | ||
MEDIAPIPE_NODE_INTERFACE(TensorToJointsCalculator, kInTensor, kOutJoints); | ||
}; | ||
|
||
} // namespace api2 | ||
} // namespace mediapipe | ||
|
||
#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_ |
32 changes: 32 additions & 0 deletions
32
mediapipe/calculators/tensor/tensor_to_joints_calculator.proto
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright 2023 The MediaPipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
syntax = "proto2"; | ||
|
||
package mediapipe; | ||
|
||
import "mediapipe/framework/calculator.proto"; | ||
|
||
message TensorToJointsCalculatorOptions { | ||
extend CalculatorOptions { | ||
optional TensorToJointsCalculatorOptions ext = 406440177; | ||
} | ||
|
||
// Number of joints from the output of the model. Calculator will expect the | ||
// tensor to contain `6 * num_joints + start_index` values. | ||
optional int32 num_joints = 1; | ||
|
||
// Index to start reading 6 value blocks from. | ||
optional int32 start_index = 2 [default = 0]; | ||
} |
123 changes: 123 additions & 0 deletions
123
mediapipe/calculators/tensor/tensor_to_joints_calculator_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright 2023 The MediaPipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cstdint> | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "absl/strings/substitute.h" | ||
#include "mediapipe/framework/calculator_framework.h" | ||
#include "mediapipe/framework/calculator_runner.h" | ||
#include "mediapipe/framework/formats/body_rig.pb.h" | ||
#include "mediapipe/framework/formats/tensor.h" | ||
#include "mediapipe/framework/port/gmock.h" | ||
#include "mediapipe/framework/port/gtest.h" | ||
#include "mediapipe/framework/port/parse_text_proto.h" | ||
#include "mediapipe/framework/port/status_matchers.h" | ||
#include "mediapipe/framework/timestamp.h" | ||
|
||
namespace mediapipe { | ||
namespace api2 { | ||
namespace { | ||
|
||
using Node = ::mediapipe::CalculatorGraphConfig::Node; | ||
|
||
struct TensorToJointsTestCase { | ||
std::string test_name; | ||
int num_joints; | ||
int start_index; | ||
std::vector<float> raw_values; | ||
std::vector<std::vector<float>> expected_rotations; | ||
}; | ||
|
||
using TensorToJointsTest = ::testing::TestWithParam<TensorToJointsTestCase>; | ||
|
||
TEST_P(TensorToJointsTest, TensorToJointsTest) { | ||
const TensorToJointsTestCase& tc = GetParam(); | ||
|
||
// Prepare graph. | ||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(absl::Substitute( | ||
R"( | ||
calculator: "TensorToJointsCalculator" | ||
input_stream: "TENSOR:tensor" | ||
output_stream: "JOINTS:joints" | ||
options: { | ||
[mediapipe.TensorToJointsCalculatorOptions.ext] { | ||
num_joints: $0 | ||
start_index: $1 | ||
} | ||
} | ||
)", | ||
tc.num_joints, tc.start_index))); | ||
|
||
// Prepare tensor. | ||
Tensor tensor(Tensor::ElementType::kFloat32, | ||
Tensor::Shape{1, 1, static_cast<int>(tc.raw_values.size()), 1}); | ||
float* tensor_buffer = tensor.GetCpuWriteView().buffer<float>(); | ||
ASSERT_NE(tensor_buffer, nullptr); | ||
for (int i = 0; i < tc.raw_values.size(); ++i) { | ||
tensor_buffer[i] = tc.raw_values[i]; | ||
} | ||
|
||
// Send tensor to the graph. | ||
runner.MutableInputs()->Tag("TENSOR").packets.push_back( | ||
mediapipe::MakePacket<Tensor>(std::move(tensor)).At(Timestamp(0))); | ||
|
||
// Run the graph. | ||
MP_ASSERT_OK(runner.Run()); | ||
|
||
const auto& output_packets = runner.Outputs().Tag("JOINTS").packets; | ||
EXPECT_EQ(1, output_packets.size()); | ||
|
||
const auto& joints = output_packets[0].Get<JointList>(); | ||
EXPECT_EQ(joints.joint_size(), tc.expected_rotations.size()); | ||
for (int i = 0; i < joints.joint_size(); ++i) { | ||
const Joint& joint = joints.joint(i); | ||
std::vector<float> expected_rotation_6d = tc.expected_rotations[i]; | ||
EXPECT_EQ(joint.rotation_6d_size(), expected_rotation_6d.size()) | ||
<< "Unexpected joint #" << i << " rotation"; | ||
for (int j = 0; j < joint.rotation_6d_size(); ++j) { | ||
EXPECT_EQ(joint.rotation_6d(j), expected_rotation_6d[j]) | ||
<< "Unexpected joint #" << i << " rotation"; | ||
} | ||
EXPECT_FALSE(joint.has_visibility()); | ||
} | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
TensorToJointsTests, TensorToJointsTest, | ||
testing::ValuesIn<TensorToJointsTestCase>({ | ||
{"Empty", 0, 3, {0, 0, 0}, {}}, | ||
|
||
{"Single", | ||
1, | ||
3, | ||
{0, 0, 0, 10, 11, 12, 13, 14, 15}, | ||
{{10, 11, 12, 13, 14, 15}}}, | ||
|
||
{"Double", | ||
2, | ||
3, | ||
{0, 0, 0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, | ||
{{10, 11, 12, 13, 14, 15}, {16, 17, 18, 19, 20, 21}}}, | ||
}), | ||
[](const testing::TestParamInfo<TensorToJointsTest::ParamType>& info) { | ||
return info.param.test_name; | ||
}); | ||
|
||
} // namespace | ||
} // namespace api2 | ||
} // namespace mediapipe |