Skip to content

Commit

Permalink
Pass a mrc.Subscription object to sources rather than a `mrc.Subscr…
Browse files Browse the repository at this point in the history
…iber` (#499)

* Remove the `make_source_subscriber` method in favor of inspecting the Python function signature.
* Since the `make_source_subscriber` method was never part of a release I think this can still be considered a non-breaking change.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #499
  • Loading branch information
dagardner-nv authored Sep 17, 2024
1 parent ccbcd76 commit 48d17a1
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 26 deletions.
4 changes: 0 additions & 4 deletions python/mrc/_pymrc/include/pymrc/segment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ class BuilderProxy
const std::string& name,
pybind11::function gen_factory);

static std::shared_ptr<mrc::segment::ObjectProperties> make_source_subscriber(mrc::segment::IBuilder& self,
const std::string& name,
pybind11::function gen_factory);

static std::shared_ptr<mrc::segment::ObjectProperties> make_source_component(mrc::segment::IBuilder& self,
const std::string& name,
pybind11::iterator source_iterator);
Expand Down
8 changes: 7 additions & 1 deletion python/mrc/_pymrc/include/pymrc/subscriber.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -47,6 +47,12 @@ class SubscriberProxy
static bool is_subscribed(PyObjectSubscriber* self);
};

class SubscriptionProxy
{
public:
static bool is_subscribed(PySubscription* self);
};

class ObservableProxy
{
public:
Expand Down
38 changes: 29 additions & 9 deletions python/mrc/_pymrc/src/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ class SubscriberFuncWrapper : public mrc::pymrc::PythonSource<PyHolder>
{
DVLOG(10) << ctx.info() << " Starting source";
py::gil_scoped_acquire gil;
py::object py_sub = py::cast(subscriber);
auto py_iter = m_gen_factory.operator()<py::iterator>(std::move(py_sub));
PySubscription subscription = subscriber.get_subscription();
py::object py_sub = py::cast(subscription);
auto py_iter = m_gen_factory.operator()<py::iterator>(std::move(py_sub));
PyIteratorWrapper iter_wrapper{std::move(py_iter)};

for (auto next_val : iter_wrapper)
Expand Down Expand Up @@ -360,14 +361,33 @@ std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source(mrc::s
const std::string& name,
py::function gen_factory)
{
return build_source(self, name, PyIteratorWrapper(std::move(gen_factory)));
}
// Determine if the gen_factory is expecting to receive a subscription object
auto inspect_mod = py::module::import("inspect");
auto signature = inspect_mod.attr("signature")(gen_factory);
auto params = signature.attr("parameters");
auto num_params = py::len(params);
bool expects_subscription = false;

if (num_params > 0)
{
// We know there is at least one parameter. Check if the first parameter is a subscription object
// Note, when we receive a function that has been bound with `functools.partial(fn, arg1=some_value)`, the
// parameter is still visible in the signature of the partial object.
auto mrc_mod = py::module::import("mrc");
auto param_values = params.attr("values")();
auto first_param = py::iter(param_values);
auto type_hint = py::object((*first_param).attr("annotation"));
expects_subscription = (type_hint.is(mrc_mod.attr("Subscription")) ||
type_hint.equal(py::str("mrc.Subscription")) ||
type_hint.equal(py::str("Subscription")));
}

std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source_subscriber(mrc::segment::IBuilder& self,
const std::string& name,
py::function gen_factory)
{
return self.construct_object<SubscriberFuncWrapper>(name, std::move(gen_factory));
if (expects_subscription)
{
return self.construct_object<SubscriberFuncWrapper>(name, std::move(gen_factory));
}

return build_source(self, name, PyIteratorWrapper(std::move(gen_factory)));
}

std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source_component(mrc::segment::IBuilder& self,
Expand Down
8 changes: 7 additions & 1 deletion python/mrc/_pymrc/src/subscriber.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -115,6 +115,12 @@ bool SubscriberProxy::is_subscribed(PyObjectSubscriber* self)
return self->is_subscribed();
}

bool SubscriptionProxy::is_subscribed(PySubscription* self)
{
// No GIL here
return self->is_subscribed();
}

PySubscription ObservableProxy::subscribe(PyObjectObservable* self, PyObjectObserver& observer)
{
// Call the internal subscribe function
Expand Down
6 changes: 0 additions & 6 deletions python/mrc/core/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,6 @@ PYBIND11_MODULE(segment, py_mod)
const std::string&,
py::function)>(&BuilderProxy::make_source));

Builder.def("make_source_subscriber",
static_cast<std::shared_ptr<mrc::segment::ObjectProperties> (*)(mrc::segment::IBuilder&,
const std::string&,
py::function)>(
&BuilderProxy::make_source_subscriber));

Builder.def("make_source_component",
static_cast<std::shared_ptr<mrc::segment::ObjectProperties> (*)(mrc::segment::IBuilder&,
const std::string&,
Expand Down
3 changes: 2 additions & 1 deletion python/mrc/core/subscriber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ PYBIND11_MODULE(subscriber, py_mod)
// Common must be first in every module
pymrc::import(py_mod, "mrc.core.common");

py::class_<PySubscription>(py_mod, "Subscription");
py::class_<PySubscription>(py_mod, "Subscription")
.def("is_subscribed", &SubscriptionProxy::is_subscribed, py::call_guard<py::gil_scoped_release>());

py::class_<PyObjectObserver>(py_mod, "Observer")
.def("on_next",
Expand Down
6 changes: 3 additions & 3 deletions python/tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def blocking_source():

def build(builder: mrc.Builder):

def gen_data(subscriber: mrc.Subscriber):
def gen_data(subscription: mrc.Subscription):
yield 1
while subscriber.is_subscribed():
while subscription.is_subscribed():
time.sleep(0.1)

return builder.make_source_subscriber("blocking_source", gen_data)
return builder.make_source("blocking_source", gen_data)

return build

Expand Down
36 changes: 35 additions & 1 deletion python/tests/test_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -489,5 +489,39 @@ def on_completed():
assert on_completed_count == 1


def test_source_with_bound_value():
"""
This test ensures that the bound values isn't confused with a subscription object
"""
on_next_value = None

def segment_init(seg: mrc.Builder):

def source_gen(a):
yield a

bound_gen = functools.partial(source_gen, a=1)
source = seg.make_source("my_src", bound_gen)

def on_next(x: int):
nonlocal on_next_value
on_next_value = x

sink = seg.make_sink("sink", on_next)
seg.make_edge(source, sink)

pipeline = mrc.Pipeline()
pipeline.make_segment("my_seg", segment_init)

options = mrc.Options()
executor = mrc.Executor(options)
executor.register_pipeline(pipeline)

executor.start()
executor.join()

assert on_next_value == 1


if (__name__ == "__main__"):
test_launch_options_properties()

0 comments on commit 48d17a1

Please sign in to comment.