Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic batcher node #481

Draft
wants to merge 3 commits into
base: branch-24.06
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ enable_language(CUDA)
set(MRC_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR})

# Set a default build type if none was specified
rapids_cmake_build_type(Release)
rapids_cmake_build_type(Debug)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down
132 changes: 132 additions & 0 deletions cpp/mrc/include/mrc/node/operators/dynamic_batcher.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "mrc/channel/buffered_channel.hpp"
#include "mrc/channel/channel.hpp"
#include "mrc/constants.hpp"
#include "mrc/core/utils.hpp"
#include "mrc/core/watcher.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/rx_epilogue_tap.hpp"
#include "mrc/node/rx_prologue_tap.hpp"
#include "mrc/node/rx_sink_base.hpp"
#include "mrc/node/rx_source_base.hpp"
#include "mrc/node/rx_subscribable.hpp"
#include "mrc/runnable/runnable.hpp"
#include "mrc/utils/type_utils.hpp"
#include "rxcpp/operators/rx-observe_on.hpp"

#include <glog/logging.h>
#include <rxcpp/operators/rx-buffer_time_count.hpp>
#include <rxcpp/rx.hpp>

#include <exception>
#include <memory>
#include <mutex>

namespace mrc::node {
template <typename T, typename ContextT>
class DynamicBatcher : public mrc::node::WritableProvider<T>,
public mrc::node::ReadableAcceptor<T>,
public mrc::node::SinkChannelOwner<T>,
public mrc::node::WritableAcceptor<std::vector<T>>,
public mrc::node::ReadableProvider<std::vector<T>>,
public mrc::node::SourceChannelOwner<std::vector<T>>,
public mrc::runnable::RunnableWithContext<ContextT> {
using state_t = mrc::runnable::Runnable::State;
using input_t = T;
using output_t = std::vector<T>;

public:
DynamicBatcher(size_t max_count, std::chrono::milliseconds duration)
: m_max_count(max_count), m_duration(duration) {
// Set the default channel
mrc::node::SinkChannelOwner<input_t>::set_channel(
std::make_unique<mrc::channel::BufferedChannel<input_t>>());
mrc::node::SourceChannelOwner<output_t>::set_channel(
std::make_unique<mrc::channel::BufferedChannel<output_t>>());
}
~DynamicBatcher() override = default;

private:
/**
* @brief Runnable's entrypoint.
*/
void run(mrc::runnable::Context &ctx) override {
// T input_data;
// auto status = this->get_readable_edge()->await_read(input_data);

// Create an observable from the input channel
auto input_observable =
rxcpp::observable<>::create<T>([this](rxcpp::subscriber<T> s) {
T input_data;
while (this->get_readable_edge()->await_read(input_data) ==
mrc::channel::Status::success) {
s.on_next(input_data);
}
s.on_completed();
});

// DVLOG(1) << "DynamicBatcher: m_duration: " << m_duration.count() << std::endl;

// Buffer the items from the input observable
auto buffered_observable = input_observable.buffer_with_time_or_count(
m_duration, m_max_count, rxcpp::observe_on_event_loop());

// Subscribe to the buffered observable
buffered_observable.subscribe(
[this](const std::vector<T> &buffer) {
this->get_writable_edge()->await_write(buffer);
},
[]() {
// Handle completion
});

// Only drop the output edges if we are rank 0
if (ctx.rank() == 0) {
// Need to drop the output edges
mrc::node::SourceProperties<output_t>::release_edge_connection();
mrc::node::SinkProperties<T>::release_edge_connection();
}
}

/**
* @brief Runnable's state control, for stopping from MRC.
*/
void on_state_update(const state_t &state) final {
switch (state) {
case state_t::Stop:
// Do nothing, we wait for the upstream channel to return closed
// m_stop_source.request_stop();
break;

case state_t::Kill:
m_stop_source.request_stop();
break;

default:
break;
}
}

std::stop_source m_stop_source;
int m_max_count;
std::chrono::milliseconds m_duration;
};
} // namespace mrc::node
43 changes: 43 additions & 0 deletions cpp/mrc/tests/test_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mrc/benchmarking/trace_statistics.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/operators/broadcast.hpp"
#include "mrc/node/operators/dynamic_batcher.hpp"
#include "mrc/node/rx_node.hpp"
#include "mrc/node/rx_sink.hpp"
#include "mrc/node/rx_source.hpp"
Expand Down Expand Up @@ -1122,4 +1123,46 @@ TEST_F(TestSegment, SegmentGetEgressNotEgressError)
*/
}

