Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor_count property for ControlMessage #2078

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions examples/log_parsing/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ class TritonInferenceLogParsing(TritonInferenceWorker):
"""

def build_output_message(self, msg: ControlMessage) -> ControlMessage:
seq_ids = cp.zeros((msg.tensors().count, 3), dtype=cp.uint32)
seq_ids[:, 0] = cp.arange(0, msg.tensors().count, dtype=cp.uint32)
seq_ids = cp.zeros((msg.tensor_count(), 3), dtype=cp.uint32)
seq_ids[:, 0] = cp.arange(0, msg.tensor_count(), dtype=cp.uint32)
seq_ids[:, 2] = msg.tensors().get_tensor('seq_ids')[:, 2]

memory = TensorMemory(
count=msg.tensors().count,
count=msg.tensor_count(),
tensors={
'confidences': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
'labels': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
'input_ids': cp.zeros((msg.tensors().count, msg.tensors().get_tensor('input_ids').shape[1])),
'confidences': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])),
'labels': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])),
'input_ids': cp.zeros((msg.tensor_count(), msg.tensors().get_tensor('input_ids').shape[1])),
'seq_ids': seq_ids
})

Expand Down Expand Up @@ -154,19 +154,19 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
seq_offset = seq_ids[0, 0].item()
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset

input_ids[batch_offset:inf.tensors().count + batch_offset, :] = inf.tensors().get_tensor('input_ids')
out_seq_ids[batch_offset:inf.tensors().count + batch_offset, :] = seq_ids
input_ids[batch_offset:inf.tensor_count() + batch_offset, :] = inf.tensors().get_tensor('input_ids')
out_seq_ids[batch_offset:inf.tensor_count() + batch_offset, :] = seq_ids

resp_confidences = res.get_tensor('confidences')
resp_labels = res.get_tensor('labels')

# Two scenarios:
if (inf.payload().count == inf.tensors().count):
if (inf.payload().count == inf.tensor_count()):
assert seq_count == res.count
confidences[batch_offset:inf.tensors().count + batch_offset, :] = resp_confidences
labels[batch_offset:inf.tensors().count + batch_offset, :] = resp_labels
confidences[batch_offset:inf.tensor_count() + batch_offset, :] = resp_confidences
labels[batch_offset:inf.tensor_count() + batch_offset, :] = resp_labels
else:
assert inf.tensors().count == res.count
assert inf.tensor_count() == res.count

mess_ids = seq_ids[:, 0].get().tolist()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "morpheus/export.h" // for MORPHEUS_EXPORT
#include "morpheus/messages/meta.hpp" // for MessageMeta
#include "morpheus/types.hpp"
#include "morpheus/utilities/json_types.hpp" // for json_t

#include <pybind11/pytypes.h> // for object, dict, list
Expand Down Expand Up @@ -197,6 +198,13 @@ class MORPHEUS_EXPORT ControlMessage
*/
void tensors(const std::shared_ptr<TensorMemory>& tensor_memory);

/**
* @brief Get the length of tensors in the tensor memory.
*
* @return The length of tensors in the tensor memory.
*/
TensorIndex tensor_count();

/**
* @brief Get the type of task associated with the control message.
* @return An enum value indicating the task type.
Expand Down Expand Up @@ -262,6 +270,7 @@ class MORPHEUS_EXPORT ControlMessage
ControlMessageType m_cm_type{ControlMessageType::NONE};
std::shared_ptr<MessageMeta> m_payload{nullptr};
std::shared_ptr<TensorMemory> m_tensors{nullptr};
TensorIndex m_tensor_count{0};

morpheus::utilities::json_t m_tasks{};
morpheus::utilities::json_t m_config{};
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ControlMessage():
def task_type(self) -> ControlMessageType: ...
@typing.overload
def task_type(self, task_type: ControlMessageType) -> None: ...
def tensor_count(self) -> int: ...
@typing.overload
def tensors(self) -> TensorMemory: ...
@typing.overload
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ PYBIND11_MODULE(messages, _module)
py::arg("meta"))
.def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors))
.def("tensors", pybind11::overload_cast<const std::shared_ptr<TensorMemory>&>(&ControlMessage::tensors))
.def("tensor_count", &ControlMessage::tensor_count)
.def("remove_task", &ControlMessage::remove_task, py::arg("task_type"))
.def("set_metadata", &ControlMessage::set_metadata, py::arg("key"), py::arg("value"))
.def("task_type", pybind11::overload_cast<>(&ControlMessage::task_type))
Expand Down
15 changes: 11 additions & 4 deletions python/morpheus/morpheus/_lib/src/messages/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ ControlMessage::ControlMessage(const morpheus::utilities::json_t& _config) :

ControlMessage::ControlMessage(const ControlMessage& other)
{
m_cm_type = other.m_cm_type;
m_payload = other.m_payload;
m_tensors = other.m_tensors;
m_cm_type = other.m_cm_type;
m_payload = other.m_payload;
m_tensors = other.m_tensors;
m_tensor_count = other.m_tensor_count;

m_config = other.m_config;
m_tasks = other.m_tasks;
Expand Down Expand Up @@ -256,7 +257,13 @@ std::shared_ptr<TensorMemory> ControlMessage::tensors()

void ControlMessage::tensors(const std::shared_ptr<TensorMemory>& tensors)
{
m_tensors = tensors;
m_tensors = tensors;
m_tensor_count = tensors ? tensors->count : 0;
}

TensorIndex ControlMessage::tensor_count()
{
return m_tensor_count;
}

ControlMessageType ControlMessage::task_type()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static ShapeType get_seq_ids(const std::shared_ptr<ControlMessage>& message)
auto seq_ids = message->tensors()->get_tensor("seq_ids");
const auto item_size = seq_ids.dtype().item_size();

ShapeType host_seq_ids(message->tensors()->count);
ShapeType host_seq_ids(message->tensor_count());
MRC_CHECK_CUDA(cudaMemcpy2D(host_seq_ids.data(),
item_size,
seq_ids.data(),
Expand All @@ -82,7 +82,7 @@ static TensorObject get_tensor(std::shared_ptr<ControlMessage> message, std::str

static void reduce_outputs(std::shared_ptr<ControlMessage> const& message, TensorMap& output_tensors)
{
if (message->payload()->count() == message->tensors()->count)
if (message->payload()->count() == message->tensor_count())
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "morpheus/messages/control.hpp" // for ControlMessage
#include "morpheus/messages/memory/tensor_memory.hpp" // for TensorMemory
#include "morpheus/messages/meta.hpp" // for MessageMeta
#include "morpheus/types.hpp"
#include "morpheus/utilities/json_types.hpp" // for PythonByteContainer

#include <gtest/gtest.h> // for Message, TestPartResult, AssertionResult, TestInfo
Expand Down Expand Up @@ -298,7 +299,8 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory)
{
auto msg = ControlMessage();

auto tensorMemory = std::make_shared<TensorMemory>(0);
TensorIndex count = 5;
auto tensorMemory = std::make_shared<TensorMemory>(count);
// Optionally, modify tensorMemory here if it has any mutable state to test

// Set the tensor memory
Expand All @@ -309,6 +311,7 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory)

// Verify that the retrieved tensor memory matches what was set
EXPECT_EQ(tensorMemory, retrievedTensorMemory);
EXPECT_EQ(count, msg.tensor_count());
}

// Test setting TensorMemory to nullptr
Expand Down
5 changes: 5 additions & 0 deletions python/morpheus/morpheus/messages/control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, config_or_message: typing.Union["ControlMessage", dict] = Non

self._payload: MessageMeta = None
self._tensors: TensorMemory = None
self._tensor_count: int = 0

self._tasks: dict[str, deque] = defaultdict(deque)
self._timestamps: dict[str, datetime] = {}
Expand Down Expand Up @@ -147,9 +148,13 @@ def payload(self, payload: MessageMeta = None) -> MessageMeta | None:
def tensors(self, tensors: TensorMemory = None) -> TensorMemory | None:
if tensors is not None:
self._tensors = tensors
self._tensor_count = tensors.count

return self._tensors

def tensor_count(self) -> int:
return self._tensor_count

def task_type(self, new_task_type: ControlMessageType = None) -> ControlMessageType:
if new_task_type is not None:
self._type = new_task_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def __init__(self, inf_queue: ProducerConsumerQueue, c: Config):
self._seq_length = c.feature_length

def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
return (msg.tensors().count, self._seq_length)
return (msg.tensor_count(), self._seq_length)

def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]):

def tmp(batch: ControlMessage, f):
count = batch.tensors().count
count = batch.tensor_count()
f(TensorMemory(
count=count,
tensors={'probs': cp.zeros((count, self._seq_length), dtype=cp.float32)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def set_output_fut(resp: TensorMemory, inner_batch, batch_future: mrc.Future):
nonlocal outstanding_requests
nonlocal batch_offset
mess = self._convert_one_response(output_message, inner_batch, resp, batch_offset)
batch_offset += inner_batch.tensors().count
batch_offset += inner_batch.tensor_count()
outstanding_requests -= 1

batch_future.set_result(mess)
Expand Down Expand Up @@ -359,13 +359,13 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens
seq_count = seq_ids[-1, 0].item() + 1 - seq_offset

# Two scenarios:
if (inf.payload().count == inf.tensors().count):
if (inf.payload().count == inf.tensor_count()):
assert seq_count == res.count

# In message and out message have same count. Just use probs as is
probs[seq_offset:seq_offset + seq_count, :] = resp_probs
else:
assert inf.tensors().count == res.count
assert inf.tensor_count() == res.count

mess_ids = seq_ids[:, 0].get().tolist()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def init(self):
def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
input_ids = msg.tensors().get_tensor("input_ids")
input_mask = msg.tensors().get_tensor("input_mask")
count = msg.tensors().count
count = msg.tensor_count()
# If we haven't cached the output dimension, do that here
if (not self._output_size):
test_intput_ids_shape = (self._max_batch_size, ) + input_ids.shape[1:]
Expand All @@ -91,7 +91,7 @@ def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]):
input_ids = batch.tensors().get_tensor("input_ids")
input_mask = batch.tensors().get_tensor("input_mask")
count = batch.tensors().count
count = batch.tensor_count()

# convert from cupy to torch tensor using dlpack
input_ids = from_dlpack(input_ids.astype(cp.float).toDlpack()).type(torch.long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def create_wrapper():
raise ex

def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple:
return (msg.tensors().count, self._outputs[list(self._outputs.keys())[0]].shape[1])
return (msg.tensor_count(), self._outputs[list(self._outputs.keys())[0]].shape[1])

def _build_response(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _calc_drift(self, msg: ControlMessage):
for label in range(len(self._labels), shifted.shape[1]):
self._labels.append(str(label))

count = msg.tensors().count
count = msg.tensor_count()

for i in list(range(0, count, self._batch_size)):
start = i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def check_inf_message(msg: ControlMessage,
expected_input__0: cp.ndarray):
assert isinstance(msg, ControlMessage)
assert msg.payload().count == expected_mess_count
assert msg.tensors().count == expected_count
assert msg.tensor_count() == expected_count

df = msg.payload().get_data()
assert 'flow_id' in df
Expand Down
4 changes: 2 additions & 2 deletions tests/examples/log_parsing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_log_parsing_triton_inference_log_parsing_build_output_message(config: C
msg = worker.build_output_message(input_msg)
assert msg.payload() is input_msg.payload()
assert msg.payload().count == mess_count
assert msg.tensors().count == count
assert msg.tensor_count() == count

assert set(msg.tensors().tensor_names).issuperset(('confidences', 'labels', 'input_ids', 'seq_ids'))
assert msg.tensors().get_tensor('confidences').shape == (count, 2)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.Lis
assert isinstance(output_msg, ControlMessage)
assert output_msg.payload() is input_inf.payload()
assert output_msg.payload().count == mess_count
assert output_msg.tensors().count == count
assert output_msg.tensor_count() == count

assert (output_msg.tensors().get_tensor('seq_ids') == input_inf.tensors().get_tensor('seq_ids')).all()
assert (output_msg.tensors().get_tensor('input_ids') == input_inf.tensors().get_tensor('input_ids')).all()
Expand Down
16 changes: 10 additions & 6 deletions tests/morpheus/messages/test_control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _verify_metadata(msg: messages.ControlMessage, metadata: dict):

@pytest.mark.gpu_and_cpu_mode
def test_control_message_init(dataset: DatasetManager):
# Explicitly performing copies of the metadata, config and the dataframe, to ensure tha the original data is not
# Explicitly performing copies of the metadata, config and the dataframe, to ensure that the original data is not
# being modified in place in some way.
msg = messages.ControlMessage()
assert msg.get_metadata() == {} # pylint: disable=use-implicit-booleaness-not-comparison
Expand Down Expand Up @@ -318,9 +318,9 @@ def test_tensors_setting_and_getting(config: Config):

message.tensors(tensor_memory)

retrieved_tensors = message.tensors()
assert retrieved_tensors.count == data["input_ids"].shape[0], "Tensor count mismatch."
assert message.tensor_count() == data["input_ids"].shape[0], "Tensor count mismatch."

retrieved_tensors = message.tensors()
for key, val in data.items():
assert array_pkg.allclose(retrieved_tensors.get_tensor(key), val), f"Mismatch in tensor data for {key}."

Expand Down Expand Up @@ -363,6 +363,7 @@ def test_tensor_manipulation_after_retrieval(config: Config):
new_tensor = array_pkg.array([4, 5, 6])
retrieved_tensors.set_tensor("new_tensor", new_tensor)

assert message.tensor_count() == tokenized_data["input_ids"].shape[0], "Tensor count mismatch"
assert array_pkg.allclose(retrieved_tensors.get_tensor("new_tensor"), new_tensor), "New tensor data mismatch."


Expand All @@ -389,8 +390,9 @@ def test_tensor_update(config: Config):

tensor_memory.set_tensors(new_tensors)

updated_tensors = message.tensors()
assert message.tensor_count() == tokenized_data["input_ids"].shape[0], "Tensor count mismatch"

updated_tensors = message.tensors()
for key, val in new_tensors.items():
assert array_pkg.allclose(updated_tensors.get_tensor(key), val), f"Mismatch in updated tensor data for {key}."

Expand All @@ -408,6 +410,7 @@ def test_update_individual_tensor(config: Config):
tensor_memory.set_tensor("input_ids", update_data["input_ids"])
retrieved_tensors = message.tensors()

assert message.tensor_count() == initial_data["input_ids"].shape[0], "Tensor count mismatch"
# Check updated tensor
assert array_pkg.allclose(retrieved_tensors.get_tensor("input_ids"),
update_data["input_ids"]), "Input IDs update mismatch."
Expand All @@ -422,8 +425,9 @@ def test_behavior_with_empty_tensors():
tensor_memory = TensorMemory(count=0)
message.tensors(tensor_memory)

assert message.tensor_count() == 0, "Tensor count should be 0 for empty tensor memory."

retrieved_tensors = message.tensors()
assert retrieved_tensors.count == 0, "Tensor count should be 0 for empty tensor memory."
assert len(retrieved_tensors.tensor_names) == 0, "There should be no tensor names for empty tensor memory."


Expand All @@ -442,8 +446,8 @@ def test_consistency_after_multiple_operations(config: Config):
new_tensor = {"new_tensor": array_pkg.array([7, 8, 9])}
tensor_memory.set_tensor("new_tensor", new_tensor["new_tensor"])

assert message.tensor_count() == initial_data["input_ids"].shape[0], "Tensor count mismatch."
retrieved_tensors = message.tensors()
assert retrieved_tensors.count == 3, "Tensor count mismatch after multiple operations."
assert array_pkg.allclose(retrieved_tensors.get_tensor("input_ids"),
array_pkg.array([4, 5, 6])), "Mismatch in input_ids after update."
assert array_pkg.allclose(retrieved_tensors.get_tensor("new_tensor"),
Expand Down
2 changes: 1 addition & 1 deletion tests/morpheus/stages/test_inference_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_convert_one_response():
cm = InferenceStageT._convert_one_response(output, inf, res, batch_offset)
assert cm.payload() == inf.payload()
assert cm.payload().count == 4
assert cm.tensors().count == 4
assert cm.tensor_count() == 4
assert cp.all(cm.tensors().get_tensor("probs") == res.get_tensor("probs"))

# Test for the second branch
Expand Down
Loading