diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 81a68f03f9..06bfa0b4d1 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -14,10 +14,12 @@ #include "mediapipe/calculators/core/packet_resampler_calculator.h" +#include #include #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" +#include "mediapipe/framework/port/ret_check.h" namespace { // Reflect an integer against the lower and upper bound of an interval. @@ -76,6 +78,8 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { } cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id)); if (cc->Outputs().HasTag(kVideoHeaderTag)) { + RET_CHECK(resampler_options.max_frame_rate() <= 0) + << "VideoHeader output is not supported with max_frame_rate."; cc->Outputs().Tag(kVideoHeaderTag).Set(); } @@ -88,24 +92,13 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { return absl::OkStatus(); } -absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { - const auto resampler_options = - tool::RetrieveOptions(cc->Options(), - cc->InputSidePackets(), "OPTIONS"); - - flush_last_packet_ = resampler_options.flush_last_packet(); - jitter_ = resampler_options.jitter(); - - input_data_id_ = cc->Inputs().GetId("DATA", 0); - if (!input_data_id_.IsValid()) { - input_data_id_ = cc->Inputs().GetId("", 0); +absl::Status PacketResamplerCalculator::UpdateFrameRate( + const PacketResamplerCalculatorOptions& resampler_options, + double frame_rate) { + frame_rate_ = frame_rate; + if (resampler_options.max_frame_rate() > 0) { + frame_rate_ = std::min(frame_rate_, resampler_options.max_frame_rate()); } - output_data_id_ = cc->Outputs().GetId("DATA", 0); - if (!output_data_id_.IsValid()) { - output_data_id_ = cc->Outputs().GetId("", 0); - } - - frame_rate_ = resampler_options.frame_rate(); start_time_ = resampler_options.has_start_time() ? Timestamp(resampler_options.start_time()) : Timestamp::Min(); @@ -125,6 +118,28 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { RET_CHECK_LE(jitter_usec_, frame_time_usec_); video_header_.frame_rate = frame_rate_; + return absl::OkStatus(); +} + +absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + + flush_last_packet_ = resampler_options.flush_last_packet(); + jitter_ = resampler_options.jitter(); + + input_data_id_ = cc->Inputs().GetId("DATA", 0); + if (!input_data_id_.IsValid()) { + input_data_id_ = cc->Inputs().GetId("", 0); + } + output_data_id_ = cc->Outputs().GetId("DATA", 0); + if (!output_data_id_.IsValid()) { + output_data_id_ = cc->Outputs().GetId("", 0); + } + + RET_CHECK_OK( + UpdateFrameRate(resampler_options, resampler_options.frame_rate())); if (resampler_options.output_header() != PacketResamplerCalculatorOptions::NONE && @@ -150,10 +165,18 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { } absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { + const auto resampler_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); + if (cc->InputTimestamp() == Timestamp::PreStream() && cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) && !cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) { video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get(); + if (resampler_options.use_input_frame_rate()) { + RET_CHECK_OK( + UpdateFrameRate(resampler_options, video_header_.frame_rate)); + } video_header_.frame_rate = frame_rate_; if (cc->Inputs().Get(input_data_id_).IsEmpty()) { return absl::OkStatus(); diff --git a/mediapipe/calculators/core/packet_resampler_calculator.h b/mediapipe/calculators/core/packet_resampler_calculator.h index 1cf425b5e5..5635d1fb81 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.h +++ b/mediapipe/calculators/core/packet_resampler_calculator.h @@ -146,6 +146,14 @@ class PacketResamplerCalculator : public CalculatorBase { const mediapipe::PacketResamplerCalculatorOptions& options); private: + // Updates the frame rate of the calculator. + // + // This updates the metadata of the frame rate of the calculator moving + // forward. All already processed packets will be ignored. + absl::Status UpdateFrameRate( + const mediapipe::PacketResamplerCalculatorOptions& resampler_options, + double frame_rate); + std::unique_ptr strategy_; // The timestamp of the first packet received. diff --git a/mediapipe/calculators/core/packet_resampler_calculator.proto b/mediapipe/calculators/core/packet_resampler_calculator.proto index 29ca8082a0..97f717adc7 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.proto +++ b/mediapipe/calculators/core/packet_resampler_calculator.proto @@ -108,4 +108,17 @@ message PacketResamplerCalculatorOptions { // are included in the output, even if the nearest timestamp is not // between start_time and end_time. optional bool round_limits = 8 [default = false]; + + // If set, the output frame rate is the same as the input frame rate. + // You need to provide the frame rate of the input images in the header in the + // input_side_packet. + // This option only makes sense in combination with max_frame_rate. It will + // hold on to the original frame rate unless it's higher than the + // max_frame_rate. + optional bool use_input_frame_rate = 11 [default = false]; + + // If set, the output frame rate is limited to this value. + // You need to provide the frame rate of the input images in the header in the + // input_side_packet. + optional double max_frame_rate = 12 [default = -1.0]; } diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc index d80793da4a..ab74fa5463 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator_test.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -14,6 +14,7 @@ #include "mediapipe/calculators/core/packet_resampler_calculator.h" +#include #include #include #include @@ -271,6 +272,150 @@ TEST(PacketResamplerCalculatorTest, TwoPacketsInStream) { } } +TEST(PacketResamplerCalculatorTest, UseInputFrameRate_HeaderHasSameFramerate) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "PacketResamplerCalculator" + input_stream: "DATA:in_data" + input_stream: "VIDEO_HEADER:in_video_header" + output_stream: "DATA:out_data" + options { + [mediapipe.PacketResamplerCalculatorOptions.ext] { + use_input_frame_rate: true + frame_rate: 1000.0 + } + } + )pb")); + + for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) { + runner.MutableInputs()->Tag(kDataTag).packets.push_back( + Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); + } + VideoHeader video_header_in; + video_header_in.width = 10; + video_header_in.height = 100; + video_header_in.frame_rate = 200.0; + video_header_in.duration = 1.0; + video_header_in.format = ImageFormat::SRGB; + runner.MutableInputs() + ->Tag(kVideoHeaderTag) + .packets.push_back( + Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); + MP_ASSERT_OK(runner.Run()); + + std::vector expected_frames = {0, 5000, 10010, 15001, 19990}; + std::vector expected_timestamps = {0, 5000, 10000, 15000, 20000}; + EXPECT_EQ(expected_frames.size(), + runner.Outputs().Tag(kDataTag).packets.size()); + EXPECT_EQ(expected_timestamps.size(), + runner.Outputs().Tag(kDataTag).packets.size()); + + int count = 0; + for (const Packet& packet : runner.Outputs().Tag(kDataTag).packets) { + EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp()); + const std::string& packet_contents = packet.Get(); + EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])), + packet_contents); + ++count; + } +} + +TEST(PacketResamplerCalculatorTest, + UseInputFrameRate_HeaderHasSmallerFramerate) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "PacketResamplerCalculator" + input_stream: "DATA:in_data" + input_stream: "VIDEO_HEADER:in_video_header" + output_stream: "DATA:out_data" + options { + [mediapipe.PacketResamplerCalculatorOptions.ext] { + use_input_frame_rate: true + frame_rate: 1000.0 + } + } + )pb")); + + for (const int64_t ts : {0, 5000, 10010, 15001}) { + runner.MutableInputs()->Tag(kDataTag).packets.push_back( + Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); + } + VideoHeader video_header_in; + video_header_in.width = 10; + video_header_in.height = 100; + video_header_in.frame_rate = 100.0; + video_header_in.duration = 1.0; + video_header_in.format = ImageFormat::SRGB; + runner.MutableInputs() + ->Tag(kVideoHeaderTag) + .packets.push_back( + Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); + MP_ASSERT_OK(runner.Run()); + + std::vector expected_frames = {0, 10010, 15001}; + std::vector expected_timestamps = {0, 10000, 20000}; + EXPECT_EQ(expected_frames.size(), + runner.Outputs().Tag(kDataTag).packets.size()); + EXPECT_EQ(expected_timestamps.size(), + runner.Outputs().Tag(kDataTag).packets.size()); + + int count = 0; + for (const Packet& packet : runner.Outputs().Tag(kDataTag).packets) { + EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp()); + const std::string& packet_contents = packet.Get(); + EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])), + packet_contents); + ++count; + } +} + +TEST(PacketResamplerCalculatorTest, + UseInputFrameRate_MaxFrameRateSmallerThanInput) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "PacketResamplerCalculator" + input_stream: "DATA:in_data" + input_stream: "VIDEO_HEADER:in_video_header" + output_stream: "DATA:out_data" + options { + [mediapipe.PacketResamplerCalculatorOptions.ext] { + use_input_frame_rate: true + frame_rate: 1000.0 + max_frame_rate: 50.0 + } + } + )pb")); + + for (const int64_t ts : {0, 5000, 10010, 15001, 20010}) { + runner.MutableInputs()->Tag(kDataTag).packets.push_back( + Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); + } + VideoHeader video_header_in; + video_header_in.width = 10; + video_header_in.height = 200; + video_header_in.frame_rate = 100.0; + video_header_in.duration = 1.0; + video_header_in.format = ImageFormat::SRGB; + runner.MutableInputs() + ->Tag(kVideoHeaderTag) + .packets.push_back( + Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); + MP_ASSERT_OK(runner.Run()); + + std::vector expected_frames = {0, 20010}; + std::vector expected_timestamps = {0, 20000}; + EXPECT_EQ(expected_frames.size(), + runner.Outputs().Tag(kDataTag).packets.size()); + EXPECT_EQ(expected_timestamps.size(), + runner.Outputs().Tag(kDataTag).packets.size()); + + int count = 0; + for (const Packet& packet : runner.Outputs().Tag(kDataTag).packets) { + EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp()); + const std::string& packet_contents = packet.Get(); + EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])), + packet_contents); + ++count; + } +} + TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepoints) { SimpleRunner runner( "[mediapipe.PacketResamplerCalculatorOptions.ext]: "