Skip to content

Commit

Permalink
fix python
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Apr 5, 2024
1 parent 1b08bd3 commit 1d94c31
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/_utils/stages/conv_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import cudf

# pylint: disable=morpheus-incorrect-lib-from-import
from morpheus._lib.messages import TensorMemory as CppTensorMemory
from morpheus.cli.register_stage import register_stage
from morpheus.config import Config
Expand All @@ -36,7 +37,8 @@
@register_stage("unittest-conv-msg", ignore_args=["expected_data"])
class ConvMsg(SinglePortStage):
"""
Simple test stage to convert a MultiMessage to a MultiResponseProbsMessage, or a ControlMessage to a ControlMessage with probs tensor.
Simple test stage to convert a MultiMessage to a MultiResponseProbsMessage,
or a ControlMessage to a ControlMessage with probs tensor.
Basically a cheap replacement for running an inference stage.
Setting `message_type` to determine the input type of the stage.
Expand Down Expand Up @@ -103,12 +105,13 @@ def _conv_message(
probs = cp.zeros([len(df), 3], 'float')
else:
probs = cp.array(df.values, dtype=self._probs_type, copy=True, order=self._order)

if self._message_type == ControlMessage:
message.tensors(CppTensorMemory(count=len(probs), tensors={'probs': probs}))
return message
if self._message_type == MultiResponseMessage:
memory = ResponseMemory(count=len(probs), tensors={'probs': probs})
return MultiResponseMessage.from_message(message, memory=memory)
# if self._message_type == MultiResponseMessage:
memory = ResponseMemory(count=len(probs), tensors={'probs': probs})
return MultiResponseMessage.from_message(message, memory=memory)

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
node = builder.make_node(self.unique_name, ops.map(self._conv_message))
Expand Down
1 change: 1 addition & 0 deletions tests/test_add_classifications_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cudf

from _utils.dataset_manager import DatasetManager
# pylint: disable=morpheus-incorrect-lib-from-import
from morpheus._lib.messages import TensorMemory as CppTensorMemory
from morpheus.config import Config
from morpheus.messages import ControlMessage
Expand Down
1 change: 1 addition & 0 deletions tests/test_add_scores_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cudf

from _utils.dataset_manager import DatasetManager
# pylint: disable=morpheus-incorrect-lib-from-import
from morpheus._lib.messages import TensorMemory as CppTensorMemory
from morpheus.config import Config
from morpheus.messages import ControlMessage
Expand Down

0 comments on commit 1d94c31

Please sign in to comment.