TEST_F(TestSegment, SegmentDynamicBatcher)
{
unsigned int iterations{3};
std::atomic<unsigned int> sink1_results{0};
float sink2_results{0};
std::mutex mux;

auto init = [&](segment::IBuilder& segment) {
auto src = segment.make_source<int>("src", [&](rxcpp::subscriber<int>& s) {
for (size_t i = 0; i < iterations && s.is_subscribed(); i++)
{
s.on_next(1);
s.on_next(2);
s.on_next(3);
}

s.on_completed();
});

auto dynamic_batcher = segment.construct_object<node::DynamicBatcher<int, runnable::Context>>("dynamic_batcher", 2, std::chrono::milliseconds(100));

segment.make_edge(src, dynamic_batcher);

auto sink = segment.make_sink<std::vector<int>>("sink", [&](std::vector<int> x) {
DVLOG(1) << "Sink got vector" << std::endl;
for (auto i : x)
{
DVLOG(1) << "Sink got value: " << i << std::endl;
// sink1_results.fetch_add(i, std::memory_order_relaxed);
}
});

segment.make_edge(dynamic_batcher, sink);
};

auto segdef = Segment::create("dynamic_batcher_test", init);

auto pipeline = mrc::make_pipeline();
pipeline->register_segment(std::move(segdef));
execute_pipeline(std::move(pipeline));
}

} // namespace mrc
40 changes: 32 additions & 8 deletions mrc.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,49 @@
"type": "cppdbg"
},
{
"MIMode": "gdb",
"MIMode": "lldb",
"args": [],
"cwd": "${workspaceFolder}",
"environment": [],
"externalConsole": false,
"miDebuggerPath": "gdb",
"name": "debug bench_mrc.x",
"preLaunchTask": "C/C++: g++ build active file",
"program": "${workspaceFolder}/build/benchmarks/bench_mrc",
"miDebuggerPath": "lldb",
"name": "debug test_mrc.x with lldb",
// "preLaunchTask": "C/C++: g++ build active file",
"program": "${workspaceFolder}/build/cpp/mrc/tests/test_mrc.x",
"request": "launch",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"description": "Enable pretty-printing for lldb",
"ignoreFailures": true,
"text": "-enable-pretty-printing"
"text": "command script import pretty_printers.py"
}
],
"justMyCode": true,
"stopAtEntry": false,
"type": "cppdbg"
"type": "lldb"
},
{
"MIMode": "lldb",
"args": [],
"cwd": "${workspaceFolder}",
"environment": [],
"externalConsole": false,
"miDebuggerPath": "lldb",
"name": "debug TestSegment.SegmentDynamicBatcher with lldb",
// "preLaunchTask": "C/C++: g++ build active file",
"program": "${workspaceFolder}/build/cpp/mrc/tests/test_mrc.x",
"request": "launch",
"setupCommands": [
{
"description": "Enable pretty-printing for lldb",
"ignoreFailures": true,
"text": "command script import pretty_printers.py"
}
],
"args": ["--gtest_filter=TestSegment.SegmentDynamicBatcher"],
"justMyCode": true,
"stopAtEntry": false,
"type": "lldb"
},
{
"MIMode": "gdb",
Expand Down
Loading