Skip to content

Commit

Permalink
Add CalculatorGraph::SetErrorCallback to receive errors in case of as…
Browse files Browse the repository at this point in the history
…ync graph use cases.

PiperOrigin-RevId: 599294023
  • Loading branch information
MediaPipe Team authored and copybara-github committed Jan 17, 2024
1 parent b62093b commit 4e7c7b5
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 0 deletions.
18 changes: 18 additions & 0 deletions mediapipe/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,24 @@ cc_test(
],
)

cc_test(
name = "calculator_graph_error_callback_test",
srcs = ["calculator_graph_error_callback_test.cc"],
deps = [
":calculator_framework",
":packet",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

cc_test(
name = "calculator_runner_test",
size = "medium",
Expand Down
14 changes: 14 additions & 0 deletions mediapipe/framework/calculator_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,17 @@ absl::Status CalculatorGraph::ObserveOutputStream(
return absl::OkStatus();
}

absl::Status CalculatorGraph::SetErrorCallback(
std::function<void(const absl::Status&)> error_callback) {
// Require setting error callback before initialization to:
// - impose the strictest requirement
// - save the future possibility of reporting initialization errors
RET_CHECK(!initialized_)
<< "SetErrorCallback must be called before Initialize()";
error_callback_ = error_callback;
return absl::OkStatus();
}

absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller(
const std::string& stream_name, bool observe_timestamp_bounds) {
RET_CHECK(initialized_).SetNoLogging()
Expand Down Expand Up @@ -1077,6 +1088,9 @@ void CalculatorGraph::RecordError(const absl::Status& error) {
"of memory.";
}
}
if (error_callback_) {
error_callback_(error);
}
}

bool CalculatorGraph::GetCombinedErrors(absl::Status* error_status) {
Expand Down
24 changes: 24 additions & 0 deletions mediapipe/framework/calculator_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class CalculatorGraph {
// object is destroyed, even if e.g. Cancel() or WaitUntilDone() have already
// been called. After this object is destroyed so is packet_callback.
// TODO: Rename to AddOutputStreamCallback.
//
// Note: use `SetErrorCallback` to subscribe for errors when using graph for
// async use cases.
absl::Status ObserveOutputStream(
const std::string& stream_name,
std::function<absl::Status(const Packet&)> packet_callback,
Expand Down Expand Up @@ -314,8 +317,26 @@ class CalculatorGraph {
}
CounterFactory* GetCounterFactory() { return counter_factory_.get(); }

// Sets the error callback to receive graph execution errors when blocking
// calls like `WaitUntilIdle()`, `WaitUntilDone()` cannot be used.
//
// Useful for async graph use cases: e.g. user entering words and each
// word is sent to the graph while graph outputs are received and rendered
// asynchronously.
//
// NOTE:
// - Must be called before graph is initialized.
// - May be executed from multiple threads.
// - Errors are first processed by the graph, then the graph transitions into
// the error state, and then finally the callback is invoked.
absl::Status SetErrorCallback(
std::function<void(const absl::Status&)> error_callback);

// Callback when an error is encountered.
// Adds the error to the vector of errors.
//
// Use `SetErrorCallback` to subscribe for errors when using graph for async
// use cases.
void RecordError(const absl::Status& error) ABSL_LOCKS_EXCLUDED(error_mutex_);

// Combines errors into a status. Returns true if the vector of errors is
Expand Down Expand Up @@ -693,6 +714,9 @@ class CalculatorGraph {
// to add an error to this vector.
std::vector<absl::Status> errors_ ABSL_GUARDED_BY(error_mutex_);

// Optional error callback set by client.
std::function<void(const absl::Status&)> error_callback_;

// True if the default executor uses the application thread.
bool use_application_thread_ = false;

Expand Down
160 changes: 160 additions & 0 deletions mediapipe/framework/calculator_graph_error_callback_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#include <utility>

#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"

namespace mediapipe {

using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
using ::testing::HasSubstr;

namespace {

constexpr absl::string_view kErrorMsgFromProcess =
"Error from Calculator::Process.";

class ProcessFnErrorCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"OUT"};

MEDIAPIPE_NODE_CONTRACT(kIn, kOut);

absl::Status Process(CalculatorContext* cc) override {
return absl::InternalError(kErrorMsgFromProcess);
}
};
MEDIAPIPE_REGISTER_NODE(ProcessFnErrorCalculator);

TEST(CalculatorGraphAsyncErrorsTest, ErrorCallbackReceivesProcessErrors) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "ProcessFnErrorCalculator"
input_stream: 'IN:input'
output_stream: 'OUT:output'
}
)pb");

CalculatorGraph graph;

bool is_error_received = false;
absl::Status output_error;
absl::Mutex m;
auto error_callback_fn = [&graph, &m, &output_error,
&is_error_received](absl::Status error) {
EXPECT_TRUE(graph.HasError());

absl::MutexLock lock(&m);
output_error = std::move(error);
is_error_received = true;
};

MP_ASSERT_OK(graph.SetErrorCallback(error_callback_fn));
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(0).At(Timestamp(10))));

{
absl::MutexLock lock(&m);
ASSERT_TRUE(m.AwaitWithTimeout(absl::Condition(&is_error_received),
absl::Seconds(1)));
}
EXPECT_THAT(output_error, StatusIs(absl::StatusCode::kInternal,
HasSubstr(kErrorMsgFromProcess)));

EXPECT_THAT(graph.WaitUntilIdle(), StatusIs(absl::StatusCode::kInternal,
HasSubstr(kErrorMsgFromProcess)));
}

constexpr absl::string_view kErrorMsgFromOpen = "Error from Calculator::Open.";

class OpenFnErrorCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"OUT"};

