Skip to content

Commit

Permalink
Multiple fixes related to SharedProcessPool & MultiProcessingStage (
Browse files Browse the repository at this point in the history
#1940)

Several fixes related to `SharedProcessPool` and `MultiProcessingStage`.

- Add `pytest` fixture that should be applied to any tests that make use of `SharedProcessPool`.
- Switched the fork method of `SharedProcessPool` to `forkserver` to avoid inheriting unnecessary resources from parent process (this resolves issues in CPU-only mode)
- Add missing DocStrings to `MultiProcessingStage`.

Closes #1939 

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Yuchen Zhang (https://github.com/yczhang-nv)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #1940
  • Loading branch information
yczhang-nv authored Oct 11, 2024
1 parent 967216b commit d95a5cf
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 50 deletions.
118 changes: 93 additions & 25 deletions python/morpheus/morpheus/stages/general/multi_processing_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,27 @@


class MultiProcessingBaseStage(SinglePortStage, typing.Generic[InputT, OutputT]):
"""
Base class for all MultiProcessing stages that make use of the SharedProcessPool.
Parameters
----------
c : `morpheus.config.Config`
Pipeline configuration instance.
process_pool_usage : float
The fraction of the process pool workers that this stage could use. Should be between 0 and 1.
max_in_flight_messages : int, default = None
The number of progress engines used by the stage. If None, it will be set to 1.5 times the total
number of process pool workers
Raises
------
ValueError
If the process pool usage is not between 0 and 1.
"""

def __init__(self, *, c: Config, process_pool_usage: float, max_in_flight_messages: int = None):

super().__init__(c=c)

if not 0 <= process_pool_usage <= 1:
Expand All @@ -50,36 +69,43 @@ def __init__(self, *, c: Config, process_pool_usage: float, max_in_flight_messag

def accepted_types(self) -> typing.Tuple:
"""
There are two approaches to inherit from this class:
- With generic types: MultiProcessingDerivedStage(MultiProcessingBaseStage[InputT, OutputT])
- With concrete types: MultiProcessingDerivedStage(MultiProcessingBaseStage[int, str])
Accepted input types for this stage are returned.
Raises
------
RuntimeError
If the accepted types cannot be deduced from either __orig_class__ or __orig_bases__.
When inheriting with generic types, the derived class can be instantiated like this:
Returns
-------
typing.Tuple
Accepted input types.
"""

stage = MultiProcessingDerivedStage[int, str]()
# There are two approaches to inherit from this class:
# - With generic types: MultiProcessingDerivedStage(MultiProcessingBaseStage[InputT, OutputT])
# - With concrete types: MultiProcessingDerivedStage(MultiProcessingBaseStage[int, str])

In this case, typing.Generic stores the stage type in stage.__orig_class__, the concrete types can be accessed
as below:
# When inheriting with generic types, the derived class can be instantiated like this:

input_type = typing.get_args(stage.__orig_class__)[0] # int
output_type = typing.get_args(stage.__orig_class__)[1] # str
# stage = MultiProcessingDerivedStage[int, str]()

However, when instantiating a stage which inherits with concrete types:
# In this case, typing.Generic stores the stage type in stage.__orig_class__, the concrete types can be accessed
# as below:

stage = MultiProcessingDerivedStage()
# input_type = typing.get_args(stage.__orig_class__)[0] # int
# output_type = typing.get_args(stage.__orig_class__)[1] # str

The stage instance does not have __orig_class__ attribute (since it is not a generic type). Thus, the concrete
types need be retrieved from its base class (which is a generic type):
# However, when instantiating a stage which inherits with concrete types:

input_type = typing.get_args(stage.__orig_bases__[0])[0] # int
output_type = typing.get_args(stage.__orig_bases__[0])[1] # str
# stage = MultiProcessingDerivedStage()

Raises:
RuntimeError: if the accepted cannot be deducted from either __orig_class__ or __orig_bases__
# The stage instance does not have __orig_class__ attribute (since it is not a generic type). Thus, the concrete
# types need be retrieved from its base class (which is a generic type):

# input_type = typing.get_args(stage.__orig_bases__[0])[0] # int
# output_type = typing.get_args(stage.__orig_bases__[0])[1] # str

Returns:
typing.Tuple: accepted input types
"""
if hasattr(self, "__orig_class__"):
# inherited with generic types
input_type = typing.get_args(self.__orig_class__)[0] # pylint: disable=no-member
Expand All @@ -95,14 +121,20 @@ def accepted_types(self) -> typing.Tuple:

def compute_schema(self, schema: StageSchema):
"""
See the comment on `accepted_types` for more information on accessing the input and output types.
Compute the output schema for the stage.
Args:
schema (StageSchema): StageSchema
Parameters
----------
schema : StageSchema
The schema for the stage.
Raises:
RuntimeError: if the output type cannot be deducted from either __orig_class__ or __orig_bases__
Raises
------
RuntimeError
If the output type cannot be deduced from either __orig_class__ or __orig_bases__.
"""

# See the comment on `accepted_types` for more information on accessing the input and output types.
if hasattr(self, "__orig_class__"):
# inherited with abstract types
output_type = typing.get_args(self.__orig_class__)[1] # pylint: disable=no-member
Expand All @@ -117,6 +149,7 @@ def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(output_type)

def supports_cpp_node(self):
"""Whether this stage supports a C++ node."""
return False

@abstractmethod
Expand Down Expand Up @@ -162,6 +195,21 @@ def _get_func_signature(func: typing.Callable[[InputT], OutputT]) -> tuple[type,


class MultiProcessingStage(MultiProcessingBaseStage[InputT, OutputT]):
"""
A derived class of MultiProcessingBaseStage that allows the user to define a process function that will be executed
based on shared process pool.
Parameters
----------
c : `morpheus.config.Config`
Pipeline configuration instance.
unique_name : str
A unique name for the stage.
process_fn: typing.Callable[[InputT], OutputT]
The function that will be executed in the process pool.
max_in_flight_messages : int, default = None
The number of progress engines used by the stage.
"""

def __init__(self,
*,
Expand All @@ -178,6 +226,7 @@ def __init__(self,

@property
def name(self) -> str:
"""Return the name of the stage."""
return self._name

def _on_data(self, data: InputT) -> OutputT:
Expand All @@ -192,6 +241,25 @@ def create(*,
unique_name: str,
process_fn: typing.Callable[[InputT], OutputT],
process_pool_usage: float):
"""
Create a MultiProcessingStage instance by deducing the input and output types from the process function.
Parameters
----------
c : morpheus.config.Config
Pipeline configuration instance.
unique_name : str
A unique name for the stage.
process_fn : typing.Callable[[InputT], OutputT]
The function that will be executed in the process pool.
process_pool_usage : float
The fraction of the process pool workers that this stage could use. Should be between 0 and 1.
Returns
-------
MultiProcessingStage[InputT, OutputT]
A MultiProcessingStage instance with deduced input and output types.
"""

input_t, output_t = _get_func_signature(process_fn)
return MultiProcessingStage[input_t, output_t](c=c,
Expand Down
11 changes: 5 additions & 6 deletions python/morpheus/morpheus/utils/shared_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _initialize(self):
self._total_max_workers = math.floor(max(1, len(os.sched_getaffinity(0)) * cpu_usage))
self._processes = []

self._context = mp.get_context("fork")
self._context = mp.get_context("forkserver")
self._manager = self._context.Manager()
self._task_queues = self._manager.dict()
self._stage_semaphores = self._manager.dict()
Expand Down Expand Up @@ -196,8 +196,7 @@ def _worker(cancellation_token, task_queues, stage_semaphores):
continue

if task is None:
logger.warning("SharedProcessPool._worker: Worker process %s has received a None task.",
os.getpid())
logger.debug("SharedProcessPool._worker: Worker process %s has received a None task.", os.getpid())
semaphore.release()
continue

Expand Down Expand Up @@ -316,7 +315,7 @@ def start(self):
If the SharedProcessPool is not shutdown.
"""
if self._status == PoolStatus.RUNNING:
logger.warning("SharedProcessPool.start(): process pool is already running.")
logger.debug("SharedProcessPool.start(): process pool is already running.")
return

process_launcher = threading.Thread(target=self._launch_workers)
Expand Down Expand Up @@ -373,7 +372,7 @@ def stop(self):
Stop receiving any new tasks.
"""
if self._status not in (PoolStatus.RUNNING, PoolStatus.INITIALIZING):
logger.warning("SharedProcessPool.stop(): Cannot stop a SharedProcessPool that is not running.")
logger.debug("SharedProcessPool.stop(): Cannot stop a SharedProcessPool that is not running.")
return

# No new tasks will be accepted from this point
Expand All @@ -400,7 +399,7 @@ def join(self, timeout: float | None = None):

if self._status != PoolStatus.STOPPED:
if self._status == PoolStatus.SHUTDOWN:
logging.warning("SharedProcessPool.join(): process pool is already shut down.")
logger.debug("SharedProcessPool.join(): process pool is already shut down.")
return

raise RuntimeError("Cannot join SharedProcessPool that is not stopped.")
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
from _utils.kafka import kafka_bootstrap_servers_fixture # noqa: F401 pylint:disable=unused-import
from _utils.kafka import kafka_consumer_fixture # noqa: F401 pylint:disable=unused-import
from _utils.kafka import kafka_topics_fixture # noqa: F401 pylint:disable=unused-import
from morpheus.utils.shared_process_pool import SharedProcessPool

# Don't let pylint complain about pytest fixtures
# pylint: disable=redefined-outer-name,unused-argument

(PYTEST_KAFKA_AVAIL, PYTEST_KAFKA_ERROR) = _init_pytest_kafka()
if PYTEST_KAFKA_AVAIL:
# Pull out the fixtures into this namespace
# pylint: disable=ungrouped-imports
from _utils.kafka import _kafka_consumer # noqa: F401 pylint:disable=unused-import
from _utils.kafka import kafka_server # noqa: F401 pylint:disable=unused-import
from _utils.kafka import zookeeper_proc # noqa: F401 pylint:disable=unused-import
Expand Down Expand Up @@ -1150,3 +1152,24 @@ def mock_subscription_fixture():
ms = mock.MagicMock()
ms.is_subscribed.return_value = True
return ms


# ==== SharedProcessPool Fixtures ====
# Any tests that use the SharedProcessPool should use this fixture
@pytest.fixture(scope="module")
def shared_process_pool_setup_and_teardown():
# Set lower CPU usage for unit test to avoid slowing down the test
os.environ["MORPHEUS_SHARED_PROCESS_POOL_CPU_USAGE"] = "0.1"

pool = SharedProcessPool()

# SharedProcessPool might be configured and used in other tests, stop and reset the pool before the test starts
pool.stop()
pool.join()
pool.reset()
yield pool

# Stop the pool after all tests are done
pool.stop()
pool.join()
os.environ.pop("MORPHEUS_SHARED_PROCESS_POOL_CPU_USAGE", None)
5 changes: 5 additions & 0 deletions tests/test_multi_processing_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage


@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown(shared_process_pool_setup_and_teardown): # pylint: disable=unused-argument
pass


def _create_df(count: int) -> pd.DataFrame:
return pd.DataFrame({"a": range(count)}, {"b": range(count)})

Expand Down
23 changes: 4 additions & 19 deletions tests/utils/test_shared_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import logging
import multiprocessing as mp
import os
import threading
from decimal import Decimal
from fractions import Fraction
Expand All @@ -30,26 +29,12 @@
# This test has issues with joining processes when testing with pytest `-s` option. Run pytest without `-s` flag


@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown():
# Set lower CPU usage for unit test to avoid slowing down the test
os.environ["MORPHEUS_SHARED_PROCESS_POOL_CPU_USAGE"] = "0.1"

pool = SharedProcessPool()

# Since SharedProcessPool might be used in other tests, stop and reset the pool before the test starts
pool.stop()
pool.join()
pool.reset()
yield

# Stop the pool after all tests are done
pool.stop()
pool.join()
os.environ.pop("MORPHEUS_SHARED_PROCESS_POOL_CPU_USAGE", None)
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown(shared_process_pool_setup_and_teardown): # pylint: disable=unused-argument
pass


@pytest.fixture(name="shared_process_pool")
@pytest.fixture(name="shared_process_pool", scope="function")
def shared_process_pool_fixture():

pool = SharedProcessPool()
Expand Down

0 comments on commit d95a5cf

Please sign in to comment.