Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565446429
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 14, 2023
1 parent 85b1938 commit f2b11bf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 9 deletions.
70 changes: 70 additions & 0 deletions mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
Expand Down Expand Up @@ -163,6 +164,75 @@ TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
PacketOfIntsEq(input_timestamp2, std::vector<int>{3, 4})));
}

TEST(BeginEndLoopCalculatorPossibleDataRaceTest,
EndLoopForIntegersDoesNotRace) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
num_threads: 4
input_stream: "ints"
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints"
output_stream: "ITEM:int"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "IncrementCalculator"
input_stream: "int"
output_stream: "int_plus_one"
}
# BEGIN: Data race possibility
# EndLoop###Calculator and another calculator using the same input
# may introduce race due to EndLoop###Calculator possibly consuming
# packet.
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_one"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_one"
}
node {
calculator: "IncrementCalculator"
input_stream: "int_plus_one"
output_stream: "int_plus_two"
}
# END: Data race possibility
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_two"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_two"
}
)pb");
std::vector<Packet> int_plus_one_packets;
tool::AddVectorSink("ints_plus_one", &graph_config, &int_plus_one_packets);
std::vector<Packet> int_original_packets;
tool::AddVectorSink("ints_plus_two", &graph_config, &int_original_packets);

CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
for (int i = 0; i < 100; ++i) {
std::vector<int> ints = {i, i + 1, i + 2};
Timestamp ts = Timestamp(i);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(ts)));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(int_plus_one_packets,
testing::ElementsAre(
PacketOfIntsEq(ts, std::vector<int>{i + 1, i + 2, i + 3})));
EXPECT_THAT(int_original_packets,
testing::ElementsAre(
PacketOfIntsEq(ts, std::vector<int>{i + 2, i + 3, i + 4})));

int_plus_one_packets.clear();
int_original_packets.clear();
}

MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}

// Passes non empty vector through or outputs empty vector in case of timestamp
// bound update.
class PassThroughOrEmptyVectorCalculator : public CalculatorBase {
Expand Down
18 changes: 9 additions & 9 deletions mediapipe/calculators/core/end_loop_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ class EndLoopCalculator : public CalculatorBase {
if (!input_stream_collection_) {
input_stream_collection_.reset(new IterableT);
}
// Try to consume the item and move it into the collection. If the items
// are not consumable, then try to copy them instead. If the items are
// not copyable, then an error will be returned.
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
if (item_ptr_or.ok()) {
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));

if constexpr (std::is_copy_constructible_v<ItemT>) {
input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").Get<ItemT>());
} else {
if constexpr (std::is_copy_constructible_v<ItemT>) {
input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").template Get<ItemT>());
// Try to consume the item and move it into the collection. Return an
// error if the items are not consumable.
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
if (item_ptr_or.ok()) {
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));
} else {
return absl::InternalError(
"The item type is not copiable. Consider making the "
Expand Down

0 comments on commit f2b11bf

Please sign in to comment.