Skip to content

Commit

Permalink
smoothing stream utility function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569074973
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 28, 2023
1 parent 9edb4cd commit a577dc3
Show file tree
Hide file tree
Showing 4 changed files with 517 additions and 0 deletions.
37 changes: 37 additions & 0 deletions mediapipe/framework/api2/stream/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,43 @@ cc_test(
],
)

cc_library(
name = "smoothing",
srcs = ["smoothing.cc"],
hdrs = ["smoothing.h"],
deps = [
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto",
"//mediapipe/calculators/util:multi_landmarks_smoothing_calculator",
"//mediapipe/calculators/util:multi_world_landmarks_smoothing_calculator",
"//mediapipe/calculators/util:visibility_smoothing_calculator",
"//mediapipe/calculators/util:visibility_smoothing_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"@com_google_absl//absl/types:optional",
],
)

cc_test(
name = "smoothing_test",
srcs = ["smoothing_test.cc"],
deps = [
":smoothing",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/types:optional",
],
)

cc_library(
name = "split",
hdrs = ["split.h"],
Expand Down
131 changes: 131 additions & 0 deletions mediapipe/framework/api2/stream/smoothing.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "mediapipe/framework/api2/stream/smoothing.h"

#include <optional>
#include <utility>
#include <vector>

#include "absl/types/optional.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/calculators/util/visibility_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/landmark.pb.h"

namespace mediapipe::api2::builder {

namespace {

void SetFilterConfig(const OneEuroFilterConfig& config,
bool disable_value_scaling, GenericNode& node) {
auto& smoothing_node_opts =
node.GetOptions<LandmarksSmoothingCalculatorOptions>();
auto& one_euro_filter = *smoothing_node_opts.mutable_one_euro_filter();
one_euro_filter.set_min_cutoff(config.min_cutoff);
one_euro_filter.set_derivate_cutoff(config.derivate_cutoff);
one_euro_filter.set_beta(config.beta);
one_euro_filter.set_disable_value_scaling(disable_value_scaling);
}

void SetFilterConfig(const LandmarksSmoothingCalculatorOptions& config,
GenericNode& node) {
auto& smoothing_node_opts =
node.GetOptions<LandmarksSmoothingCalculatorOptions>();
smoothing_node_opts = config;
}

GenericNode& AddVisibilitySmoothingNode(float low_pass_filter_alpha,
Graph& graph) {
auto& smoothing_node = graph.AddNode("VisibilitySmoothingCalculator");
auto& smoothing_node_opts =
smoothing_node.GetOptions<VisibilitySmoothingCalculatorOptions>();
smoothing_node_opts.mutable_low_pass_filter()->set_alpha(
low_pass_filter_alpha);
return smoothing_node;
}

} // namespace

Stream<NormalizedLandmarkList> SmoothLandmarks(
Stream<NormalizedLandmarkList> landmarks,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph) {
auto& smoothing_node = graph.AddNode("LandmarksSmoothingCalculator");
SetFilterConfig(config, /*disable_value_scaling=*/false, smoothing_node);

landmarks.ConnectTo(smoothing_node.In("NORM_LANDMARKS"));
image_size.ConnectTo(smoothing_node.In("IMAGE_SIZE"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("NORM_FILTERED_LANDMARKS")
.Cast<NormalizedLandmarkList>();
}

Stream<LandmarkList> SmoothLandmarks(
Stream<LandmarkList> landmarks,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph) {
auto& smoothing_node = graph.AddNode("LandmarksSmoothingCalculator");
SetFilterConfig(config, /*disable_value_scaling=*/true, smoothing_node);

landmarks.ConnectTo(smoothing_node.In("LANDMARKS"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("FILTERED_LANDMARKS").Cast<LandmarkList>();
}

Stream<std::vector<NormalizedLandmarkList>> SmoothMultiLandmarks(
Stream<std::vector<NormalizedLandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<std::vector<NormalizedRect>>> scale_roi,
const LandmarksSmoothingCalculatorOptions& config, Graph& graph) {
auto& smoothing_node = graph.AddNode("MultiLandmarksSmoothingCalculator");
SetFilterConfig(config, smoothing_node);

landmarks.ConnectTo(smoothing_node.In("NORM_LANDMARKS"));
tracking_ids.ConnectTo(smoothing_node.In("TRACKING_IDS"));
image_size.ConnectTo(smoothing_node.In("IMAGE_SIZE"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("NORM_FILTERED_LANDMARKS")
.Cast<std::vector<NormalizedLandmarkList>>();
}

Stream<std::vector<LandmarkList>> SmoothMultiWorldLandmarks(
Stream<std::vector<LandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
std::optional<Stream<std::vector<Rect>>> scale_roi,
const LandmarksSmoothingCalculatorOptions& config, Graph& graph) {
auto& smoothing_node =
graph.AddNode("MultiWorldLandmarksSmoothingCalculator");
SetFilterConfig(config, smoothing_node);

landmarks.ConnectTo(smoothing_node.In("LANDMARKS"));
tracking_ids.ConnectTo(smoothing_node.In("TRACKING_IDS"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("FILTERED_LANDMARKS")
.Cast<std::vector<LandmarkList>>();
}

Stream<NormalizedLandmarkList> SmoothLandmarksVisibility(
Stream<NormalizedLandmarkList> landmarks, float low_pass_filter_alpha,
Graph& graph) {
auto& node = AddVisibilitySmoothingNode(low_pass_filter_alpha, graph);
landmarks.ConnectTo(node.In("NORM_LANDMARKS"));
return node.Out("NORM_FILTERED_LANDMARKS").Cast<NormalizedLandmarkList>();
}

Stream<LandmarkList> SmoothLandmarksVisibility(Stream<LandmarkList> landmarks,
float low_pass_filter_alpha,
Graph& graph) {
auto& node = AddVisibilitySmoothingNode(low_pass_filter_alpha, graph);
landmarks.ConnectTo(node.In("LANDMARKS"));
return node.Out("FILTERED_LANDMARKS").Cast<LandmarkList>();
}

} // namespace mediapipe::api2::builder
119 changes: 119 additions & 0 deletions mediapipe/framework/api2/stream/smoothing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_

#include <cstdint>
#include <optional>
#include <utility>
#include <vector>

#include "absl/types/optional.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"

namespace mediapipe::api2::builder {

struct OneEuroFilterConfig {
float min_cutoff;
float beta;
float derivate_cutoff;
};

// Updates graph to smooth normalized landmarks and returns resulting stream.
//
// @landmarks - normalized landmarks.
// @image_size - size of image where landmarks were detected.
// @scale_roi - can be used to specify object scale.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered normalized landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<mediapipe::NormalizedLandmarkList> SmoothLandmarks(
Stream<mediapipe::NormalizedLandmarkList> landmarks,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph);

// Updates graph to smooth absolute landmarks and returns resulting stream.
//
// @landmarks - absolute landmarks.
// @scale_roi - can be used to specify object scale.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered absolute landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<mediapipe::LandmarkList> SmoothLandmarks(
Stream<mediapipe::LandmarkList> landmarks,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph);

// Updates graph to smooth normalized landmarks and returns resulting stream.
//
// @landmarks - normalized landmarks vector.
// @tracking_ids - tracking IDs associated with landmarks
// @image_size - size of image where landmarks were detected.
// @scale_roi - can be used to specify object scales.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered normalized landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<std::vector<mediapipe::NormalizedLandmarkList>> SmoothMultiLandmarks(
Stream<std::vector<mediapipe::NormalizedLandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<std::vector<NormalizedRect>>> scale_roi,
const mediapipe::LandmarksSmoothingCalculatorOptions& config, Graph& graph);

// Updates graph to smooth absolute landmarks and returns resulting stream.
//
// @landmarks - absolute landmarks vector.
// @tracking_ids - tracking IDs associated with landmarks
// @scale_roi - can be used to specify object scales.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered absolute landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<std::vector<mediapipe::LandmarkList>> SmoothMultiWorldLandmarks(
Stream<std::vector<mediapipe::LandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
std::optional<Stream<std::vector<mediapipe::Rect>>> scale_roi,
const mediapipe::LandmarksSmoothingCalculatorOptions& config, Graph& graph);

// Updates graph to smooth visibility of landmarks.
//
// @landmarks - normalized landmarks.
// @low_pass_filter_alpha - low pass filter alpha to use for smoothing.
// @graph - graph to update.
//
// Returns: normalized landmarks containing smoothed visibility.
Stream<mediapipe::NormalizedLandmarkList> SmoothLandmarksVisibility(
Stream<mediapipe::NormalizedLandmarkList> landmarks,
float low_pass_filter_alpha, Graph& graph);

// Updates graph to smooth visibility of landmarks.
//
// @landmarks - absolute landmarks.
// @low_pass_filter_alpha - low pass filter alpha to use for smoothing.
// @graph - graph to update.
//
// Returns: absolute landmarks containing smoothed visibility.
Stream<mediapipe::LandmarkList> SmoothLandmarksVisibility(
Stream<mediapipe::LandmarkList> landmarks, float low_pass_filter_alpha,
mediapipe::api2::builder::Graph& graph);

} // namespace mediapipe::api2::builder

#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_
Loading

0 comments on commit a577dc3

Please sign in to comment.