Skip to content

Commit

Permalink
initial test for test_multi_processing_stage
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Aug 24, 2024
1 parent 21fd8e8 commit 46865ae
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 20 deletions.
48 changes: 28 additions & 20 deletions python/morpheus/morpheus/stages/general/multi_processing_stage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
import typing

from morpheus.pipeline.stage_schema import StageSchema
Expand All @@ -12,12 +12,22 @@
OutputT = typing.TypeVar('OutputT')


class MultiProcessingBaseStage(SinglePortStage, typing.Generic[InputT, OutputT]):
class MultiProcessingBaseStage(SinglePortStage, ABC, typing.Generic[InputT, OutputT]):

def __init__(self, *, c: Config, process_pool_usage: float):
def __init__(self, *, c: Config, process_pool_usage: float, max_in_flight_messages: int = None):
super().__init__(c=c)

self._process_pool_usage = process_pool_usage
self._shared_process_pool = SharedProcessPool()
self._shared_process_pool.set_usage(self.name, self._process_pool_usage)

if max_in_flight_messages is None:
# set the multiplier to 1.5 to keep the workers busy
self._max_in_flight_messages = int(self._shared_process_pool.total_max_workers * 1.5)
else:
self._max_in_flight_messages = max_in_flight_messages

# self._max_in_flight_messages = 1

@property
def name(self) -> str:
Expand All @@ -27,12 +37,25 @@ def accepted_types(self) -> typing.Tuple:
return (InputT, )

def compute_schema(self, schema: StageSchema):
return super().compute_schema(schema)
for (port_idx, port_schema) in enumerate(schema.input_schemas):
schema.output_schemas[port_idx].set_type(port_schema.get_type())

@abstractmethod
def _on_data(self, data: InputT) -> OutputT:
pass

def supports_cpp_node(self):
return False

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
node = builder.make_node(self.name, ops.map(self._on_data))
node.launch_options.pe_count = self._max_in_flight_messages

builder.make_edge(input_node, node)

return node



class MultiProcessingStage(MultiProcessingBaseStage[InputT, OutputT]):

Expand All @@ -42,16 +65,9 @@ def __init__(self,
process_pool_usage: float,
process_fn: typing.Callable[[InputT], OutputT],
max_in_flight_messages: int = None):
super().__init__(c=c, process_pool_usage=process_pool_usage)
super().__init__(c=c, process_pool_usage=process_pool_usage, max_in_flight_messages=max_in_flight_messages)

self._process_fn = process_fn
self._shared_process_pool = SharedProcessPool()
self._shared_process_pool.set_usage(self.name, self._process_pool_usage)
if max_in_flight_messages is None:
# set the multiplier to 1.5 to keep the workers busy
self._max_in_flight_messages = self._shared_process_pool.total_max_workers * 1.5
else:
self._max_in_flight_messages = max_in_flight_messages

@property
def name(self) -> str:
Expand All @@ -69,14 +85,6 @@ def create(*, c: Config, process_fn: typing.Callable[[InputT], OutputT], process

return MultiProcessingStage[InputT, OutputT](c=c, process_pool_usage=process_pool_usage, process_fn=process_fn)

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
node = builder.make_node(self.name, ops.map(self._on_data))
node.lanuch_options.pe_count = self._max_in_flight_messages

builder.make_edge(input_node, node)

return node


# pipe = LinearPipeline(config)

Expand Down
64 changes: 64 additions & 0 deletions tests/test_multi_processing_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Tuple
import cudf
import pytest
import os

from _utils.dataset_manager import DatasetManager
from morpheus.config import Config
from morpheus.messages import ControlMessage
from morpheus.pipeline import LinearPipeline
from morpheus.stages.general.multi_processing_stage import MultiProcessingBaseStage
from morpheus.stages.general.multi_processing_stage import MultiProcessingStage
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage


def test_constructor(config: Config):
stage = MultiProcessingStage.create(c=config, process_fn=lambda x: x, process_pool_usage=0.5)
assert stage.name == "multi-processing-stage"


class DerivedMultiProcessingStage(MultiProcessingBaseStage[ControlMessage, ControlMessage]):

def __init__(self,
*,
c: Config,
process_pool_usage: float,
add_column_name: str,
max_in_flight_messages: int = None):
super().__init__(c=c, process_pool_usage=process_pool_usage, max_in_flight_messages=max_in_flight_messages)

self._add_column_name = add_column_name

@property
def name(self) -> str:
return "derived-multi-processing-stage"

def accepted_types(self) -> Tuple:
return (ControlMessage, )

def _on_data(self, data: ControlMessage) -> ControlMessage:
with data.payload().mutable_dataframe() as df:
df[self._add_column_name] = "Hello"

return data

@pytest.mark.use_python
def test_stage_pipe(config: Config, dataset_pandas: DatasetManager):

config.num_threads = os.cpu_count()
input_df = dataset_pandas["filter_probs.csv"]

pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, [cudf.DataFrame(input_df)]))
pipe.add_stage(DeserializeStage(config, ensure_sliceable_index=True, message_type=ControlMessage))
pipe.add_stage(DerivedMultiProcessingStage(c=config, process_pool_usage=0.5, add_column_name="new_column"))

pipe.run()


# if __name__ == "__main__":
# config = Config()
# dataset_pandas = DatasetManager()
# # test_constructor(config)
# test_stage_pipe(config, dataset_pandas)

0 comments on commit 46865ae

Please sign in to comment.