From 4e7c7b586e84907dded0aab619697a512aa7f452 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 17 Jan 2024 14:39:07 -0800 Subject: [PATCH] Add CalculatorGraph::SetErrorCallback to receive errors in case of async graph use cases. PiperOrigin-RevId: 599294023 --- mediapipe/framework/BUILD | 18 ++ mediapipe/framework/calculator_graph.cc | 14 ++ mediapipe/framework/calculator_graph.h | 24 +++ .../calculator_graph_error_callback_test.cc | 160 ++++++++++++++++++ 4 files changed, 216 insertions(+) create mode 100644 mediapipe/framework/calculator_graph_error_callback_test.cc diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index ab0b30ed80..8f0b0b8f79 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1420,6 +1420,24 @@ cc_test( ], ) +cc_test( + name = "calculator_graph_error_callback_test", + srcs = ["calculator_graph_error_callback_test.cc"], + deps = [ + ":calculator_framework", + ":packet", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + cc_test( name = "calculator_runner_test", size = "medium", diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index 1bd356eac4..f556fc6def 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -495,6 +495,17 @@ absl::Status CalculatorGraph::ObserveOutputStream( return absl::OkStatus(); } +absl::Status CalculatorGraph::SetErrorCallback( + std::function error_callback) { + // Require setting error callback before initialization to: + // - impose the strictest requirement + // - save the future possibility of reporting initialization errors + RET_CHECK(!initialized_) + << "SetErrorCallback must be called before Initialize()"; + error_callback_ = error_callback; + return absl::OkStatus(); +} + absl::StatusOr CalculatorGraph::AddOutputStreamPoller( const std::string& stream_name, bool observe_timestamp_bounds) { RET_CHECK(initialized_).SetNoLogging() @@ -1077,6 +1088,9 @@ void CalculatorGraph::RecordError(const absl::Status& error) { "of memory."; } } + if (error_callback_) { + error_callback_(error); + } } bool CalculatorGraph::GetCombinedErrors(absl::Status* error_status) { diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 80af726504..13b73e6a65 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -157,6 +157,9 @@ class CalculatorGraph { // object is destroyed, even if e.g. Cancel() or WaitUntilDone() have already // been called. After this object is destroyed so is packet_callback. // TODO: Rename to AddOutputStreamCallback. + // + // Note: use `SetErrorCallback` to subscribe for errors when using graph for + // async use cases. absl::Status ObserveOutputStream( const std::string& stream_name, std::function packet_callback, @@ -314,8 +317,26 @@ class CalculatorGraph { } CounterFactory* GetCounterFactory() { return counter_factory_.get(); } + // Sets the error callback to receive graph execution errors when blocking + // calls like `WaitUntilIdle()`, `WaitUntilDone()` cannot be used. + // + // Useful for async graph use cases: e.g. user entering words and each + // word is sent to the graph while graph outputs are received and rendered + // asynchronously. + // + // NOTE: + // - Must be called before graph is initialized. + // - May be executed from multiple threads. + // - Errors are first processed by the graph, then the graph transitions into + // the error state, and then finally the callback is invoked. + absl::Status SetErrorCallback( + std::function error_callback); + // Callback when an error is encountered. // Adds the error to the vector of errors. + // + // Use `SetErrorCallback` to subscribe for errors when using graph for async + // use cases. void RecordError(const absl::Status& error) ABSL_LOCKS_EXCLUDED(error_mutex_); // Combines errors into a status. Returns true if the vector of errors is @@ -693,6 +714,9 @@ class CalculatorGraph { // to add an error to this vector. std::vector errors_ ABSL_GUARDED_BY(error_mutex_); + // Optional error callback set by client. + std::function error_callback_; + // True if the default executor uses the application thread. bool use_application_thread_ = false; diff --git a/mediapipe/framework/calculator_graph_error_callback_test.cc b/mediapipe/framework/calculator_graph_error_callback_test.cc new file mode 100644 index 0000000000..428f476ace --- /dev/null +++ b/mediapipe/framework/calculator_graph_error_callback_test.cc @@ -0,0 +1,160 @@ +#include + +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Node; +using ::mediapipe::api2::Output; +using ::testing::HasSubstr; + +namespace { + +constexpr absl::string_view kErrorMsgFromProcess = + "Error from Calculator::Process."; + +class ProcessFnErrorCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"OUT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) override { + return absl::InternalError(kErrorMsgFromProcess); + } +}; +MEDIAPIPE_REGISTER_NODE(ProcessFnErrorCalculator); + +TEST(CalculatorGraphAsyncErrorsTest, ErrorCallbackReceivesProcessErrors) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "ProcessFnErrorCalculator" + input_stream: 'IN:input' + output_stream: 'OUT:output' + } + )pb"); + + CalculatorGraph graph; + + bool is_error_received = false; + absl::Status output_error; + absl::Mutex m; + auto error_callback_fn = [&graph, &m, &output_error, + &is_error_received](absl::Status error) { + EXPECT_TRUE(graph.HasError()); + + absl::MutexLock lock(&m); + output_error = std::move(error); + is_error_received = true; + }; + + MP_ASSERT_OK(graph.SetErrorCallback(error_callback_fn)); + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(0).At(Timestamp(10)))); + + { + absl::MutexLock lock(&m); + ASSERT_TRUE(m.AwaitWithTimeout(absl::Condition(&is_error_received), + absl::Seconds(1))); + } + EXPECT_THAT(output_error, StatusIs(absl::StatusCode::kInternal, + HasSubstr(kErrorMsgFromProcess))); + + EXPECT_THAT(graph.WaitUntilIdle(), StatusIs(absl::StatusCode::kInternal, + HasSubstr(kErrorMsgFromProcess))); +} + +constexpr absl::string_view kErrorMsgFromOpen = "Error from Calculator::Open."; + +class OpenFnErrorCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"OUT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(CalculatorContext* cc) override { + return absl::InternalError(kErrorMsgFromOpen); + } + + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); + } +}; +MEDIAPIPE_REGISTER_NODE(OpenFnErrorCalculator); + +TEST(CalculatorGraphAsyncErrorsTest, ErrorCallbackReceivesOpenErrors) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "OpenFnErrorCalculator" + input_stream: 'IN:input' + output_stream: 'OUT:output' + } + )pb"); + + CalculatorGraph graph; + + bool is_error_received = false; + absl::Status output_error; + absl::Mutex m; + auto error_callback_fn = [&graph, &m, &output_error, + &is_error_received](absl::Status error) { + EXPECT_TRUE(graph.HasError()); + + absl::MutexLock lock(&m); + output_error = std::move(error); + is_error_received = true; + }; + + MP_ASSERT_OK(graph.SetErrorCallback(error_callback_fn)); + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + + { + absl::MutexLock lock(&m); + ASSERT_TRUE(m.AwaitWithTimeout(absl::Condition(&is_error_received), + absl::Seconds(1))); + } + EXPECT_THAT(output_error, StatusIs(absl::StatusCode::kInternal, + HasSubstr(kErrorMsgFromOpen))); + + EXPECT_THAT(graph.WaitUntilIdle(), StatusIs(absl::StatusCode::kInternal, + HasSubstr(kErrorMsgFromOpen))); +} + +TEST(CalculatorGraphAsyncErrorsTest, ErrorCallbackMustBeSetBeforeInit) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "OpenFnErrorCalculator" + input_stream: 'IN:input' + output_stream: 'OUT:output' + } + )pb"); + + CalculatorGraph graph; + ABSL_CHECK_OK(graph.Initialize(graph_config, {})); + EXPECT_THAT(graph.SetErrorCallback({}), + StatusIs(absl::StatusCode::kInternal)); +} + +} // namespace +} // namespace mediapipe