Skip to content

Commit

Permalink
ControlMessage improvements (#1511)
Browse files Browse the repository at this point in the history
Resolves #1502 

Adds the ability to attach TensorMemory to ControlMessages
Improves get/set/list metadata functions
Adds the ability to attach grouped/keyed timestamps

Authors:
  - Devin Robison (https://github.com/drobison00)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1511
  • Loading branch information
drobison00 authored Feb 16, 2024
1 parent 5fd661b commit aa8d42e
Show file tree
Hide file tree
Showing 10 changed files with 858 additions and 82 deletions.
2 changes: 1 addition & 1 deletion examples/llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
callback=parse_log_level,
help="Specify the logging level to use.")
@click.option('--use_cpp',
default=True,
default=False,
type=bool,
help=("Whether or not to use C++ node and message types or to prefer python. "
"Only use as a last resort if bugs are encountered"))
Expand Down
241 changes: 220 additions & 21 deletions morpheus/_lib/include/morpheus/messages/control.hpp

Large diffs are not rendered by default.

20 changes: 18 additions & 2 deletions morpheus/_lib/messages/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,19 @@ class ControlMessage():
@typing.overload
def config(self, config: dict) -> None: ...
def copy(self) -> ControlMessage: ...
def get_metadata(self, key: typing.Optional[str] = None) -> object: ...
def filter_timestamp(self, regex_filter: str) -> dict:
"""
Retrieve timestamps matching a regex filter within a given group.
"""
def get_metadata(self, key: object = None, default_value: object = None) -> object: ...
def get_tasks(self) -> dict: ...
def get_timestamp(self, key: str, fail_if_nonexist: bool = False) -> object:
"""
Retrieve the timestamp for a given group and key. Returns None if the timestamp does not exist and fail_if_nonexist is False.
"""
def has_metadata(self, key: str) -> bool: ...
def has_task(self, task_type: str) -> bool: ...
def list_metadata(self) -> dict: ...
def list_metadata(self) -> list: ...
@typing.overload
def payload(self) -> MessageMeta: ...
@typing.overload
Expand All @@ -62,10 +70,18 @@ class ControlMessage():
def payload(self, meta: object) -> None: ...
def remove_task(self, task_type: str) -> dict: ...
def set_metadata(self, key: str, value: object) -> None: ...
def set_timestamp(self, key: str, timestamp: object) -> None:
"""
Set a timestamp for a given key and group.
"""
@typing.overload
def task_type(self) -> ControlMessageType: ...
@typing.overload
def task_type(self, task_type: ControlMessageType) -> None: ...
@typing.overload
def tensors(self) -> TensorMemory: ...
@typing.overload
def tensors(self, arg0: TensorMemory) -> None: ...
pass
class ControlMessageType():
"""
Expand Down
25 changes: 22 additions & 3 deletions morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ PYBIND11_MODULE(messages, _module)
.value("NONE", ControlMessageType::INFERENCE)
.value("TRAINING", ControlMessageType::TRAINING);

// TODO(Devin): Circle back on return value policy choices
py::class_<ControlMessage, std::shared_ptr<ControlMessage>>(_module, "ControlMessage")
.def(py::init<>())
.def(py::init(py::overload_cast<py::dict&>(&ControlMessageProxy::create)))
Expand All @@ -369,17 +368,37 @@ PYBIND11_MODULE(messages, _module)
py::arg("config"))
.def("config", pybind11::overload_cast<ControlMessage&>(&ControlMessageProxy::config))
.def("copy", &ControlMessageProxy::copy)
.def("get_metadata", &ControlMessageProxy::get_metadata, py::arg("key") = py::none())
.def("get_metadata",
&ControlMessageProxy::get_metadata,
py::arg("key") = py::none(),
py::arg("default_value") = py::none())
.def("get_tasks", &ControlMessageProxy::get_tasks)
.def("filter_timestamp",
py::overload_cast<ControlMessage&, const std::string&>(&ControlMessageProxy::filter_timestamp),
"Retrieve timestamps matching a regex filter within a given group.",
py::arg("regex_filter"))
.def("get_timestamp",
py::overload_cast<ControlMessage&, const std::string&, bool>(&ControlMessageProxy::get_timestamp),
"Retrieve the timestamp for a given group and key. Returns None if the timestamp does not exist and "
"fail_if_nonexist is False.",
py::arg("key"),
py::arg("fail_if_nonexist") = false)
.def("set_timestamp",
&ControlMessageProxy::set_timestamp,
"Set a timestamp for a given key and group.",
py::arg("key"),
py::arg("timestamp"))
.def("has_metadata", &ControlMessage::has_metadata, py::arg("key"))
.def("has_task", &ControlMessage::has_task, py::arg("task_type"))
.def("list_metadata", &ControlMessageProxy::list_metadata)
.def("payload", pybind11::overload_cast<>(&ControlMessage::payload), py::return_value_policy::move)
.def("payload", pybind11::overload_cast<>(&ControlMessage::payload))
.def("payload", pybind11::overload_cast<const std::shared_ptr<MessageMeta>&>(&ControlMessage::payload))
.def(
"payload",
pybind11::overload_cast<ControlMessage&, const py::object&>(&ControlMessageProxy::payload_from_python_meta),
py::arg("meta"))
.def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors))
.def("tensors", pybind11::overload_cast<const std::shared_ptr<TensorMemory>&>(&ControlMessage::tensors))
.def("remove_task", &ControlMessageProxy::remove_task, py::arg("task_type"))
.def("set_metadata", &ControlMessageProxy::set_metadata, py::arg("key"), py::arg("value"))
.def("task_type", pybind11::overload_cast<>(&ControlMessage::task_type))
Expand Down
162 changes: 143 additions & 19 deletions morpheus/_lib/src/messages/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
#include "morpheus/messages/meta.hpp"

#include <glog/logging.h>
#include <pybind11/chrono.h> // IWYU pragma: keep
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pymrc/utils.hpp>

#include <chrono>
#include <optional>
#include <ostream>
#include <regex>
#include <stdexcept>
#include <utility>

namespace py = pybind11;

Expand Down Expand Up @@ -58,7 +63,6 @@ const nlohmann::json& ControlMessage::config() const

void ControlMessage::add_task(const std::string& task_type, const nlohmann::json& task)
{
// TODO(Devin) Schema check
VLOG(20) << "Adding task of type " << task_type << " to control message" << task.dump(4);
auto _task_type = s_task_type_map.contains(task_type) ? s_task_type_map[task_type] : ControlMessageType::NONE;

Expand All @@ -85,9 +89,9 @@ const nlohmann::json& ControlMessage::get_tasks() const
return m_tasks;
}

const nlohmann::json ControlMessage::list_metadata() const
std::vector<std::string> ControlMessage::list_metadata() const
{
nlohmann::json key_list = nlohmann::json::array();
std::vector<std::string> key_list{};

for (auto it = m_config["metadata"].begin(); it != m_config["metadata"].end(); ++it)
{
Expand All @@ -112,17 +116,31 @@ bool ControlMessage::has_metadata(const std::string& key) const
return m_config["metadata"].contains(key);
}

const nlohmann::json& ControlMessage::get_metadata() const
nlohmann::json ControlMessage::get_metadata() const
{
return m_config["metadata"];
auto metadata = m_config["metadata"];

return metadata;
}

const nlohmann::json ControlMessage::get_metadata(const std::string& key) const
nlohmann::json ControlMessage::get_metadata(const std::string& key, bool fail_on_nonexist) const
{
return m_config["metadata"].at(key);
// Assuming m_metadata is a std::map<std::string, nlohmann::json> storing metadata
auto metadata = m_config["metadata"];
auto it = metadata.find(key);
if (it != metadata.end())
{
return metadata.at(key);
}
else if (fail_on_nonexist)
{
throw std::runtime_error("Metadata key does not exist: " + key);
}

return {};
}

const nlohmann::json ControlMessage::remove_task(const std::string& task_type)
nlohmann::json ControlMessage::remove_task(const std::string& task_type)
{
auto& task_set = m_tasks.at(task_type);
auto iter_task = task_set.begin();
Expand All @@ -138,6 +156,43 @@ const nlohmann::json ControlMessage::remove_task(const std::string& task_type)
throw std::runtime_error("No tasks of type " + task_type + " found");
}

void ControlMessage::set_timestamp(const std::string& key, time_point_t timestamp_ns)
{
// Insert or update the timestamp in the map
m_timestamps[key] = timestamp_ns;
}

std::map<std::string, time_point_t> ControlMessage::filter_timestamp(const std::string& regex_filter)
{
std::map<std::string, time_point_t> matching_timestamps;
std::regex filter(regex_filter);

for (const auto& [key, timestamp] : m_timestamps)
{
// Check if the key matches the regex
if (std::regex_search(key, filter))
{
matching_timestamps[key] = timestamp;
}
}

return matching_timestamps;
}

std::optional<time_point_t> ControlMessage::get_timestamp(const std::string& key, bool fail_if_nonexist)
{
auto it = m_timestamps.find(key);
if (it != m_timestamps.end())
{
return it->second; // Return the found timestamp
}
else if (fail_if_nonexist)
{
throw std::runtime_error("Timestamp for the specified key does not exist.");
}
return std::nullopt;
}

void ControlMessage::config(const nlohmann::json& config)
{
if (config.contains("type"))
Expand Down Expand Up @@ -173,10 +228,6 @@ void ControlMessage::config(const nlohmann::json& config)

std::shared_ptr<MessageMeta> ControlMessage::payload()
{
// auto temp = std::move(m_payload);
// TODO(Devin): Decide if we copy or steal the payload
// m_payload = nullptr;

return m_payload;
}

Expand All @@ -185,6 +236,16 @@ void ControlMessage::payload(const std::shared_ptr<MessageMeta>& payload)
m_payload = payload;
}

std::shared_ptr<TensorMemory> ControlMessage::tensors()
{
return m_tensors;
}

void ControlMessage::tensors(const std::shared_ptr<TensorMemory>& tensors)
{
m_tensors = tensors;
}

ControlMessageType ControlMessage::task_type()
{
return m_cm_type;
Expand Down Expand Up @@ -236,26 +297,89 @@ py::dict ControlMessageProxy::config(ControlMessage& self)
return dict;
}

py::object ControlMessageProxy::get_metadata(ControlMessage& self, std::optional<std::string> const& key)
py::object ControlMessageProxy::get_metadata(ControlMessage& self,
const py::object& key,
pybind11::object default_value)
{
if (key == std::nullopt)
if (key.is_none())
{
auto metadata = self.get_metadata();
return mrc::pymrc::cast_from_json(metadata);
}

auto value = self.get_metadata(py::cast<std::string>(key), false);
if (value.empty())
{
return mrc::pymrc::cast_from_json(self.get_metadata());
return default_value;
}

return mrc::pymrc::cast_from_json(self.get_metadata(key.value()));
return mrc::pymrc::cast_from_json(value);
}

void ControlMessageProxy::set_metadata(ControlMessage& self, const std::string& key, pybind11::object& value)
{
self.set_metadata(key, mrc::pymrc::cast_from_pyobject(value));
}

py::dict ControlMessageProxy::list_metadata(ControlMessage& self)
py::list ControlMessageProxy::list_metadata(ControlMessage& self)
{
auto dict = mrc::pymrc::cast_from_json(self.list_metadata());
auto keys = self.list_metadata();
py::list py_keys;
for (const auto& key : keys)
{
py_keys.append(py::str(key));
}
return py_keys;
}

return dict;
py::dict ControlMessageProxy::filter_timestamp(ControlMessage& self, const std::string& regex_filter)
{
auto cpp_map = self.filter_timestamp(regex_filter);
py::dict py_dict;
for (const auto& [key, timestamp] : cpp_map)
{
// Directly use the timestamp as datetime.datetime in Python
py_dict[py::str(key)] = timestamp;
}
return py_dict;
}

// Get a specific timestamp and return it as datetime.datetime or None
py::object ControlMessageProxy::get_timestamp(ControlMessage& self, const std::string& key, bool fail_if_nonexist)
{
try
{
auto timestamp_opt = self.get_timestamp(key, fail_if_nonexist);
if (timestamp_opt)
{
// Directly return the timestamp as datetime.datetime in Python
return py::cast(*timestamp_opt);
}

return py::none();
} catch (const std::runtime_error& e)
{
if (fail_if_nonexist)
{
throw py::value_error(e.what());
}
return py::none();
}
}

// Set a timestamp using a datetime.datetime object from Python
void ControlMessageProxy::set_timestamp(ControlMessage& self, const std::string& key, py::object timestamp_ns)
{
if (!py::isinstance<py::none>(timestamp_ns))
{
// Convert Python datetime.datetime to std::chrono::system_clock::time_point before setting
auto _timestamp_ns = timestamp_ns.cast<time_point_t>();
self.set_timestamp(key, _timestamp_ns);
}
else
{
throw std::runtime_error("Timestamp cannot be None");
}
}

void ControlMessageProxy::config(ControlMessage& self, py::dict& config)
Expand Down
Loading

0 comments on commit aa8d42e

Please sign in to comment.