From bf9b553e6b429d220c1c87075e0cb71199873442 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 5 Jun 2024 10:23:23 -0700 Subject: [PATCH 1/3] initial commit --- cpp/mrc/include/mrc/node/dynamic_batcher.hpp | 91 ++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 cpp/mrc/include/mrc/node/dynamic_batcher.hpp diff --git a/cpp/mrc/include/mrc/node/dynamic_batcher.hpp b/cpp/mrc/include/mrc/node/dynamic_batcher.hpp new file mode 100644 index 000000000..f3b282e79 --- /dev/null +++ b/cpp/mrc/include/mrc/node/dynamic_batcher.hpp @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/channel/buffered_channel.hpp" +#include "mrc/channel/channel.hpp" +#include "mrc/constants.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/core/watcher.hpp" +#include "mrc/exceptions/runtime_error.hpp" +#include "mrc/node/rx_epilogue_tap.hpp" +#include "mrc/node/rx_prologue_tap.hpp" +#include "mrc/node/rx_sink_base.hpp" +#include "mrc/node/rx_source_base.hpp" +#include "mrc/node/rx_subscribable.hpp" +#include "mrc/runnable/runnable.hpp" +#include "mrc/utils/type_utils.hpp" + +#include +#include + +#include +#include +#include + +template +class DynamicBatcher : public mrc::node::WritableProvider, + public mrc::node::ReadableAcceptor, + public mrc::node::SinkChannelOwner, + public mrc::node::WritableAcceptor>, + public mrc::node::ReadableProvider>, + public mrc::node::SourceChannelOwner>, + public mrc::runnable::RunnableWithContext { + using state_t = mrc::runnable::Runnable::State; + using input_t = T; + using output_t = std::vector; + +public: + DynamicBatcher(size_t max_count) { + // Set the default channel + mrc::node::SinkChannelOwner::set_channel( + std::make_unique>()); + mrc::node::SourceChannelOwner::set_channel( + std::make_unique>()); + } + ~DynamicBatcher() override = default; + +private: + /** + * @brief Runnable's entrypoint. + */ + void run(mrc::runnable::Context &ctx) override { + T input_data; + auto status = this->get_readable_edge()->await_read(input_data); + + // TODO(Yuchen): fill out the implementation here + + + + + + // Only drop the output edges if we are rank 0 + if (ctx.rank() == 0) { + // Need to drop the output edges + mrc::node::SourceProperties::release_edge_connection(); + mrc::node::SinkProperties::release_edge_connection(); + } + } + + /** + * @brief Runnable's state control, for stopping from MRC. + */ + void on_state_update(const state_t &state) final; + + std::stop_source m_stop_source; +}; From 7ac9ad9f9ac6cb649cbd81b1d2d8cf31ebed9ab4 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Mon, 10 Jun 2024 15:28:12 -0700 Subject: [PATCH 2/3] initial impl dynamic_batcher --- cpp/mrc/include/mrc/node/dynamic_batcher.hpp | 50 +++++++++++++++++--- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/cpp/mrc/include/mrc/node/dynamic_batcher.hpp b/cpp/mrc/include/mrc/node/dynamic_batcher.hpp index f3b282e79..539815423 100644 --- a/cpp/mrc/include/mrc/node/dynamic_batcher.hpp +++ b/cpp/mrc/include/mrc/node/dynamic_batcher.hpp @@ -32,6 +32,7 @@ #include "mrc/utils/type_utils.hpp" #include +#include #include #include @@ -51,7 +52,8 @@ class DynamicBatcher : public mrc::node::WritableProvider, using output_t = std::vector; public: - DynamicBatcher(size_t max_count) { + DynamicBatcher(size_t max_count, std::chrono::milliseconds duration) + : m_max_count(max_count), m_duration(duration) { // Set the default channel mrc::node::SinkChannelOwner::set_channel( std::make_unique>()); @@ -65,14 +67,32 @@ class DynamicBatcher : public mrc::node::WritableProvider, * @brief Runnable's entrypoint. */ void run(mrc::runnable::Context &ctx) override { - T input_data; - auto status = this->get_readable_edge()->await_read(input_data); - - // TODO(Yuchen): fill out the implementation here - + // T input_data; + // auto status = this->get_readable_edge()->await_read(input_data); + // Create an observable from the input channel + auto input_observable = + rxcpp::observable<>::create([this](rxcpp::subscriber s) { + T input_data; + while (this->get_readable_edge()->await_read(input_data) == + mrc::channel::Status::success) { + s.on_next(input_data); + } + s.on_completed(); + }); + // Buffer the items from the input observable + auto buffered_observable = input_observable.buffer_with_time_or_count( + m_duration, m_max_count, rxcpp::observe_on_new_thread()); + // Subscribe to the buffered observable + buffered_observable.subscribe( + [this](const std::vector &buffer) { + this->get_writable_edge()->await_write(buffer); + }, + []() { + // Handle completion + }); // Only drop the output edges if we are rank 0 if (ctx.rank() == 0) { @@ -85,7 +105,23 @@ class DynamicBatcher : public mrc::node::WritableProvider, /** * @brief Runnable's state control, for stopping from MRC. */ - void on_state_update(const state_t &state) final; + void on_state_update(const state_t &state) final { + switch (state) { + case state_t::Stop: + // Do nothing, we wait for the upstream channel to return closed + // m_stop_source.request_stop(); + break; + + case state_t::Kill: + m_stop_source.request_stop(); + break; + + default: + break; + } + } std::stop_source m_stop_source; + size_t m_max_count; + std::chrono::milliseconds m_duration; }; From f9dbdd70aca97e7bb845802e6ca3569a210255f7 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Mon, 17 Jun 2024 20:51:09 -0700 Subject: [PATCH 3/3] temp commit --- CMakeLists.txt | 2 +- .../node/{ => operators}/dynamic_batcher.hpp | 9 +++- cpp/mrc/tests/test_segment.cpp | 43 +++++++++++++++++++ mrc.code-workspace | 40 +++++++++++++---- 4 files changed, 83 insertions(+), 11 deletions(-) rename cpp/mrc/include/mrc/node/{ => operators}/dynamic_batcher.hpp (94%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e9931166..7a1accb04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,7 +122,7 @@ enable_language(CUDA) set(MRC_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}) # Set a default build type if none was specified -rapids_cmake_build_type(Release) +rapids_cmake_build_type(Debug) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) diff --git a/cpp/mrc/include/mrc/node/dynamic_batcher.hpp b/cpp/mrc/include/mrc/node/operators/dynamic_batcher.hpp similarity index 94% rename from cpp/mrc/include/mrc/node/dynamic_batcher.hpp rename to cpp/mrc/include/mrc/node/operators/dynamic_batcher.hpp index 539815423..8f68ac85c 100644 --- a/cpp/mrc/include/mrc/node/dynamic_batcher.hpp +++ b/cpp/mrc/include/mrc/node/operators/dynamic_batcher.hpp @@ -30,6 +30,7 @@ #include "mrc/node/rx_subscribable.hpp" #include "mrc/runnable/runnable.hpp" #include "mrc/utils/type_utils.hpp" +#include "rxcpp/operators/rx-observe_on.hpp" #include #include @@ -39,6 +40,7 @@ #include #include +namespace mrc::node { template class DynamicBatcher : public mrc::node::WritableProvider, public mrc::node::ReadableAcceptor, @@ -81,9 +83,11 @@ class DynamicBatcher : public mrc::node::WritableProvider, s.on_completed(); }); + // DVLOG(1) << "DynamicBatcher: m_duration: " << m_duration.count() << std::endl; + // Buffer the items from the input observable auto buffered_observable = input_observable.buffer_with_time_or_count( - m_duration, m_max_count, rxcpp::observe_on_new_thread()); + m_duration, m_max_count, rxcpp::observe_on_event_loop()); // Subscribe to the buffered observable buffered_observable.subscribe( @@ -122,6 +126,7 @@ class DynamicBatcher : public mrc::node::WritableProvider, } std::stop_source m_stop_source; - size_t m_max_count; + int m_max_count; std::chrono::milliseconds m_duration; }; +} // namespace mrc::node diff --git a/cpp/mrc/tests/test_segment.cpp b/cpp/mrc/tests/test_segment.cpp index bd3b09d78..4792ab28c 100644 --- a/cpp/mrc/tests/test_segment.cpp +++ b/cpp/mrc/tests/test_segment.cpp @@ -20,6 +20,7 @@ #include "mrc/benchmarking/trace_statistics.hpp" #include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/operators/broadcast.hpp" +#include "mrc/node/operators/dynamic_batcher.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" @@ -1122,4 +1123,46 @@ TEST_F(TestSegment, SegmentGetEgressNotEgressError) */ } +TEST_F(TestSegment, SegmentDynamicBatcher) +{ + unsigned int iterations{3}; + std::atomic sink1_results{0}; + float sink2_results{0}; + std::mutex mux; + + auto init = [&](segment::IBuilder& segment) { + auto src = segment.make_source("src", [&](rxcpp::subscriber& s) { + for (size_t i = 0; i < iterations && s.is_subscribed(); i++) + { + s.on_next(1); + s.on_next(2); + s.on_next(3); + } + + s.on_completed(); + }); + + auto dynamic_batcher = segment.construct_object>("dynamic_batcher", 2, std::chrono::milliseconds(100)); + + segment.make_edge(src, dynamic_batcher); + + auto sink = segment.make_sink>("sink", [&](std::vector x) { + DVLOG(1) << "Sink got vector" << std::endl; + for (auto i : x) + { + DVLOG(1) << "Sink got value: " << i << std::endl; + // sink1_results.fetch_add(i, std::memory_order_relaxed); + } + }); + + segment.make_edge(dynamic_batcher, sink); + }; + + auto segdef = Segment::create("dynamic_batcher_test", init); + + auto pipeline = mrc::make_pipeline(); + pipeline->register_segment(std::move(segdef)); + execute_pipeline(std::move(pipeline)); +} + } // namespace mrc diff --git a/mrc.code-workspace b/mrc.code-workspace index 632b0a0e6..82cd91226 100644 --- a/mrc.code-workspace +++ b/mrc.code-workspace @@ -86,25 +86,49 @@ "type": "cppdbg" }, { - "MIMode": "gdb", + "MIMode": "lldb", "args": [], "cwd": "${workspaceFolder}", "environment": [], "externalConsole": false, - "miDebuggerPath": "gdb", - "name": "debug bench_mrc.x", - "preLaunchTask": "C/C++: g++ build active file", - "program": "${workspaceFolder}/build/benchmarks/bench_mrc", + "miDebuggerPath": "lldb", + "name": "debug test_mrc.x with lldb", + // "preLaunchTask": "C/C++: g++ build active file", + "program": "${workspaceFolder}/build/cpp/mrc/tests/test_mrc.x", "request": "launch", "setupCommands": [ { - "description": "Enable pretty-printing for gdb", + "description": "Enable pretty-printing for lldb", "ignoreFailures": true, - "text": "-enable-pretty-printing" + "text": "command script import pretty_printers.py" } ], + "justMyCode": true, "stopAtEntry": false, - "type": "cppdbg" + "type": "lldb" + }, + { + "MIMode": "lldb", + "args": [], + "cwd": "${workspaceFolder}", + "environment": [], + "externalConsole": false, + "miDebuggerPath": "lldb", + "name": "debug TestSegment.SegmentDynamicBatcher with lldb", + // "preLaunchTask": "C/C++: g++ build active file", + "program": "${workspaceFolder}/build/cpp/mrc/tests/test_mrc.x", + "request": "launch", + "setupCommands": [ + { + "description": "Enable pretty-printing for lldb", + "ignoreFailures": true, + "text": "command script import pretty_printers.py" + } + ], + "args": ["--gtest_filter=TestSegment.SegmentDynamicBatcher"], + "justMyCode": true, + "stopAtEntry": false, + "type": "lldb" }, { "MIMode": "gdb",