From aa8d42e79936bc7b2558682ca1197cedca8c7041 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Thu, 15 Feb 2024 20:33:47 -0700 Subject: [PATCH] ControlMessage improvements (#1511) Resolves #1502 Adds the ability to attach TensorMemory to ControlMessages Improves get/set/list metadata functions Adds the ability to attach grouped/keyed timestamps Authors: - Devin Robison (https://github.com/drobison00) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/1511 --- examples/llm/cli.py | 2 +- .../include/morpheus/messages/control.hpp | 241 +++++++++++++++-- morpheus/_lib/messages/__init__.pyi | 20 +- morpheus/_lib/messages/module.cpp | 25 +- morpheus/_lib/src/messages/control.cpp | 162 +++++++++-- .../tests/messages/test_control_message.cpp | 199 ++++++++++++++ morpheus/stages/inference/inference_stage.py | 12 +- .../stages/preprocess/preprocess_nlp_stage.py | 20 +- tests/messages/test_control_message.py | 251 +++++++++++++++++- tests/utils/test_control_message_utils.py | 8 +- 10 files changed, 858 insertions(+), 82 deletions(-) diff --git a/examples/llm/cli.py b/examples/llm/cli.py index c8aea20320..1ea9198dc1 100644 --- a/examples/llm/cli.py +++ b/examples/llm/cli.py @@ -32,7 +32,7 @@ callback=parse_log_level, help="Specify the logging level to use.") @click.option('--use_cpp', - default=True, + default=False, type=bool, help=("Whether or not to use C++ node and message types or to prefer python. " "Only use as a last resort if bugs are encountered")) diff --git a/morpheus/_lib/include/morpheus/messages/control.hpp b/morpheus/_lib/include/morpheus/messages/control.hpp index 9adb568f90..8ee020c76d 100644 --- a/morpheus/_lib/include/morpheus/messages/control.hpp +++ b/morpheus/_lib/include/morpheus/messages/control.hpp @@ -22,10 +22,12 @@ #include #include +#include #include #include #include #include +#include namespace morpheus { @@ -159,6 +161,11 @@ enum class ControlMessageType // std::shared_ptr m_tensors; // }; +class TensorMemory; + +// System-clock for better compatibility with pybind11/chrono +using time_point_t = std::chrono::time_point; + /** * @brief Class representing a control message for coordinating data processing tasks. * @@ -170,7 +177,8 @@ class ControlMessage { public: ControlMessage(); - ControlMessage(const nlohmann::json& config); + explicit ControlMessage(const nlohmann::json& config); + ControlMessage(const ControlMessage& other); // Copies config and metadata, but not payload /** @@ -183,7 +191,7 @@ class ControlMessage * @brief Get the configuration object for the control message. * @return A const reference to the json object containing configuration information. */ - const nlohmann::json& config() const; + [[nodiscard]] const nlohmann::json& config() const; /** * @brief Add a task of the given type to the control message. @@ -197,19 +205,19 @@ class ControlMessage * @param task_type A string indicating the type of the task. * @return True if a task of the given type exists, false otherwise. */ - bool has_task(const std::string& task_type) const; + [[nodiscard]] bool has_task(const std::string& task_type) const; /** * @brief Remove and return a task of the given type from the control message. * @param task_type A string indicating the type of the task. * @return A json object describing the task. */ - const nlohmann::json remove_task(const std::string& task_type); + nlohmann::json remove_task(const std::string& task_type); /** * @brief Get the tasks for the control message. */ - const nlohmann::json& get_tasks() const; + [[nodiscard]] const nlohmann::json& get_tasks() const; /** * @brief Add a key-value pair to the metadata for the control message. @@ -223,27 +231,47 @@ class ControlMessage * @param key A string indicating the metadata key. * @return True if the metadata key exists, false otherwise. */ - bool has_metadata(const std::string& key) const; + [[nodiscard]] bool has_metadata(const std::string& key) const; /** * @brief Get the metadata for the control message. */ - const nlohmann::json& get_metadata() const; + [[nodiscard]] nlohmann::json get_metadata() const; /** * @brief Get the metadata value for the given key from the control message. + * If the key does not exist, the behavior depends on the fail_on_nonexist parameter. + * * @param key A string indicating the metadata key. - * @return A json object describing the metadata value. + * @param fail_on_nonexist If true, throws an exception when the key does not exist. + * If false, returns std::nullopt for non-existing keys. + * @return An optional json object describing the metadata value if it exists. */ - const nlohmann::json get_metadata(const std::string& key) const; + [[nodiscard]] nlohmann::json get_metadata(const std::string& key, bool fail_on_nonexist = false) const; /** - * @brief Get all metadata keys for the control message. - * @return A json object containing all metadata keys and values. + * @brief Lists all metadata keys currently stored in the control message. + * + * This method retrieves a list of all metadata keys present in the control message. + * Metadata within a control message typically includes supplementary information + * such as configuration settings, operational parameters, or annotations that + * are not directly part of the message payload but are crucial for processing + * or understanding the message. + * + * @return A std::vector containing the keys of all metadata entries + * in the control message. If no metadata has been set, the returned vector + * will be empty. */ - const nlohmann::json list_metadata() const; + [[nodiscard]] std::vector list_metadata() const; /** + * @brief Retrieves the current payload object of the control message. + * + * This method returns a shared pointer to the current payload object associated + * with this control message. The payload object encapsulates metadata or data + * specific to this message instance. + * + * @return A shared pointer to the MessageMeta instance representing the message payload. * @brief Get the payload object for the control message. * @param payload * A shared pointer to the message payload. @@ -251,11 +279,42 @@ class ControlMessage std::shared_ptr payload(); /** - * @brief Set the payload object - * @param payload + * @brief Assigns a new payload object to the control message. + * + * Sets the payload of the control message to the provided MessageMeta instance. + * The payload contains data or metadata pertinent to the message. Using a shared + * pointer ensures that the payload is managed efficiently with automatic reference + * counting. + * + * @param payload A shared pointer to the MessageMeta instance to be set as the new payload. */ void payload(const std::shared_ptr& payload); + /** + * @brief Retrieves the tensor memory associated with the control message. + * + * This method returns a shared pointer to the TensorMemory object linked with + * the control message, if any. TensorMemory typically contains or references + * tensors or other large data blobs relevant to the message's purpose. + * + * @return A shared pointer to the TensorMemory instance associated with the message, + * or nullptr if no tensor memory is set. + */ + std::shared_ptr tensors(); + + /** + * @brief Associates tensor memory with the control message. + * + * Sets the tensor memory for the control message to the provided TensorMemory instance. + * This tensor memory can contain tensors or large data blobs pertinent to the message. + * Utilizing a shared pointer facilitates efficient memory management through automatic + * reference counting. + * + * @param tensor_memory A shared pointer to the TensorMemory instance to be associated + * with the control message. + */ + void tensors(const std::shared_ptr& tensor_memory); + /** * @brief Get the type of task associated with the control message. * @return An enum value indicating the task type. @@ -269,49 +328,189 @@ class ControlMessage */ void task_type(ControlMessageType task_type); + /** + * @brief Sets a timestamp for a specific key. + * + * This method stores a timestamp associated with a unique identifier, + * If the key already exists, its timestamp will be updated to the new value. + * + * @param key The specific key for which the timestamp is to be set. + * @param timestamp The timestamp to be associated with the key. + */ + void set_timestamp(const std::string& key, time_point_t timestamp_ns); + + /** + * @brief Retrieves the timestamp for a specific key. + * + * Attempts to find and return the timestamp associated with the specified key. + * If the key does not exist, the method's behavior is determined by the fail_if_nonexist flag. + * + * @param key The specific key for which the timestamp is requested. + * @param fail_if_nonexist If true, the method throws an exception if the timestamp doesn't exist. + * If false, returns std::nullopt for non-existing timestamps. + * @return An optional containing the timestamp if found, or std::nullopt + * otherwise. + */ + std::optional get_timestamp(const std::string& key, bool fail_if_nonexist = false); + + /** + * @brief Retrieves timestamps for all keys that match a regex pattern. + * + * Searches for the specified for keys that match the provided regex filter and returns + * a map of these keys and their associated timestamps. + * + * @param regex_filter A regular expression pattern that keys must match to be included in the result. + * @return A map containing the matching key and their timestamps. The map will be empty if no matches are found. + */ + std::map filter_timestamp(const std::string& regex_filter); + private: static const std::string s_config_schema; // NOLINT static std::map s_task_type_map; // NOLINT ControlMessageType m_cm_type{ControlMessageType::NONE}; std::shared_ptr m_payload{nullptr}; + std::shared_ptr m_tensors{nullptr}; nlohmann::json m_tasks{}; nlohmann::json m_config{}; + + std::map m_timestamps{}; }; struct ControlMessageProxy { + /** + * @brief Creates a new ControlMessage instance from a configuration dictionary. + * @param config A pybind11::dict representing the configuration for the ControlMessage. + * @return A shared_ptr to a newly created ControlMessage instance. + */ static std::shared_ptr create(pybind11::dict& config); + + /** + * @brief Creates a new ControlMessage instance as a copy of an existing one. + * @param other A shared_ptr to another ControlMessage instance to copy. + * @return A shared_ptr to the newly copied ControlMessage instance. + */ static std::shared_ptr create(std::shared_ptr other); + /** + * @brief Creates a deep copy of the ControlMessage instance. + * @param self Reference to the underlying ControlMessage object. + * @return A shared_ptr to the copied ControlMessage instance. + */ static std::shared_ptr copy(ControlMessage& self); + /** + * @brief Retrieves the configuration of the ControlMessage as a dictionary. + * @param self Reference to the underlying ControlMessage object. + * @return A pybind11::dict representing the ControlMessage's configuration. + */ static pybind11::dict config(ControlMessage& self); - // Required for proxy conversion of json -> dict in python + /** + * @brief Updates the configuration of the ControlMessage from a dictionary. + * @param self Reference to the underlying ControlMessage object. + * @param config A pybind11::dict representing the new configuration. + */ static void config(ControlMessage& self, pybind11::dict& config); + /** + * @brief Adds a task to the ControlMessage. + * @param self Reference to the underlying ControlMessage object. + * @param type The type of the task to be added. + * @param task A pybind11::dict representing the task to be added. + */ static void add_task(ControlMessage& self, const std::string& type, pybind11::dict& task); + + /** + * @brief Removes and returns a task of the given type from the ControlMessage. + * @param self Reference to the underlying ControlMessage object. + * @param type The type of the task to be removed. + * @return A pybind11::dict representing the removed task. + */ static pybind11::dict remove_task(ControlMessage& self, const std::string& type); + + /** + * @brief Retrieves all tasks from the ControlMessage. + * @param self Reference to the underlying ControlMessage object. + * @return A pybind11::dict containing all tasks. + */ static pybind11::dict get_tasks(ControlMessage& self); /** - * @brief Set a metadata key-value pair -- value must be json serializable - * @param self - * @param key - * @param value + * @brief Sets a metadata key-value pair. + * @param self Reference to the underlying ControlMessage object. + * @param key The key for the metadata entry. + * @param value The value for the metadata entry, must be JSON serializable. */ static void set_metadata(ControlMessage& self, const std::string& key, pybind11::object& value); - static pybind11::object get_metadata(ControlMessage& self, std::optional const& key); - static pybind11::dict list_metadata(ControlMessage& self); + /** + * @brief Retrieves a metadata value by key, with an optional default value. + * + * @param self Reference to the underlying ControlMessage object. + * @param key The key for the metadata entry. If not provided, retrieves all metadata. + * @param default_value An optional default value to return if the key does not exist. + * @return The value associated with the key, the default value if the key is not found, or all metadata if the key + * is not provided. + */ + static pybind11::object get_metadata(ControlMessage& self, + const pybind11::object& key, + pybind11::object default_value); + + /** + * @brief Lists all metadata keys of the ControlMessage. + * @param self Reference to the underlying ControlMessage object. + * @return A pybind11::list containing all metadata keys. + */ + static pybind11::list list_metadata(ControlMessage& self); /** * @brief Set the payload object given a Python instance of MessageMeta * @param meta */ static void payload_from_python_meta(ControlMessage& self, const pybind11::object& meta); + + /** + * @brief Sets a timestamp for a given key. + * @param self Reference to the underlying ControlMessage object. + * @param key The key associated with the timestamp. + * @param timestamp A datetime.datetime object representing the timestamp. + * + * This method directly takes a datetime.datetime object from Python and sets the corresponding + * std::chrono::system_clock::time_point for the specified key in the ControlMessage object. + */ + static void set_timestamp(ControlMessage& self, const std::string& key, pybind11::object timestamp); + + /** + * @brief Retrieves the timestamp for a specific key from the ControlMessage object. + * + * @param self Reference to the underlying ControlMessage object. + * @param key The specific key for which the timestamp is requested. + * @param fail_if_nonexist Determines the behavior when the requested timestamp does not exist. + * If true, an exception is thrown. If false, py::none is returned. + * @return A datetime.datetime object representing the timestamp if found, or py::none if not found + * and fail_if_nonexist is false. + * + * This method fetches the timestamp associated with the specified key and returns it as a + * datetime.datetime object in Python. If the timestamp does not exist and fail_if_nonexist is true, + * an exception is raised. + */ + static pybind11::object get_timestamp(ControlMessage& self, const std::string& key, bool fail_if_nonexist = false); + + /** + * @brief Retrieves timestamps for all keys that match a regex pattern from the ControlMessage object. + * + * @param self Reference to the underlying ControlMessage object. + * @param regex_filter The regex pattern that keys must match to be included in the result. + * @return A Python dictionary of matching keys and their timestamps as datetime.datetime objects. + * + * This method retrieves all timestamps within the ControlMessage object that match a specified + * regex pattern. Each key and its associated timestamp are returned in a Python dictionary, with + * timestamps represented as datetime.datetime objects. + */ + static pybind11::dict filter_timestamp(ControlMessage& self, const std::string& regex_filter); }; #pragma GCC visibility pop diff --git a/morpheus/_lib/messages/__init__.pyi b/morpheus/_lib/messages/__init__.pyi index 4f7137a60a..937e0a6084 100644 --- a/morpheus/_lib/messages/__init__.pyi +++ b/morpheus/_lib/messages/__init__.pyi @@ -49,11 +49,19 @@ class ControlMessage(): @typing.overload def config(self, config: dict) -> None: ... def copy(self) -> ControlMessage: ... - def get_metadata(self, key: typing.Optional[str] = None) -> object: ... + def filter_timestamp(self, regex_filter: str) -> dict: + """ + Retrieve timestamps matching a regex filter within a given group. + """ + def get_metadata(self, key: object = None, default_value: object = None) -> object: ... def get_tasks(self) -> dict: ... + def get_timestamp(self, key: str, fail_if_nonexist: bool = False) -> object: + """ + Retrieve the timestamp for a given group and key. Returns None if the timestamp does not exist and fail_if_nonexist is False. + """ def has_metadata(self, key: str) -> bool: ... def has_task(self, task_type: str) -> bool: ... - def list_metadata(self) -> dict: ... + def list_metadata(self) -> list: ... @typing.overload def payload(self) -> MessageMeta: ... @typing.overload @@ -62,10 +70,18 @@ class ControlMessage(): def payload(self, meta: object) -> None: ... def remove_task(self, task_type: str) -> dict: ... def set_metadata(self, key: str, value: object) -> None: ... + def set_timestamp(self, key: str, timestamp: object) -> None: + """ + Set a timestamp for a given key and group. + """ @typing.overload def task_type(self) -> ControlMessageType: ... @typing.overload def task_type(self, task_type: ControlMessageType) -> None: ... + @typing.overload + def tensors(self) -> TensorMemory: ... + @typing.overload + def tensors(self, arg0: TensorMemory) -> None: ... pass class ControlMessageType(): """ diff --git a/morpheus/_lib/messages/module.cpp b/morpheus/_lib/messages/module.cpp index 7aa21f24d1..b5b84ee071 100644 --- a/morpheus/_lib/messages/module.cpp +++ b/morpheus/_lib/messages/module.cpp @@ -358,7 +358,6 @@ PYBIND11_MODULE(messages, _module) .value("NONE", ControlMessageType::INFERENCE) .value("TRAINING", ControlMessageType::TRAINING); - // TODO(Devin): Circle back on return value policy choices py::class_>(_module, "ControlMessage") .def(py::init<>()) .def(py::init(py::overload_cast(&ControlMessageProxy::create))) @@ -369,17 +368,37 @@ PYBIND11_MODULE(messages, _module) py::arg("config")) .def("config", pybind11::overload_cast(&ControlMessageProxy::config)) .def("copy", &ControlMessageProxy::copy) - .def("get_metadata", &ControlMessageProxy::get_metadata, py::arg("key") = py::none()) + .def("get_metadata", + &ControlMessageProxy::get_metadata, + py::arg("key") = py::none(), + py::arg("default_value") = py::none()) .def("get_tasks", &ControlMessageProxy::get_tasks) + .def("filter_timestamp", + py::overload_cast(&ControlMessageProxy::filter_timestamp), + "Retrieve timestamps matching a regex filter within a given group.", + py::arg("regex_filter")) + .def("get_timestamp", + py::overload_cast(&ControlMessageProxy::get_timestamp), + "Retrieve the timestamp for a given group and key. Returns None if the timestamp does not exist and " + "fail_if_nonexist is False.", + py::arg("key"), + py::arg("fail_if_nonexist") = false) + .def("set_timestamp", + &ControlMessageProxy::set_timestamp, + "Set a timestamp for a given key and group.", + py::arg("key"), + py::arg("timestamp")) .def("has_metadata", &ControlMessage::has_metadata, py::arg("key")) .def("has_task", &ControlMessage::has_task, py::arg("task_type")) .def("list_metadata", &ControlMessageProxy::list_metadata) - .def("payload", pybind11::overload_cast<>(&ControlMessage::payload), py::return_value_policy::move) + .def("payload", pybind11::overload_cast<>(&ControlMessage::payload)) .def("payload", pybind11::overload_cast&>(&ControlMessage::payload)) .def( "payload", pybind11::overload_cast(&ControlMessageProxy::payload_from_python_meta), py::arg("meta")) + .def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors)) + .def("tensors", pybind11::overload_cast&>(&ControlMessage::tensors)) .def("remove_task", &ControlMessageProxy::remove_task, py::arg("task_type")) .def("set_metadata", &ControlMessageProxy::set_metadata, py::arg("key"), py::arg("value")) .def("task_type", pybind11::overload_cast<>(&ControlMessage::task_type)) diff --git a/morpheus/_lib/src/messages/control.cpp b/morpheus/_lib/src/messages/control.cpp index f1413c2650..dd54b80a43 100644 --- a/morpheus/_lib/src/messages/control.cpp +++ b/morpheus/_lib/src/messages/control.cpp @@ -20,12 +20,17 @@ #include "morpheus/messages/meta.hpp" #include +#include // IWYU pragma: keep +#include #include #include +#include #include #include +#include #include +#include namespace py = pybind11; @@ -58,7 +63,6 @@ const nlohmann::json& ControlMessage::config() const void ControlMessage::add_task(const std::string& task_type, const nlohmann::json& task) { - // TODO(Devin) Schema check VLOG(20) << "Adding task of type " << task_type << " to control message" << task.dump(4); auto _task_type = s_task_type_map.contains(task_type) ? s_task_type_map[task_type] : ControlMessageType::NONE; @@ -85,9 +89,9 @@ const nlohmann::json& ControlMessage::get_tasks() const return m_tasks; } -const nlohmann::json ControlMessage::list_metadata() const +std::vector ControlMessage::list_metadata() const { - nlohmann::json key_list = nlohmann::json::array(); + std::vector key_list{}; for (auto it = m_config["metadata"].begin(); it != m_config["metadata"].end(); ++it) { @@ -112,17 +116,31 @@ bool ControlMessage::has_metadata(const std::string& key) const return m_config["metadata"].contains(key); } -const nlohmann::json& ControlMessage::get_metadata() const +nlohmann::json ControlMessage::get_metadata() const { - return m_config["metadata"]; + auto metadata = m_config["metadata"]; + + return metadata; } -const nlohmann::json ControlMessage::get_metadata(const std::string& key) const +nlohmann::json ControlMessage::get_metadata(const std::string& key, bool fail_on_nonexist) const { - return m_config["metadata"].at(key); + // Assuming m_metadata is a std::map storing metadata + auto metadata = m_config["metadata"]; + auto it = metadata.find(key); + if (it != metadata.end()) + { + return metadata.at(key); + } + else if (fail_on_nonexist) + { + throw std::runtime_error("Metadata key does not exist: " + key); + } + + return {}; } -const nlohmann::json ControlMessage::remove_task(const std::string& task_type) +nlohmann::json ControlMessage::remove_task(const std::string& task_type) { auto& task_set = m_tasks.at(task_type); auto iter_task = task_set.begin(); @@ -138,6 +156,43 @@ const nlohmann::json ControlMessage::remove_task(const std::string& task_type) throw std::runtime_error("No tasks of type " + task_type + " found"); } +void ControlMessage::set_timestamp(const std::string& key, time_point_t timestamp_ns) +{ + // Insert or update the timestamp in the map + m_timestamps[key] = timestamp_ns; +} + +std::map ControlMessage::filter_timestamp(const std::string& regex_filter) +{ + std::map matching_timestamps; + std::regex filter(regex_filter); + + for (const auto& [key, timestamp] : m_timestamps) + { + // Check if the key matches the regex + if (std::regex_search(key, filter)) + { + matching_timestamps[key] = timestamp; + } + } + + return matching_timestamps; +} + +std::optional ControlMessage::get_timestamp(const std::string& key, bool fail_if_nonexist) +{ + auto it = m_timestamps.find(key); + if (it != m_timestamps.end()) + { + return it->second; // Return the found timestamp + } + else if (fail_if_nonexist) + { + throw std::runtime_error("Timestamp for the specified key does not exist."); + } + return std::nullopt; +} + void ControlMessage::config(const nlohmann::json& config) { if (config.contains("type")) @@ -173,10 +228,6 @@ void ControlMessage::config(const nlohmann::json& config) std::shared_ptr ControlMessage::payload() { - // auto temp = std::move(m_payload); - // TODO(Devin): Decide if we copy or steal the payload - // m_payload = nullptr; - return m_payload; } @@ -185,6 +236,16 @@ void ControlMessage::payload(const std::shared_ptr& payload) m_payload = payload; } +std::shared_ptr ControlMessage::tensors() +{ + return m_tensors; +} + +void ControlMessage::tensors(const std::shared_ptr& tensors) +{ + m_tensors = tensors; +} + ControlMessageType ControlMessage::task_type() { return m_cm_type; @@ -236,14 +297,23 @@ py::dict ControlMessageProxy::config(ControlMessage& self) return dict; } -py::object ControlMessageProxy::get_metadata(ControlMessage& self, std::optional const& key) +py::object ControlMessageProxy::get_metadata(ControlMessage& self, + const py::object& key, + pybind11::object default_value) { - if (key == std::nullopt) + if (key.is_none()) + { + auto metadata = self.get_metadata(); + return mrc::pymrc::cast_from_json(metadata); + } + + auto value = self.get_metadata(py::cast(key), false); + if (value.empty()) { - return mrc::pymrc::cast_from_json(self.get_metadata()); + return default_value; } - return mrc::pymrc::cast_from_json(self.get_metadata(key.value())); + return mrc::pymrc::cast_from_json(value); } void ControlMessageProxy::set_metadata(ControlMessage& self, const std::string& key, pybind11::object& value) @@ -251,11 +321,65 @@ void ControlMessageProxy::set_metadata(ControlMessage& self, const std::string& self.set_metadata(key, mrc::pymrc::cast_from_pyobject(value)); } -py::dict ControlMessageProxy::list_metadata(ControlMessage& self) +py::list ControlMessageProxy::list_metadata(ControlMessage& self) { - auto dict = mrc::pymrc::cast_from_json(self.list_metadata()); + auto keys = self.list_metadata(); + py::list py_keys; + for (const auto& key : keys) + { + py_keys.append(py::str(key)); + } + return py_keys; +} - return dict; +py::dict ControlMessageProxy::filter_timestamp(ControlMessage& self, const std::string& regex_filter) +{ + auto cpp_map = self.filter_timestamp(regex_filter); + py::dict py_dict; + for (const auto& [key, timestamp] : cpp_map) + { + // Directly use the timestamp as datetime.datetime in Python + py_dict[py::str(key)] = timestamp; + } + return py_dict; +} + +// Get a specific timestamp and return it as datetime.datetime or None +py::object ControlMessageProxy::get_timestamp(ControlMessage& self, const std::string& key, bool fail_if_nonexist) +{ + try + { + auto timestamp_opt = self.get_timestamp(key, fail_if_nonexist); + if (timestamp_opt) + { + // Directly return the timestamp as datetime.datetime in Python + return py::cast(*timestamp_opt); + } + + return py::none(); + } catch (const std::runtime_error& e) + { + if (fail_if_nonexist) + { + throw py::value_error(e.what()); + } + return py::none(); + } +} + +// Set a timestamp using a datetime.datetime object from Python +void ControlMessageProxy::set_timestamp(ControlMessage& self, const std::string& key, py::object timestamp_ns) +{ + if (!py::isinstance(timestamp_ns)) + { + // Convert Python datetime.datetime to std::chrono::system_clock::time_point before setting + auto _timestamp_ns = timestamp_ns.cast(); + self.set_timestamp(key, _timestamp_ns); + } + else + { + throw std::runtime_error("Timestamp cannot be None"); + } } void ControlMessageProxy::config(ControlMessage& self, py::dict& config) diff --git a/morpheus/_lib/tests/messages/test_control_message.cpp b/morpheus/_lib/tests/messages/test_control_message.cpp index 61bc59b72e..7fe86afd6c 100644 --- a/morpheus/_lib/tests/messages/test_control_message.cpp +++ b/morpheus/_lib/tests/messages/test_control_message.cpp @@ -19,18 +19,25 @@ #include "test_messages.hpp" #include "morpheus/messages/control.hpp" +#include "morpheus/messages/memory/tensor_memory.hpp" #include "morpheus/messages/meta.hpp" #include #include +#include +#include +#include #include +#include #include #include using namespace morpheus; using namespace morpheus::test; +using clock_type_t = std::chrono::system_clock; + TEST_F(TestControlMessage, InitializationTest) { auto msg_one = ControlMessage(); @@ -48,6 +55,76 @@ TEST_F(TestControlMessage, InitializationTest) ASSERT_EQ(msg_two.has_task("load"), true); } +TEST_F(TestControlMessage, SetAndGetMetadata) +{ + auto msg = ControlMessage(); + + nlohmann::json value = {{"property", "value"}}; + std::string key = "testKey"; + + // Set metadata + msg.set_metadata(key, value); + + // Verify metadata can be retrieved and matches what was set + EXPECT_TRUE(msg.has_metadata(key)); + auto retrievedValue = msg.get_metadata(key, true); + EXPECT_EQ(value, retrievedValue); + + // Verify listing metadata includes the key + auto keys = msg.list_metadata(); + auto it = std::find(keys.begin(), keys.end(), key); + EXPECT_NE(it, keys.end()); +} + +// Test for overwriting metadata +TEST_F(TestControlMessage, OverwriteMetadata) +{ + auto msg = ControlMessage(); + + nlohmann::json value1 = {{"initial", "data"}}; + nlohmann::json value2 = {{"updated", "data"}}; + std::string key = "overwriteKey"; + + // Set initial metadata + msg.set_metadata(key, value1); + + // Overwrite metadata + msg.set_metadata(key, value2); + + // Verify metadata was overwritten + auto retrievedValue = msg.get_metadata(key, false); + EXPECT_EQ(value2, retrievedValue); +} + +// Test retrieving metadata when it does not exist +TEST_F(TestControlMessage, GetNonexistentMetadata) +{ + auto msg = ControlMessage(); + + std::string key = "nonexistentKey"; + + // Attempt to retrieve metadata that does not exist + EXPECT_FALSE(msg.has_metadata(key)); + EXPECT_THROW(auto const x = msg.get_metadata(key, true), std::runtime_error); + EXPECT_NO_THROW(auto const x = msg.get_metadata(key, false)); // Should not throw, but return empty json +} + +// Test retrieving all metadata +TEST_F(TestControlMessage, GetAllMetadata) +{ + auto msg = ControlMessage(); + + // Setup - add some metadata + msg.set_metadata("key1", {{"data", "value1"}}); + msg.set_metadata("key2", {{"data", "value2"}}); + + // Retrieve all metadata + auto metadata = msg.get_metadata(); + EXPECT_EQ(2, metadata.size()); // Assuming get_metadata() returns a json object with all metadata + EXPECT_TRUE(metadata.contains("key1")); + EXPECT_TRUE(metadata.contains("key2")); +} + TEST_F(TestControlMessage, SetMessageTest) { auto msg = ControlMessage(); @@ -131,4 +208,126 @@ TEST_F(TestControlMessage, PayloadTest) msg.payload(data_payload); ASSERT_EQ(msg.payload(), data_payload); +} + +TEST_F(TestControlMessage, SetAndGetTimestamp) +{ + auto msg = ControlMessage(); + + // Test setting a timestamp + auto start = clock_type_t::now(); + msg.set_timestamp("group1::key1", start); + + auto result = msg.get_timestamp("group1::key1", false); + ASSERT_TRUE(result.has_value()); + + // Direct comparison since we're using time points now + EXPECT_EQ(start, result.value()); +} + +TEST_F(TestControlMessage, GetTimestampWithRegex) +{ + auto start = clock_type_t::now(); + auto msg = ControlMessage(); + + // Set two timestamps slightly apart + msg.set_timestamp("group1::key1", start); + auto later = clock_type_t::now(); + msg.set_timestamp("group1::key2", later); + + auto result = msg.filter_timestamp("group1::key.*"); + ASSERT_EQ(2, result.size()); + + // Check using the actual time points + EXPECT_EQ(start, result["group1::key1"]); + EXPECT_EQ(later, result["group1::key2"]); + + auto resultSingle = msg.filter_timestamp("group1::key1"); + ASSERT_EQ(1, resultSingle.size()); + EXPECT_EQ(start, resultSingle["group1::key1"]); +} + +TEST_F(TestControlMessage, GetTimestampNonExistentKey) +{ + auto msg = ControlMessage(); + + auto result = msg.get_timestamp("group1::nonexistent", false); + EXPECT_FALSE(result.has_value()); + + EXPECT_THROW( + { + try + { + msg.get_timestamp("group1::nonexistent", true); + } catch (const std::runtime_error& e) + { + EXPECT_STREQ("Timestamp for the specified key does not exist.", e.what()); + throw; + } + }, + std::runtime_error); +} + +TEST_F(TestControlMessage, UpdateTimestamp) +{ + auto msg = ControlMessage(); + + auto start = clock_type_t::now(); + msg.set_timestamp("group1::key1", start); + auto later = clock_type_t::now(); + msg.set_timestamp("group1::key1", later); + + auto result = msg.get_timestamp("group1::key1", false); + ASSERT_TRUE(result.has_value()); + + // Check using the actual time points for update + EXPECT_EQ(later, result.value()); +} + +// Test setting and getting Ten:sorMemory +TEST_F(TestControlMessage, SetAndGetTensorMemory) +{ + auto msg = ControlMessage(); + + auto tensorMemory = std::make_shared(0); + // Optionally, modify tensorMemory here if it has any mutable state to test + + // Set the tensor memory + msg.tensors(tensorMemory); + + // Retrieve the tensor memory + auto retrievedTensorMemory = msg.tensors(); + + // Verify that the retrieved tensor memory matches what was set + EXPECT_EQ(tensorMemory, retrievedTensorMemory); +} + +// Test setting TensorMemory to nullptr +TEST_F(TestControlMessage, SetTensorMemoryToNull) +{ + auto msg = ControlMessage(); + + // Set tensor memory to a valid object first + msg.tensors(std::make_shared(0)); + + // Now set it to nullptr + msg.tensors(nullptr); + + // Retrieve the tensor memory + auto retrievedTensorMemory = msg.tensors(); + + // Verify that the retrieved tensor memory is nullptr + EXPECT_EQ(nullptr, retrievedTensorMemory); +} + +// Test retrieving TensorMemory when none has been set +TEST_F(TestControlMessage, GetTensorMemoryWhenNoneSet) +{ + auto msg = ControlMessage(); + + // Attempt to retrieve tensor memory without setting it first + auto retrievedTensorMemory = msg.tensors(); + + // Verify that the retrieved tensor memory is nullptr + EXPECT_EQ(nullptr, retrievedTensorMemory); } \ No newline at end of file diff --git a/morpheus/stages/inference/inference_stage.py b/morpheus/stages/inference/inference_stage.py index 1cc6703fc6..d601d3880d 100644 --- a/morpheus/stages/inference/inference_stage.py +++ b/morpheus/stages/inference/inference_stage.py @@ -37,7 +37,6 @@ from morpheus.messages.memory.tensor_memory import TensorMemory from morpheus.pipeline.multi_message_stage import MultiMessageStage from morpheus.pipeline.stage_schema import StageSchema -from morpheus.stages.preprocess.preprocess_nlp_stage import base64_to_cupyarray from morpheus.utils.producer_consumer_queue import ProducerConsumerQueue logger = logging.getLogger(__name__) @@ -240,19 +239,12 @@ def on_next(message: typing.Union[MultiInferenceMessage, ControlMessage]): _message = None if (isinstance(message, ControlMessage)): _message = message + tensors = message.tensors() memory_params: dict = message.get_metadata("inference_memory_params") inference_type: str = memory_params["inference_type"] - count = int(memory_params["count"]) - segment_ids = base64_to_cupyarray(memory_params["segment_ids"]) - input_ids = base64_to_cupyarray(memory_params["input_ids"]) - input_mask = base64_to_cupyarray(memory_params["input_mask"]) if (inference_type == "nlp"): - memory = InferenceMemoryNLP(count=count, - input_ids=input_ids, - input_mask=input_mask, - seq_ids=segment_ids) - + memory = InferenceMemoryNLP(count=tensors.count, **tensors.get_tensors()) meta_message = MessageMeta(df=message.payload().df) multi_message = MultiMessage(meta=meta_message) diff --git a/morpheus/stages/preprocess/preprocess_nlp_stage.py b/morpheus/stages/preprocess/preprocess_nlp_stage.py index b5587ee90e..8b45dafe37 100644 --- a/morpheus/stages/preprocess/preprocess_nlp_stage.py +++ b/morpheus/stages/preprocess/preprocess_nlp_stage.py @@ -25,6 +25,8 @@ import cudf import morpheus._lib.stages as _stages +# pylint: disable=morpheus-incorrect-lib-from-import +from morpheus._lib.messages import TensorMemory as CppTensorMemory from morpheus.cli.register_stage import register_stage from morpheus.cli.utils import MorpheusRelativePath from morpheus.cli.utils import get_package_relative_file @@ -203,15 +205,15 @@ def process_control_message(message: ControlMessage, del text_series - message.set_metadata( - "inference_memory_params", - { - "inference_type": "nlp", - "count": tokenized.input_ids.shape[0], - "segment_ids": cupyarray_to_base64(tokenized.segment_ids), - "input_ids": cupyarray_to_base64(tokenized.input_ids), - "input_mask": cupyarray_to_base64(tokenized.input_mask), - }) + message.tensors( + CppTensorMemory(count=tokenized.input_ids.shape[0], + tensors={ + "input_ids": tokenized.input_ids, + "input_mask": tokenized.input_mask, + "seq_ids": tokenized.segment_ids + })) + + message.set_metadata("inference_memory_params", {"inference_type": "nlp"}) return message diff --git a/tests/messages/test_control_message.py b/tests/messages/test_control_message.py index 4e913be066..dc2c1a3c2b 100644 --- a/tests/messages/test_control_message.py +++ b/tests/messages/test_control_message.py @@ -14,11 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime + +import cupy as cp import pytest import cudf from morpheus import messages +# pylint: disable=morpheus-incorrect-lib-from-import +from morpheus.messages import TensorMemory # pylint: disable=unsupported-membership-test # pylint: disable=unsubscriptable-object @@ -84,18 +89,20 @@ def test_control_message_tasks(): @pytest.mark.usefixtures("config_only_cpp") def test_control_message_metadata(): message = messages.ControlMessage() + message.set_metadata("key_x", "value_x") message.set_metadata("key_y", "value_y") message.set_metadata("key_z", "value_z") - assert len(message.get_metadata()) == 3 + metadata_tags = message.list_metadata() + assert len(metadata_tags) == 3 - assert "key_x" in message.get_metadata() - assert "key_y" in message.get_metadata() - assert "key_z" in message.get_metadata() - assert message.get_metadata()["key_x"] == "value_x" - assert message.get_metadata()["key_y"] == "value_y" - assert message.get_metadata()["key_z"] == "value_z" + assert "key_x" in metadata_tags + assert "key_y" in metadata_tags + assert "key_z" in metadata_tags + assert message.get_metadata("key_x") == "value_x" + assert message.get_metadata("key_y") == "value_y" + assert message.get_metadata("key_z") == "value_z" message.set_metadata("key_y", "value_yy") @@ -106,6 +113,52 @@ def test_control_message_metadata(): assert "not_mutable" not in message.get_metadata() +def test_set_and_get_metadata(): + message = messages.ControlMessage() + + # Test setting and getting metadata + message.set_metadata("test_key", "test_value") + assert message.get_metadata("test_key") == "test_value" + + # Test getting metadata with a default value when the key does not exist + default_value = "default" + assert message.get_metadata("nonexistent_key", default_value) == default_value + + # Test getting all metadata + message.set_metadata("another_key", "another_value") + all_metadata = message.get_metadata() + assert isinstance(all_metadata, dict) + assert all_metadata["test_key"] == "test_value" + assert all_metadata["another_key"] == "another_value" + + +def test_list_metadata(): + message = messages.ControlMessage() + + # Setting some metadata + message.set_metadata("key1", "value1") + message.set_metadata("key2", "value2") + message.set_metadata("key3", "value3") + + # Listing all metadata keys + keys = message.list_metadata() + assert isinstance(keys, list) + assert set(keys) == {"key1", "key2", "key3"} + + +def test_get_metadata_default_value(): + message = messages.ControlMessage() + + # Setting metadata to test default value retrieval + message.set_metadata("existing_key", "existing_value") + + # Getting an existing key without default value + assert message.get_metadata("existing_key") == "existing_value" + + # Getting a non-existing key with default value provided + assert message.get_metadata("non_existing_key", "default_value") == "default_value" + + @pytest.mark.usefixtures("config_only_cpp") def test_control_message_get(): raw_control_message = messages.ControlMessage({ @@ -168,8 +221,182 @@ def test_control_message_set_and_get_payload(): assert payload.df == payload2.df -if (__name__ == "__main__"): - test_control_message_init() - test_control_message_get() - test_control_message_set() - test_control_message_set_and_get_payload() +@pytest.mark.usefixtures("config_only_cpp") +def test_set_and_get_timestamp_single(): + # Create a ControlMessage instance + msg = messages.ControlMessage() + + # Define test data + key = "group1::key1" + timestamp = datetime.datetime.now() + + # Set timestamp + msg.set_timestamp(key, timestamp) + + # Get timestamp and assert it's as expected + result = msg.get_timestamp(key, True) + assert result == timestamp, "The retrieved timestamp should match the one that was set." + + +@pytest.mark.usefixtures("config_only_cpp") +def test_filter_timestamp(): + # Create a ControlMessage instance + msg = messages.ControlMessage() + + # Setup test data + group = "group1" + timestamp1 = datetime.datetime.now() + timestamp2 = timestamp1 + datetime.timedelta(seconds=1) + msg.set_timestamp(f"{group}::key1", timestamp1) + msg.set_timestamp(f"{group}::key2", timestamp2) + + # Use a regex that matches both keys + result = msg.filter_timestamp(f"{group}::key.*") + + # Assert both keys are in the result and have correct timestamps + assert len(result) == 2, "Both keys should be present in the result." + assert result[f"{group}::key1"] == timestamp1, "The timestamp for key1 should match." + assert result[f"{group}::key2"] == timestamp2, "The timestamp for key2 should match." + + +@pytest.mark.usefixtures("config_only_cpp") +def test_get_timestamp_fail_if_nonexist(): + # Create a ControlMessage instance + msg = messages.ControlMessage() + + # Setup test data + key = "nonexistent_key" + + # Attempt to get a timestamp for a non-existent key, expecting failure + with pytest.raises(ValueError) as exc_info: + msg.get_timestamp(key, True) + assert str(exc_info.value) == "Timestamp for the specified key does not exist." + + +# Test setting and getting tensors with cupy arrays +@pytest.mark.usefixtures("config_only_cpp") +def test_tensors_setting_and_getting(): + data = {"input_ids": cp.array([1, 2, 3]), "input_mask": cp.array([1, 1, 1]), "segment_ids": cp.array([0, 0, 1])} + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=data["input_ids"].shape[0]) + tensor_memory.set_tensors(data) + + message.tensors(tensor_memory) + + retrieved_tensors = message.tensors() + assert retrieved_tensors.count == data["input_ids"].shape[0], "Tensor count mismatch." + + for key, val in data.items(): + assert cp.allclose(retrieved_tensors.get_tensor(key), val), f"Mismatch in tensor data for {key}." + + +# Test retrieving tensor names and checking specific tensor existence +@pytest.mark.usefixtures("config_only_cpp") +def test_tensor_names_and_existence(): + tokenized_data = { + "input_ids": cp.array([1, 2, 3]), "input_mask": cp.array([1, 1, 1]), "segment_ids": cp.array([0, 0, 1]) + } + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=tokenized_data["input_ids"].shape[0], tensors=tokenized_data) + + message.tensors(tensor_memory) + retrieved_tensors = message.tensors() + + for key in tokenized_data: + assert key in retrieved_tensors.tensor_names, f"Tensor {key} should be listed in tensor names." + assert retrieved_tensors.has_tensor(key), f"Tensor {key} should exist." + + +# Test manipulating tensors after retrieval +@pytest.mark.usefixtures("config_only_cpp") +def test_tensor_manipulation_after_retrieval(): + tokenized_data = { + "input_ids": cp.array([1, 2, 3]), "input_mask": cp.array([1, 1, 1]), "segment_ids": cp.array([0, 0, 1]) + } + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=3, tensors=tokenized_data) + + message.tensors(tensor_memory) + + retrieved_tensors = message.tensors() + new_tensor = cp.array([4, 5, 6]) + retrieved_tensors.set_tensor("new_tensor", new_tensor) + + assert cp.allclose(retrieved_tensors.get_tensor("new_tensor"), new_tensor), "New tensor data mismatch." + + +# Assuming there's functionality to update all tensors at once +@pytest.mark.usefixtures("config_only_cpp") +def test_tensor_update(): + tokenized_data = { + "input_ids": cp.array([1, 2, 3]), "input_mask": cp.array([1, 1, 1]), "segment_ids": cp.array([0, 0, 1]) + } + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=3, tensors=tokenized_data) + + message.tensors(tensor_memory) + + # Update tensors with new data + new_tensors = { + "input_ids": cp.array([4, 5, 6]), "input_mask": cp.array([1, 0, 1]), "segment_ids": cp.array([1, 1, 0]) + } + + tensor_memory.set_tensors(new_tensors) + + updated_tensors = message.tensors() + + for key, val in new_tensors.items(): + assert cp.allclose(updated_tensors.get_tensor(key), val), f"Mismatch in updated tensor data for {key}." + + +@pytest.mark.usefixtures("config_only_cpp") +def test_update_individual_tensor(): + initial_data = {"input_ids": cp.array([1, 2, 3]), "input_mask": cp.array([1, 1, 1])} + update_data = {"input_ids": cp.array([4, 5, 6])} + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=3, tensors=initial_data) + message.tensors(tensor_memory) + + # Update one tensor and retrieve all to ensure update integrity + tensor_memory.set_tensor("input_ids", update_data["input_ids"]) + retrieved_tensors = message.tensors() + + # Check updated tensor + assert cp.allclose(retrieved_tensors.get_tensor("input_ids"), + update_data["input_ids"]), "Input IDs update mismatch." + # Ensure other tensor remains unchanged + assert cp.allclose(retrieved_tensors.get_tensor("input_mask"), + initial_data["input_mask"]), "Input mask should remain unchanged after updating input_ids." + + +@pytest.mark.usefixtures("config_only_cpp") +def test_behavior_with_empty_tensors(): + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=0) + message.tensors(tensor_memory) + + retrieved_tensors = message.tensors() + assert retrieved_tensors.count == 0, "Tensor count should be 0 for empty tensor memory." + assert len(retrieved_tensors.tensor_names) == 0, "There should be no tensor names for empty tensor memory." + + +@pytest.mark.usefixtures("config_only_cpp") +def test_consistency_after_multiple_operations(): + initial_data = {"input_ids": cp.array([1, 2, 3]), "input_mask": cp.array([1, 1, 1])} + message = messages.ControlMessage() + tensor_memory = TensorMemory(count=3, tensors=initial_data) + message.tensors(tensor_memory) + + # Update a tensor + tensor_memory.set_tensor("input_ids", cp.array([4, 5, 6])) + # Remove another tensor + # Add a new tensor + new_tensor = {"new_tensor": cp.array([7, 8, 9])} + tensor_memory.set_tensor("new_tensor", new_tensor["new_tensor"]) + + retrieved_tensors = message.tensors() + assert retrieved_tensors.count == 3, "Tensor count mismatch after multiple operations." + assert cp.allclose(retrieved_tensors.get_tensor("input_ids"), + cp.array([4, 5, 6])), "Mismatch in input_ids after update." + assert cp.allclose(retrieved_tensors.get_tensor("new_tensor"), + new_tensor["new_tensor"]), "New tensor data mismatch." diff --git a/tests/utils/test_control_message_utils.py b/tests/utils/test_control_message_utils.py index 95fee92a73..71be9c7074 100644 --- a/tests/utils/test_control_message_utils.py +++ b/tests/utils/test_control_message_utils.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - from morpheus.messages import ControlMessage from morpheus.utils.control_message_utils import CMDefaultFailureContextManager from morpheus.utils.control_message_utils import cm_set_failure @@ -37,7 +35,7 @@ def test_skip_forward_on_cm_failed(): # pylint: disable=unused-argument @cm_skip_processing_if_failed - def dummy_func(control_message, *args, **kwargs): + def dummy_func(cm, *args, **kwargs): return "Function Executed" assert dummy_func(control_message) == control_message @@ -50,8 +48,8 @@ def test_cm_default_failure_context_manager_no_exception(): control_message = ControlMessage() with CMDefaultFailureContextManager(control_message): pass - with pytest.raises(RuntimeError): - control_message.get_metadata("cm_failed") + + assert control_message.get_metadata("cm_failed") is None def test_cm_default_failure_context_manager_with_exception():