Skip to content

Commit

Permalink
add overload to TensorObject
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Aug 28, 2024
1 parent d3e1b45 commit 371f001
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,12 @@ struct MORPHEUS_EXPORT ControlMessageProxy
*/
static void payload_from_python_meta(ControlMessage& self, const pybind11::object& meta);

/**
* @brief Set the tensors given a Python instance of TensorMemory
* @param tensor_memory the Python instance of TensorMemory
*/
static void set_tensors_from_python(ControlMessage& self, const pybind11::object& tensor_memory);

/**
* @brief Set the tensors from python object
*
Expand Down
2 changes: 2 additions & 0 deletions python/morpheus/morpheus/_lib/messages/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class ControlMessage():
def tensors(self, arg0: TensorMemory) -> None: ...
@typing.overload
def tensors(self, count: int, tensors: object) -> None: ...
@typing.overload
def tensors(self, tensors: object) -> None: ...
pass
class ControlMessageType():
"""
Expand Down
3 changes: 3 additions & 0 deletions python/morpheus/morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <glog/logging.h> // for COMPACT_GOOGLE_LOG_INFO, LogMessage, VLOG
#include <mrc/edge/edge_connector.hpp>
#include <nlohmann/json.hpp> // for basic_json
#include <pybind11/detail/common.h>
#include <pybind11/functional.h> // IWYU pragma: keep
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
Expand Down Expand Up @@ -442,6 +443,8 @@ PYBIND11_MODULE(messages, _module)
py::arg("meta"))
.def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors))
.def("tensors", pybind11::overload_cast<const std::shared_ptr<TensorMemory>&>(&ControlMessage::tensors))
.def("tensors", pybind11::overload_cast<ControlMessage&, const py::object&>(&ControlMessageProxy::set_tensors_from_python),
py::arg("tensors"))
.def("tensors",
pybind11::overload_cast<ControlMessage&, TensorIndex, py::object&>(
&ControlMessageProxy::set_tensors_from_python),
Expand Down
7 changes: 7 additions & 0 deletions python/morpheus/morpheus/_lib/src/messages/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,13 @@ void ControlMessageProxy::payload_from_python_meta(ControlMessage& self, const p
self.payload(MessageMetaInterfaceProxy::init_python_meta(meta));
}

void ControlMessageProxy::set_tensors_from_python(ControlMessage& self, const pybind11::object& tensor_memory)
{
TensorIndex count = tensor_memory.attr("count").cast<TensorIndex>();
pybind11::object tensors = tensor_memory.attr("get_tensors")();
self.tensors(TensorMemoryInterfaceProxy::init(count, tensors));
}

void ControlMessageProxy::set_tensors_from_python(ControlMessage& self, TensorIndex count, pybind11::object& tensors)
{
self.tensors(TensorMemoryInterfaceProxy::init(count, tensors));
Expand Down
3 changes: 1 addition & 2 deletions tests/test_add_scores_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def test_add_labels_with_multi_response_message_and_control_message():

cm = ControlMessage()
cm.payload(MessageMeta(df))
tensor_memory = TensorMemory(count=2, tensors={"probs": probs_array})
cm.tensors(tensor_memory.count, tensor_memory.get_tensors())
cm.tensors(TensorMemory(count=2, tensors={"probs": probs_array}))

labeled_cm = AddClassificationsStage._add_labels(cm, idx2label=class_labels, threshold=None)

Expand Down

0 comments on commit 371f001

Please sign in to comment.