From 371f001248bf312be1d9e7cb81da3af8cfb138a1 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang <134643420+yczhang-nv@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:05:41 -0700 Subject: [PATCH] add overload to TensorObject --- .../morpheus/_lib/include/morpheus/messages/control.hpp | 6 ++++++ python/morpheus/morpheus/_lib/messages/__init__.pyi | 2 ++ python/morpheus/morpheus/_lib/messages/module.cpp | 3 +++ python/morpheus/morpheus/_lib/src/messages/control.cpp | 7 +++++++ tests/test_add_scores_stage.py | 3 +-- 5 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp b/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp index df643717a2..199848c62b 100644 --- a/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp +++ b/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp @@ -427,6 +427,12 @@ struct MORPHEUS_EXPORT ControlMessageProxy */ static void payload_from_python_meta(ControlMessage& self, const pybind11::object& meta); + /** + * @brief Set the tensors given a Python instance of TensorMemory + * @param tensor_memory the Python instance of TensorMemory + */ + static void set_tensors_from_python(ControlMessage& self, const pybind11::object& tensor_memory); + /** * @brief Set the tensors from python object * diff --git a/python/morpheus/morpheus/_lib/messages/__init__.pyi b/python/morpheus/morpheus/_lib/messages/__init__.pyi index 8cd3025fc7..48d95358e1 100644 --- a/python/morpheus/morpheus/_lib/messages/__init__.pyi +++ b/python/morpheus/morpheus/_lib/messages/__init__.pyi @@ -85,6 +85,8 @@ class ControlMessage(): def tensors(self, arg0: TensorMemory) -> None: ... @typing.overload def tensors(self, count: int, tensors: object) -> None: ... + @typing.overload + def tensors(self, tensors: object) -> None: ... pass class ControlMessageType(): """ diff --git a/python/morpheus/morpheus/_lib/messages/module.cpp b/python/morpheus/morpheus/_lib/messages/module.cpp index b7c025ddf3..61e9a60866 100644 --- a/python/morpheus/morpheus/_lib/messages/module.cpp +++ b/python/morpheus/morpheus/_lib/messages/module.cpp @@ -45,6 +45,7 @@ #include // for COMPACT_GOOGLE_LOG_INFO, LogMessage, VLOG #include #include // for basic_json +#include #include // IWYU pragma: keep #include #include @@ -442,6 +443,8 @@ PYBIND11_MODULE(messages, _module) py::arg("meta")) .def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors)) .def("tensors", pybind11::overload_cast&>(&ControlMessage::tensors)) + .def("tensors", pybind11::overload_cast(&ControlMessageProxy::set_tensors_from_python), + py::arg("tensors")) .def("tensors", pybind11::overload_cast( &ControlMessageProxy::set_tensors_from_python), diff --git a/python/morpheus/morpheus/_lib/src/messages/control.cpp b/python/morpheus/morpheus/_lib/src/messages/control.cpp index 34f873761e..f7dfe85bb3 100644 --- a/python/morpheus/morpheus/_lib/src/messages/control.cpp +++ b/python/morpheus/morpheus/_lib/src/messages/control.cpp @@ -358,6 +358,13 @@ void ControlMessageProxy::payload_from_python_meta(ControlMessage& self, const p self.payload(MessageMetaInterfaceProxy::init_python_meta(meta)); } +void ControlMessageProxy::set_tensors_from_python(ControlMessage& self, const pybind11::object& tensor_memory) +{ + TensorIndex count = tensor_memory.attr("count").cast(); + pybind11::object tensors = tensor_memory.attr("get_tensors")(); + self.tensors(TensorMemoryInterfaceProxy::init(count, tensors)); +} + void ControlMessageProxy::set_tensors_from_python(ControlMessage& self, TensorIndex count, pybind11::object& tensors) { self.tensors(TensorMemoryInterfaceProxy::init(count, tensors)); diff --git a/tests/test_add_scores_stage.py b/tests/test_add_scores_stage.py index 36328f0014..ea96b6a6d6 100755 --- a/tests/test_add_scores_stage.py +++ b/tests/test_add_scores_stage.py @@ -81,8 +81,7 @@ def test_add_labels_with_multi_response_message_and_control_message(): cm = ControlMessage() cm.payload(MessageMeta(df)) - tensor_memory = TensorMemory(count=2, tensors={"probs": probs_array}) - cm.tensors(tensor_memory.count, tensor_memory.get_tensors()) + cm.tensors(TensorMemory(count=2, tensors={"probs": probs_array})) labeled_cm = AddClassificationsStage._add_labels(cm, idx2label=class_labels, threshold=None)