From f7b7280cea4ab7ee033b37c0b48bc58b88e48da3 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang <134643420+yczhang-nv@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:33:38 -0700 Subject: [PATCH] add unit tests for shared_process_pool --- .../morpheus/utils/shared_process_pool.py | 25 ++- tests/utils/test_shared_process_pool.py | 160 +++++++++++++++++- 2 files changed, 166 insertions(+), 19 deletions(-) diff --git a/python/morpheus/morpheus/utils/shared_process_pool.py b/python/morpheus/morpheus/utils/shared_process_pool.py index 9345b868bb..093573defb 100644 --- a/python/morpheus/morpheus/utils/shared_process_pool.py +++ b/python/morpheus/morpheus/utils/shared_process_pool.py @@ -48,6 +48,9 @@ def result(self): raise self._exception.value return self._result.value + def done(self): + return self._done.is_set() + class SharedProcessPool: @@ -85,7 +88,6 @@ def _initialize(self, total_max_workers): self._stage_semaphores = self._manager.dict() self._processes = [] - # TODO: Test the performance of reading the shared variable in each worker loop and try some alternatives self._shutdown_in_progress = self._manager.Value("b", False) for i in range(total_max_workers): @@ -123,13 +125,14 @@ def _worker(task_queues, stage_semaphores, shutdown_in_progress): semaphore.release() continue - # if task is None: # Stop signal - # semaphore.release() - # return + if task is None: + logger.warning("Worker process %s received a None task.", os.getpid()) + semaphore.release() + continue - process_fn, args, future = task + process_fn, args, kwargs, future = task try: - result = process_fn(*args) + result = process_fn(*args, **kwargs) future.set_result(result) except Exception as e: future.set_exception(e) @@ -138,12 +141,12 @@ def _worker(task_queues, stage_semaphores, shutdown_in_progress): time.sleep(0.1) # Avoid busy-waiting - def submit_task(self, stage_name, process_fn, *args): + def submit_task(self, stage_name, process_fn, *args, **kwargs): """ Submit a task to the corresponding task queue of the stage. """ future = SerializableFuture(self._context.Manager()) - task = (process_fn, args, future) + task = (process_fn, args, kwargs, future) self._task_queues[stage_name].put(task) return future @@ -152,7 +155,7 @@ def set_usage(self, stage_name, percentage): """ Set the maximum percentage of processes that can be used by each stage. """ - if not 0 < percentage <= 1: + if not 0 <= percentage <= 1: raise ValueError("Percentage must be between 0 and 1.") new_total_usage = self._total_usage - self._stage_usage.get(stage_name, 0.0) + percentage @@ -175,10 +178,6 @@ def set_usage(self, stage_name, percentage): def shutdown(self): if not self._shutdown: self._shutdown_in_progress.value = True - # for stage_name, task_queue in self._task_queues.items(): - # for _ in range(self._total_max_workers): - # task_queue.put(None) - # logger.debug("Task queue for stage %s has been cleared.", stage_name) for i, p in enumerate(self._processes): p.join() diff --git a/tests/utils/test_shared_process_pool.py b/tests/utils/test_shared_process_pool.py index 5f553e7355..222f271444 100644 --- a/tests/utils/test_shared_process_pool.py +++ b/tests/utils/test_shared_process_pool.py @@ -13,18 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import logging import multiprocessing as mp +import threading import time import numpy as np +import pytest from morpheus.utils.shared_process_pool import SharedProcessPool -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) +@pytest.fixture(name="shared_process_pool") +def shared_process_pool_fixture(): + pool = SharedProcessPool() + atexit.register(pool.shutdown) # make sure to shutdown the pool before the test exits + + return pool + + def _matrix_multiplication_task(size): matrix_a = np.random.rand(size, size) matrix_b = np.random.rand(size, size) @@ -34,6 +45,22 @@ def _matrix_multiplication_task(size): return result +def _simple_add_task(a, b): + return a + b + + +def _process_func_with_exception(): + raise ValueError("Exception is raised in the process.") + + +def _unserializable_function(): + return threading.Lock() + + +def _arbitrary_function(*args, **kwargs): + return args, kwargs + + def _test_worker(pool, stage_name, task_size, num_tasks): future_list = [] for i in range(num_tasks): @@ -47,6 +74,8 @@ def _test_worker(pool, stage_name, task_size, num_tasks): logging.info("All tasks in stage %s have been completed in %.2f seconds.", stage_name, (future_list[-1].result()[1] - future_list[0].result()[1])) + assert len(future_list) == num_tasks + def test_singleton(): pool_1 = SharedProcessPool() @@ -55,14 +84,81 @@ def test_singleton(): assert pool_1 is pool_2 -def test_shared_process_pool(): - pool = SharedProcessPool() +def test_single_task(shared_process_pool): + pool = shared_process_pool + + pool.set_usage("test_stage", 0.5) + + a = 10 + b = 20 + + future = pool.submit_task("test_stage", _simple_add_task, a, b) + assert future.result() == a + b + + future = pool.submit_task("test_stage", _simple_add_task, a=a, b=b) + assert future.result() == a + b + + future = pool.submit_task("test_stage", _simple_add_task, a, b=b) + assert future.result() == a + b + + +def test_multiple_tasks(shared_process_pool): + pool = shared_process_pool + + pool.set_usage("test_stage", 0.5) + + num_tasks = 100 + futures = [] + for _ in range(num_tasks): + futures.append(pool.submit_task("test_stage", _simple_add_task, 10, 20)) + + for future in futures: + assert future.result() == 30 + + +def test_error_process_function(shared_process_pool): + pool = shared_process_pool + + pool.set_usage("test_stage", 0.5) + + with pytest.raises(ValueError): + future = pool.submit_task("test_stage", _process_func_with_exception) + future.result() + + +def test_unserializable_function(shared_process_pool): + pool = shared_process_pool + + pool.set_usage("test_stage", 0.5) + + with pytest.raises(TypeError): + future = pool.submit_task("test_stage", _unserializable_function) + future.result() + + +def test_unserializable_arg(shared_process_pool): + pool = shared_process_pool + + pool.set_usage("test_stage", 0.5) + + with pytest.raises(TypeError): + future = pool.submit_task("test_stage", _arbitrary_function, threading.Lock()) + future.result() + + +def test_multiple_stages(shared_process_pool): + pool = shared_process_pool + + pool.set_usage("test_stage", 0.0) # Remove usage of test_stage in previous tests pool.set_usage("test_stage_1", 0.1) pool.set_usage("test_stage_2", 0.3) pool.set_usage("test_stage_3", 0.6) - tasks = [("test_stage_1", 8000, 30), ("test_stage_2", 8000, 30), ("test_stage_3", 8000, 30)] + task_size = 3000 + task_num = 30 + tasks = [("test_stage_1", task_size, task_num), ("test_stage_2", task_size, task_num), + ("test_stage_3", task_size, task_num)] processes = [] for task in tasks: @@ -77,5 +173,57 @@ def test_shared_process_pool(): p.join() -if __name__ == "__main__": - test_shared_process_pool() +def test_invalid_stage_usage(shared_process_pool): + pool = shared_process_pool + + # Remove usage of test_stage in previous tests + pool.set_usage("test_stage", 0.0) + pool.set_usage("test_stage_1", 0.0) + pool.set_usage("test_stage_2", 0.0) + pool.set_usage("test_stage_3", 0.0) + + with pytest.raises(ValueError): + pool.set_usage("test_stage", 1.1) + + with pytest.raises(ValueError): + pool.set_usage("test_stage", -0.1) + + pool.set_usage("test_stage_1", 0.5) + pool.set_usage("test_stage_2", 0.4) + + pool.set_usage("test_stage_1", 0.6) # ok to update the usage of an existing stage + + with pytest.raises(ValueError): + pool.set_usage("test_stage_1", 0.7) # not ok to exceed the total usage limit after updating + + with pytest.raises(ValueError): + pool.set_usage("test_stage_3", 0.1) + + +def test_task_completion_before_shutdown(shared_process_pool): + pool = shared_process_pool + + # Remove usage of test_stage in previous tests + pool.set_usage("test_stage", 0.0) + pool.set_usage("test_stage_1", 0.0) + pool.set_usage("test_stage_2", 0.0) + pool.set_usage("test_stage_3", 0.0) + + pool.set_usage("test_stage_1", 0.1) + pool.set_usage("test_stage_2", 0.3) + pool.set_usage("test_stage_3", 0.6) + + task_size = 3000 + task_num = 30 + futures = [] + for _ in range(task_num): + futures.append(pool.submit_task("test_stage_1", _matrix_multiplication_task, task_size)) + futures.append(pool.submit_task("test_stage_2", _matrix_multiplication_task, task_size)) + futures.append(pool.submit_task("test_stage_3", _matrix_multiplication_task, task_size)) + + pool.shutdown() + + # all tasks should be completed before shutdown + assert len(futures) == 3 * task_num + for future in futures: + assert future.done()