diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index f9bdb56131..823429ed54 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -54,8 +54,12 @@ cc_library( srcs = ["pose_landmarks_detector_graph.cc"], deps = [ "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto", "//mediapipe/calculators/core:end_loop_calculator", "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", @@ -87,6 +91,9 @@ cc_library( "//mediapipe/framework:subgraph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", + "//mediapipe/framework/api2/stream:get_vector_item", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:smoothing", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", @@ -98,6 +105,7 @@ cc_library( "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], alwayslink = 1, @@ -125,21 +133,18 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph", "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", "//mediapipe/util:graph_builder_utils", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/status", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index 7889212e8c..413835b03a 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -13,12 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include #include -#include "absl/strings/str_format.h" +#include "absl/status/status.h" #include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" #include "mediapipe/calculators/core/gate_calculator.pb.h" #include "mediapipe/calculators/util/association_calculator.pb.h" @@ -29,14 +26,11 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h" @@ -292,7 +286,9 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { auto& pose_detector = graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"); - pose_detector.GetOptions().Swap( + auto& pose_detector_options = + pose_detector.GetOptions(); + pose_detector_options.Swap( tasks_options.mutable_pose_detector_graph_options()); auto& clip_pose_rects = graph.AddNode("ClipNormalizedRectVectorSizeCalculator"); @@ -303,9 +299,23 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { auto& pose_landmarks_detector_graph = graph.AddNode( "mediapipe.tasks.vision.pose_landmarker." "MultiplePoseLandmarksDetectorGraph"); - pose_landmarks_detector_graph - .GetOptions() - .Swap(tasks_options.mutable_pose_landmarks_detector_graph_options()); + auto& pose_landmarks_detector_graph_options = + pose_landmarks_detector_graph + .GetOptions(); + pose_landmarks_detector_graph_options.Swap( + tasks_options.mutable_pose_landmarks_detector_graph_options()); + + // Apply smoothing filter only on the single pose landmarks, because + // landmarks smoothing calculator doesn't support multiple landmarks yet. + if (pose_detector_options.num_poses() == 1) { + pose_landmarks_detector_graph_options.set_smooth_landmarks( + tasks_options.base_options().use_stream_mode()); + } else if (pose_detector_options.num_poses() > 1 && + pose_landmarks_detector_graph_options.smooth_landmarks()) { + return absl::InvalidArgumentError( + "Currently pose landmarks smoothing only supports a single pose."); + } + image_in >> pose_landmarks_detector_graph.In(kImageTag); clipped_pose_rects >> pose_landmarks_detector_graph.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc index 239851b5ff..e1ac74c739 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc @@ -240,7 +240,7 @@ TEST_P(ImageModeTest, Succeeds) { } INSTANTIATE_TEST_SUITE_P( - PoseGestureTest, ImageModeTest, + PoseTest, ImageModeTest, Values(TestParams{ /* test_name= */ "Pose", /* test_image_name= */ kPoseImage, @@ -328,7 +328,7 @@ TEST_P(VideoModeTest, Succeeds) { // TODO Investigate PoseLandmarker performance in VideoMode. INSTANTIATE_TEST_SUITE_P( - PoseGestureTest, VideoModeTest, + PoseTest, VideoModeTest, Values(TestParams{ /* test_name= */ "Pose", /* test_image_name= */ kPoseImage, @@ -444,7 +444,7 @@ TEST_P(LiveStreamModeTest, Succeeds) { // Investigate PoseLandmarker performance in LiveStreamMode. INSTANTIATE_TEST_SUITE_P( - PoseGestureTest, LiveStreamModeTest, + PoseTest, LiveStreamModeTest, Values(TestParams{ /* test_name= */ "Pose", /* test_image_name= */ kPoseImage, diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc index e8397192be..cc61aa212a 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "absl/status/statusor.h" +#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h" #include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/calculators/image/warp_affine_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" @@ -26,6 +31,9 @@ limitations under the License. #include "mediapipe/calculators/util/visibility_copy_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/api2/stream/get_vector_item.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/smoothing.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -47,7 +55,10 @@ namespace pose_landmarker { using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::GetImageSize; using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::SmoothLandmarks; +using ::mediapipe::api2::builder::SmoothLandmarksVisibility; using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Stream; using ::mediapipe::tasks::core::ModelResources; @@ -213,6 +224,23 @@ void ConfigureWarpAffineCalculator( options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); } +template +Stream CreateIntConstantStream(Stream tick_stream, int constant_int, + Graph& graph) { + auto& constant_side_packet_node = + graph.AddNode("ConstantSidePacketCalculator"); + constant_side_packet_node + .GetOptions() + .add_packet() + ->set_int_value(constant_int); + auto side_packet = constant_side_packet_node.SideOut("PACKET"); + + auto& side_packet_to_stream = graph.AddNode("SidePacketToStreamCalculator"); + tick_stream.ConnectTo(side_packet_to_stream.In("TICK")); + side_packet.ConnectTo(side_packet_to_stream.SideIn("")); + return side_packet_to_stream.Out("AT_TICK").Cast(); +} + // A "mediapipe.tasks.vision.pose_landmarker.SinglePoseLandmarksDetectorGraph" // performs pose landmarks detection. // - Accepts CPU input images and outputs Landmark on CPU. @@ -669,8 +697,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto& pose_landmark_subgraph = graph.AddNode( "mediapipe.tasks.vision.pose_landmarker." "SinglePoseLandmarksDetectorGraph"); - pose_landmark_subgraph.GetOptions() - .CopyFrom(subgraph_options); + pose_landmark_subgraph.GetOptions() = + subgraph_options; image >> pose_landmark_subgraph.In(kImageTag); pose_rect >> pose_landmark_subgraph.In(kNormRectTag); auto landmarks = pose_landmark_subgraph.Out(kLandmarksTag); @@ -734,6 +762,70 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { end_loop_segmentation_mask[Output>(kIterableTag)]; } + // Apply smoothing filter only on the single pose landmarks, because + // landmarks smoothing calculator doesn't support multiple landmarks yet. + // Notice the landmarks smoothing calculator cannot be put inside the for + // loop calculator, because the smoothing calculator utilize the timestamp + // to smoote landmarks across frames but the for loop calculator makes fake + // timestamps for the streams. + if (subgraph_options.smooth_landmarks()) { + Stream> image_size = GetImageSize(image_in, graph); + Stream zero_index = + CreateIntConstantStream(landmark_lists, 0, graph); + Stream landmarks = + GetItem(landmark_lists, zero_index, graph); + Stream world_landmarks = + GetItem(world_landmark_lists, zero_index, graph); + Stream roi = + GetItem(pose_rects_next_frame, zero_index, graph); + + // Apply smoothing filter on pose landmarks. + landmarks = SmoothLandmarksVisibility( + landmarks, /*low_pass_filter_alpha=*/0.1f, graph); + landmarks = SmoothLandmarks( + landmarks, image_size, roi, + {// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter + // when landmark is static. + /*min_cutoff=*/0.05f, + // Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94 + // alpha in landmark EMA filter when landmark is moving fast. + /*beta=*/80.0f, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark + // velocity EMA filter. + /*derivate_cutoff=*/1.0f}, + graph); + + // Apply smoothing filter on pose world landmarks. + world_landmarks = SmoothLandmarksVisibility( + world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph); + world_landmarks = SmoothLandmarks( + world_landmarks, + /*scale_roi=*/std::nullopt, + {// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter + // when landmark is static. + /*min_cutoff=*/0.1f, + // Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8 + // alpha in landmark EMA filter when landmark is moving fast. + /*beta=*/40.0f, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark + // velocity EMA filter. + /*derivate_cutoff=*/1.0f}, + graph); + + // Wrap the single pose landmarks into a vector of landmarks. + auto& concat_landmarks = + graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator"); + landmarks >> concat_landmarks.In(""); + landmark_lists = + concat_landmarks.Out("").Cast>(); + + auto& concat_world_landmarks = + graph.AddNode("ConcatenateLandmarkListVectorCalculator"); + world_landmarks >> concat_world_landmarks.In(""); + world_landmark_lists = + concat_world_landmarks.Out("").Cast>(); + } + return {{ /* landmark_lists= */ landmark_lists, /* world_landmark_lists= */ world_landmark_lists, diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto index 9eb835d6af..bcb2459969 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto @@ -35,4 +35,11 @@ message PoseLandmarksDetectorGraphOptions { // Minimum confidence value ([0.0, 1.0]) for pose presence score to be // considered successfully detecting a pose in the image. optional float min_detection_confidence = 2 [default = 0.5]; + + // Whether to smooth the detected landmarks over timestamps. Note that + // landmarks smoothing is only applicable for a single pose. If multiple poses + // landmarks are given, and smooth_landmarks is true, only the first pose + // landmarks would be smoothed, and the remaining landmarks are discarded in + // the returned landmarks list. + optional bool smooth_landmarks = 3; }