MEDIAPIPE_NODE_CONTRACT(kIn, kOut);

absl::Status Open(CalculatorContext* cc) override {
return absl::InternalError(kErrorMsgFromOpen);
}

absl::Status Process(CalculatorContext* cc) override {
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(OpenFnErrorCalculator);

TEST(CalculatorGraphAsyncErrorsTest, ErrorCallbackReceivesOpenErrors) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "OpenFnErrorCalculator"
input_stream: 'IN:input'
output_stream: 'OUT:output'
}
)pb");

CalculatorGraph graph;

bool is_error_received = false;
absl::Status output_error;
absl::Mutex m;
auto error_callback_fn = [&graph, &m, &output_error,
&is_error_received](absl::Status error) {
EXPECT_TRUE(graph.HasError());

absl::MutexLock lock(&m);
output_error = std::move(error);
is_error_received = true;
};

MP_ASSERT_OK(graph.SetErrorCallback(error_callback_fn));
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));

{
absl::MutexLock lock(&m);
ASSERT_TRUE(m.AwaitWithTimeout(absl::Condition(&is_error_received),
absl::Seconds(1)));
}
EXPECT_THAT(output_error, StatusIs(absl::StatusCode::kInternal,
HasSubstr(kErrorMsgFromOpen)));

EXPECT_THAT(graph.WaitUntilIdle(), StatusIs(absl::StatusCode::kInternal,
HasSubstr(kErrorMsgFromOpen)));
}

TEST(CalculatorGraphAsyncErrorsTest, ErrorCallbackMustBeSetBeforeInit) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "OpenFnErrorCalculator"
input_stream: 'IN:input'
output_stream: 'OUT:output'
}
)pb");

CalculatorGraph graph;
ABSL_CHECK_OK(graph.Initialize(graph_config, {}));
EXPECT_THAT(graph.SetErrorCallback({}),
StatusIs(absl::StatusCode::kInternal));
}

} // namespace
} // namespace mediapipe

0 comments on commit 4e7c7b5

Please sign in to comment.