Skip to content

Commit

Permalink
add unit tests for shared_process_pool
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Aug 29, 2024
1 parent b073b7d commit f7b7280
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 19 deletions.
25 changes: 12 additions & 13 deletions python/morpheus/morpheus/utils/shared_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def result(self):
raise self._exception.value
return self._result.value

def done(self):
return self._done.is_set()


class SharedProcessPool:

Expand Down Expand Up @@ -85,7 +88,6 @@ def _initialize(self, total_max_workers):
self._stage_semaphores = self._manager.dict()
self._processes = []

# TODO: Test the performance of reading the shared variable in each worker loop and try some alternatives
self._shutdown_in_progress = self._manager.Value("b", False)

for i in range(total_max_workers):
Expand Down Expand Up @@ -123,13 +125,14 @@ def _worker(task_queues, stage_semaphores, shutdown_in_progress):
semaphore.release()
continue

# if task is None: # Stop signal
# semaphore.release()
# return
if task is None:
logger.warning("Worker process %s received a None task.", os.getpid())
semaphore.release()
continue

process_fn, args, future = task
process_fn, args, kwargs, future = task
try:
result = process_fn(*args)
result = process_fn(*args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
Expand All @@ -138,12 +141,12 @@ def _worker(task_queues, stage_semaphores, shutdown_in_progress):

time.sleep(0.1) # Avoid busy-waiting

def submit_task(self, stage_name, process_fn, *args):
def submit_task(self, stage_name, process_fn, *args, **kwargs):
"""
Submit a task to the corresponding task queue of the stage.
"""
future = SerializableFuture(self._context.Manager())
task = (process_fn, args, future)
task = (process_fn, args, kwargs, future)
self._task_queues[stage_name].put(task)

return future
Expand All @@ -152,7 +155,7 @@ def set_usage(self, stage_name, percentage):
"""
Set the maximum percentage of processes that can be used by each stage.
"""
if not 0 < percentage <= 1:
if not 0 <= percentage <= 1:
raise ValueError("Percentage must be between 0 and 1.")

new_total_usage = self._total_usage - self._stage_usage.get(stage_name, 0.0) + percentage
Expand All @@ -175,10 +178,6 @@ def set_usage(self, stage_name, percentage):
def shutdown(self):
if not self._shutdown:
self._shutdown_in_progress.value = True
# for stage_name, task_queue in self._task_queues.items():
# for _ in range(self._total_max_workers):
# task_queue.put(None)
# logger.debug("Task queue for stage %s has been cleared.", stage_name)

for i, p in enumerate(self._processes):
p.join()
Expand Down
160 changes: 154 additions & 6 deletions tests/utils/test_shared_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import logging
import multiprocessing as mp
import threading
import time

import numpy as np
import pytest

from morpheus.utils.shared_process_pool import SharedProcessPool

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


@pytest.fixture(name="shared_process_pool")
def shared_process_pool_fixture():
pool = SharedProcessPool()
atexit.register(pool.shutdown) # make sure to shutdown the pool before the test exits

return pool


def _matrix_multiplication_task(size):
matrix_a = np.random.rand(size, size)
matrix_b = np.random.rand(size, size)
Expand All @@ -34,6 +45,22 @@ def _matrix_multiplication_task(size):
return result


def _simple_add_task(a, b):
return a + b


def _process_func_with_exception():
raise ValueError("Exception is raised in the process.")


def _unserializable_function():
return threading.Lock()


def _arbitrary_function(*args, **kwargs):
return args, kwargs


def _test_worker(pool, stage_name, task_size, num_tasks):
future_list = []
for i in range(num_tasks):
Expand All @@ -47,6 +74,8 @@ def _test_worker(pool, stage_name, task_size, num_tasks):
logging.info("All tasks in stage %s have been completed in %.2f seconds.",
stage_name, (future_list[-1].result()[1] - future_list[0].result()[1]))

assert len(future_list) == num_tasks


def test_singleton():
pool_1 = SharedProcessPool()
Expand All @@ -55,14 +84,81 @@ def test_singleton():
assert pool_1 is pool_2


def test_shared_process_pool():
pool = SharedProcessPool()
def test_single_task(shared_process_pool):
pool = shared_process_pool

pool.set_usage("test_stage", 0.5)

a = 10
b = 20

future = pool.submit_task("test_stage", _simple_add_task, a, b)
assert future.result() == a + b

future = pool.submit_task("test_stage", _simple_add_task, a=a, b=b)
assert future.result() == a + b

future = pool.submit_task("test_stage", _simple_add_task, a, b=b)
assert future.result() == a + b


def test_multiple_tasks(shared_process_pool):
pool = shared_process_pool

pool.set_usage("test_stage", 0.5)

num_tasks = 100
futures = []
for _ in range(num_tasks):
futures.append(pool.submit_task("test_stage", _simple_add_task, 10, 20))

for future in futures:
assert future.result() == 30


def test_error_process_function(shared_process_pool):
pool = shared_process_pool

pool.set_usage("test_stage", 0.5)

with pytest.raises(ValueError):
future = pool.submit_task("test_stage", _process_func_with_exception)
future.result()


def test_unserializable_function(shared_process_pool):
pool = shared_process_pool

pool.set_usage("test_stage", 0.5)

with pytest.raises(TypeError):
future = pool.submit_task("test_stage", _unserializable_function)
future.result()


def test_unserializable_arg(shared_process_pool):
pool = shared_process_pool

pool.set_usage("test_stage", 0.5)

with pytest.raises(TypeError):
future = pool.submit_task("test_stage", _arbitrary_function, threading.Lock())
future.result()


def test_multiple_stages(shared_process_pool):
pool = shared_process_pool

pool.set_usage("test_stage", 0.0) # Remove usage of test_stage in previous tests

pool.set_usage("test_stage_1", 0.1)
pool.set_usage("test_stage_2", 0.3)
pool.set_usage("test_stage_3", 0.6)

tasks = [("test_stage_1", 8000, 30), ("test_stage_2", 8000, 30), ("test_stage_3", 8000, 30)]
task_size = 3000
task_num = 30
tasks = [("test_stage_1", task_size, task_num), ("test_stage_2", task_size, task_num),
("test_stage_3", task_size, task_num)]

processes = []
for task in tasks:
Expand All @@ -77,5 +173,57 @@ def test_shared_process_pool():
p.join()


if __name__ == "__main__":
test_shared_process_pool()
def test_invalid_stage_usage(shared_process_pool):
pool = shared_process_pool

# Remove usage of test_stage in previous tests
pool.set_usage("test_stage", 0.0)
pool.set_usage("test_stage_1", 0.0)
pool.set_usage("test_stage_2", 0.0)
pool.set_usage("test_stage_3", 0.0)

with pytest.raises(ValueError):
pool.set_usage("test_stage", 1.1)

with pytest.raises(ValueError):
pool.set_usage("test_stage", -0.1)

pool.set_usage("test_stage_1", 0.5)
pool.set_usage("test_stage_2", 0.4)

pool.set_usage("test_stage_1", 0.6) # ok to update the usage of an existing stage

with pytest.raises(ValueError):
pool.set_usage("test_stage_1", 0.7) # not ok to exceed the total usage limit after updating

with pytest.raises(ValueError):
pool.set_usage("test_stage_3", 0.1)


def test_task_completion_before_shutdown(shared_process_pool):
pool = shared_process_pool

# Remove usage of test_stage in previous tests
pool.set_usage("test_stage", 0.0)
pool.set_usage("test_stage_1", 0.0)
pool.set_usage("test_stage_2", 0.0)
pool.set_usage("test_stage_3", 0.0)

pool.set_usage("test_stage_1", 0.1)
pool.set_usage("test_stage_2", 0.3)
pool.set_usage("test_stage_3", 0.6)

task_size = 3000
task_num = 30
futures = []
for _ in range(task_num):
futures.append(pool.submit_task("test_stage_1", _matrix_multiplication_task, task_size))
futures.append(pool.submit_task("test_stage_2", _matrix_multiplication_task, task_size))
futures.append(pool.submit_task("test_stage_3", _matrix_multiplication_task, task_size))

pool.shutdown()

# all tasks should be completed before shutdown
assert len(futures) == 3 * task_num
for future in futures:
assert future.done()

0 comments on commit f7b7280

Please sign in to comment.