Skip to content

Commit

Permalink
Introduce CombineJointsCalculator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570739088
  • Loading branch information
MediaPipe Team authored and copybara-github committed Oct 4, 2023
1 parent 7f1c170 commit c81624d
Show file tree
Hide file tree
Showing 5 changed files with 405 additions and 0 deletions.
42 changes: 42 additions & 0 deletions mediapipe/calculators/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,48 @@ cc_test(
],
)

cc_library(
name = "combine_joints_calculator",
srcs = ["combine_joints_calculator.cc"],
hdrs = ["combine_joints_calculator.h"],
deps = [
":combine_joints_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/port:ret_check",
],
alwayslink = 1,
)

mediapipe_proto_library(
name = "combine_joints_calculator_proto",
srcs = ["combine_joints_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:body_rig_proto",
],
)

cc_test(
name = "combine_joints_calculator_test",
srcs = ["combine_joints_calculator_test.cc"],
deps = [
":combine_joints_calculator",
":combine_joints_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

cc_library(
name = "pass_through_or_empty_detection_vector_calculator",
srcs = ["pass_through_or_empty_detection_vector_calculator.cc"],
Expand Down
79 changes: 79 additions & 0 deletions mediapipe/calculators/util/combine_joints_calculator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/calculators/util/combine_joints_calculator.h"

#include <utility>

#include "mediapipe/calculators/util/combine_joints_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/port/ret_check.h"

namespace mediapipe {
namespace api2 {

namespace {} // namespace

class CombineJointsCalculatorImpl : public NodeImpl<CombineJointsCalculator> {
public:
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<CombineJointsCalculatorOptions>();
RET_CHECK_GE(options_.num_joints(), 0);
RET_CHECK_GT(kInJoints(cc).Count(), 0);
RET_CHECK_EQ(kInJoints(cc).Count(), options_.joints_mapping_size());
RET_CHECK(options_.has_default_joint());
for (const auto& mapping : options_.joints_mapping()) {
for (int idx : mapping.idx()) {
RET_CHECK_GE(idx, 0);
RET_CHECK_LT(idx, options_.num_joints());
}
}
return absl::OkStatus();
}

absl::Status Process(CalculatorContext* cc) override {
// Initialize output joints with default values.
JointList out_joints;
for (int i = 0; i < options_.num_joints(); ++i) {
*out_joints.add_joint() = options_.default_joint();
}

// Override default joints with provided joints.
for (int i = 0; i < kInJoints(cc).Count(); ++i) {
// Skip empty joint streams.
if (kInJoints(cc)[i].IsEmpty()) {
continue;
}

const JointList& in_joints = kInJoints(cc)[i].Get();
const auto& mapping = options_.joints_mapping(i);
RET_CHECK_EQ(in_joints.joint_size(), mapping.idx_size());
for (int j = 0; j < in_joints.joint_size(); ++j) {
*out_joints.mutable_joint(mapping.idx(j)) = in_joints.joint(j);
}
}

kOutJoints(cc).Send(std::move(out_joints));
return absl::OkStatus();
}

private:
CombineJointsCalculatorOptions options_;
};
MEDIAPIPE_NODE_IMPLEMENTATION(CombineJointsCalculatorImpl);

} // namespace api2
} // namespace mediapipe
64 changes: 64 additions & 0 deletions mediapipe/calculators/util/combine_joints_calculator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_

#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"

namespace mediapipe {
namespace api2 {

// A calculator to combine several joint sets into one.
//
// Input:
// JOINTS - Multiple JointList
// Joint sets to combine into one. Subsets are applied in provided order and
// overwrite each other.
//
// Output:
// JOINTS - JointList
// Combined joints.
//
// Example:
// node {
// calculator: "CombineJointsCalculator"
// input_stream: "JOINTS:0:joints_0"
// input_stream: "JOINTS:1:joints_1"
// output_stream: "JOINTS:combined_joints"
// options: {
// [mediapipe.CombineJointsCalculatorOptions.ext] {
// num_joints: 63
// joints_mapping: { idx: [0, 1, 2] }
// joints_mapping: { idx: [2, 3] }
// default_joint: {
// rotation_6d: [1, 0, 0, 1, 0, 0]
// visibility: 1.0
// }
// }
// }
// }
class CombineJointsCalculator : public NodeIntf {
public:
static constexpr Input<mediapipe::JointList>::Multiple kInJoints{"JOINTS"};
static constexpr Output<mediapipe::JointList> kOutJoints{"JOINTS"};
MEDIAPIPE_NODE_INTERFACE(CombineJointsCalculator, kInJoints, kOutJoints);
};

} // namespace api2
} // namespace mediapipe

#endif // MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_
46 changes: 46 additions & 0 deletions mediapipe/calculators/util/combine_joints_calculator.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto2";

package mediapipe;

import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/formats/body_rig.proto";

message CombineJointsCalculatorOptions {
extend CalculatorOptions {
optional CombineJointsCalculatorOptions ext = 406440185;
}

// Mapping from joint set to the resulting set.
message JointsMapping {
// Indexes of provided joints in the resulting joint set.
// All indexes must be within the [0, num_joints - 1] range.
repeated int32 idx = 1 [packed = true];
}

// Number of joints in the resulting set.
optional int32 num_joints = 1;

// Mapping from joint sets to the resulting set.
// Number of mappings must be equal to number of provided joint sets. Number
// of indexes in each mapping must be equal to number of joints in
// corresponding joint set. Mappings are applied in the provided order and can
// overwrite each other.
repeated JointsMapping joints_mapping = 2;

// Default joint to initialize joints in the resulting set.
optional Joint default_joint = 3;
}
Loading

0 comments on commit c81624d

Please sign in to comment.