From ccb3d4cc41db79b84df5ef0bbb70171b53cacb97 Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Fri, 1 Dec 2023 17:09:49 +0000 Subject: [PATCH 1/6] FIX-#407: Make the data put into unidist immutable. Signed-off-by: Kirill Suvorov --- .../core/backends/mpi/core/controller/api.py | 2 +- .../backends/mpi/core/controller/common.py | 25 +++-- .../backends/mpi/core/local_object_store.py | 35 ++++++- .../backends/mpi/core/shared_object_store.py | 98 ++++++++++--------- .../backends/mpi/core/worker/task_store.py | 2 +- unidist/test/test_task.py | 22 ++++- 6 files changed, 121 insertions(+), 63 deletions(-) diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index e573aafd..caee566f 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -386,7 +386,7 @@ def put(data): data_id = local_store.generate_data_id(garbage_collector) serialized_data = serialize_complex_data(data) - local_store.put(data_id, data) + # data is prepared for sending to another process, but is not saved to local storage if shared_store.is_allocated(): shared_store.put(data_id, serialized_data) else: diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index 0646fd86..b1fe1fe3 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -394,22 +394,19 @@ def push_data(dest_rank, value, is_blocking_op=False): data_id = value if shared_store.contains(data_id): _push_shared_data(dest_rank, data_id, is_blocking_op) + elif local_store.is_already_serialized(data_id): + _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) elif local_store.contains(data_id): - if local_store.is_already_serialized(data_id): - _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) + data = local_store.get(data_id) + serialized_data = serialize_complex_data(data) + if shared_store.is_allocated() and shared_store.should_be_shared( + serialized_data + ): + shared_store.put(data_id, serialized_data) + _push_shared_data(dest_rank, data_id, is_blocking_op) else: - data = local_store.get(data_id) - serialized_data = serialize_complex_data(data) - if shared_store.is_allocated() and shared_store.should_be_shared( - serialized_data - ): - shared_store.put(data_id, serialized_data) - _push_shared_data(dest_rank, data_id, is_blocking_op) - else: - local_store.cache_serialized_data(data_id, serialized_data) - _push_local_data( - dest_rank, data_id, is_blocking_op, is_serialized=True - ) + local_store.cache_serialized_data(data_id, serialized_data) + _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) elif local_store.contains_data_owner(data_id): _push_data_owner(dest_rank, data_id) else: diff --git a/unidist/core/backends/mpi/core/local_object_store.py b/unidist/core/backends/mpi/core/local_object_store.py index bc80be5c..64fa50b4 100644 --- a/unidist/core/backends/mpi/core/local_object_store.py +++ b/unidist/core/backends/mpi/core/local_object_store.py @@ -8,6 +8,8 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication +from unidist.core.backends.mpi.core.serialization import deserialize_complex_data +from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore class LocalObjectStore: @@ -112,7 +114,26 @@ def get(self, data_id): object Return local data associated with `data_id`. """ - return self._data_map[data_id] + if data_id in self._data_map: + return self._data_map[data_id] + # data can be prepared for sending to another process, but is not saved to local storage + else: + shared_object_store = SharedObjectStore.get_instance() + if self.is_already_serialized(data_id): + serialized_data = self.get_serialized_data(data_id) + value = deserialize_complex_data( + serialized_data["s_data"], + serialized_data["raw_buffers"], + serialized_data["buffer_count"], + ) + elif shared_object_store.contains(data_id): + value = shared_object_store.get(data_id) + else: + raise ValueError( + "The current data ID is not contained in the LocalObjectStore." + ) + self.put(data_id, value) + return value def get_data_owner(self, data_id): """ @@ -144,7 +165,12 @@ def contains(self, data_id): bool Return the status if an object exist in local dictionary. """ - return data_id in self._data_map + shared_object_store = SharedObjectStore.get_instance() + return ( + data_id in self._data_map + or self.is_already_serialized(data_id) + or shared_object_store.contains(data_id) + ) def contains_data_owner(self, data_id): """ @@ -277,6 +303,11 @@ def cache_serialized_data(self, data_id, data): data : object Serialized data to cache. """ + # Copying is necessary to avoid corruption of data obtained through out-of-band serialization, + # and buffers are marked read-only to prevent them from being modified. + data["raw_buffers"] = [ + memoryview(buf.tobytes()).toreadonly() for buf in data["raw_buffers"] + ] self._serialization_cache[data_id] = data self.maybe_update_data_id_map(data_id) diff --git a/unidist/core/backends/mpi/core/shared_object_store.py b/unidist/core/backends/mpi/core/shared_object_store.py index c3f76405..d898e900 100644 --- a/unidist/core/backends/mpi/core/shared_object_store.py +++ b/unidist/core/backends/mpi/core/shared_object_store.py @@ -747,7 +747,7 @@ def put(self, data_id, serialized_data): # put shared info self._put_shared_info(data_id, shared_info) - def get(self, data_id, owner_rank, shared_info): + def get(self, data_id, owner_rank=None, shared_info=None): """ Get data from another worker using shared memory. @@ -755,54 +755,64 @@ def get(self, data_id, owner_rank, shared_info): ---------- data_id : unidist.core.backends.mpi.core.common.MpiDataID An ID to data. - owner_rank : int + owner_rank : int, default: None The rank that sent the data. - shared_info : dict - The necessary information to properly deserialize data from shared memory. + This value is used to synchronize data in shared memory between different hosts if the value is defined. + shared_info : dict, default: None + The necessary information to properly deserialize data from shared memory. If `shared_info` is None, the data already exists in shared memory. """ - mpi_state = communication.MPIState.get_instance() - s_data_len = shared_info["s_data_len"] - raw_buffers_len = shared_info["raw_buffers_len"] - service_index = shared_info["service_index"] - buffer_count = shared_info["buffer_count"] - - # check data in shared memory - if not self._check_service_info(data_id, service_index): - # reserve shared memory - shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len]) - reservation_info = communication.send_reserve_operation( - mpi_state.global_comm, data_id, shared_data_len - ) - - service_index = reservation_info["service_index"] - # check if worker should sync shared buffer or it is doing by another worker - if reservation_info["is_first_request"]: - # syncronize shared buffer - self._sync_shared_memory_from_another_host( - mpi_state.global_comm, - data_id, - owner_rank, - reservation_info["first_index"], - reservation_info["last_index"], - service_index, - ) - # put service info - self._put_service_info( - service_index, data_id, reservation_info["first_index"] + if shared_info is None: + shared_info = self.get_shared_info(data_id) + else: + mpi_state = communication.MPIState.get_instance() + s_data_len = shared_info["s_data_len"] + raw_buffers_len = shared_info["raw_buffers_len"] + service_index = shared_info["service_index"] + buffer_count = shared_info["buffer_count"] + + # check data in shared memory + if not self._check_service_info(data_id, service_index): + # reserve shared memory + shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len]) + reservation_info = communication.send_reserve_operation( + mpi_state.global_comm, data_id, shared_data_len ) - else: - # wait while another worker syncronize shared buffer - while not self._check_service_info(data_id, service_index): - time.sleep(MpiBackoff.get()) - # put shared info with updated data_id and service_index - shared_info = common.MetadataPackage.get_shared_info( - data_id, s_data_len, raw_buffers_len, buffer_count, service_index - ) - self._put_shared_info(data_id, shared_info) + service_index = reservation_info["service_index"] + # check if worker should sync shared buffer or it is doing by another worker + if reservation_info["is_first_request"]: + # syncronize shared buffer + if owner_rank is None: + raise ValueError( + "The data is not in the host's shared memory and the data must be synchronized, " + + "but the owner rank is not defined." + ) + + self._sync_shared_memory_from_another_host( + mpi_state.global_comm, + data_id, + owner_rank, + reservation_info["first_index"], + reservation_info["last_index"], + service_index, + ) + # put service info + self._put_service_info( + service_index, data_id, reservation_info["first_index"] + ) + else: + # wait while another worker syncronize shared buffer + while not self._check_service_info(data_id, service_index): + time.sleep(MpiBackoff.get()) + + # put shared info with updated data_id and service_index + shared_info = common.MetadataPackage.get_shared_info( + data_id, s_data_len, raw_buffers_len, buffer_count, service_index + ) + self._put_shared_info(data_id, shared_info) - # increment ref - self._increment_ref_number(data_id, shared_info["service_index"]) + # increment ref + self._increment_ref_number(data_id, shared_info["service_index"]) # read from shared buffer and deserialized return self._read_from_shared_buffer(data_id, shared_info) diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index 4e8b7bc1..ea1b8efd 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -188,7 +188,7 @@ def unwrap_local_data_id(self, arg): if is_data_id(arg): local_store = LocalObjectStore.get_instance() if local_store.contains(arg): - value = LocalObjectStore.get_instance().get(arg) + value = local_store.get(arg) # Data is already local or was pushed from master return value, False elif local_store.contains_data_owner(arg): diff --git a/unidist/test/test_task.py b/unidist/test/test_task.py index c53b8f11..02df73cc 100644 --- a/unidist/test/test_task.py +++ b/unidist/test/test_task.py @@ -6,7 +6,7 @@ import pytest import unidist -from unidist.config import Backend +from unidist.config import Backend, MpiSharedObjectStore from unidist.core.base.common import BackendName from .utils import ( assert_equal, @@ -122,3 +122,23 @@ def f(params): } assert_equal(f.remote(data), data) + + +@pytest.mark.xfail( + Backend.get() == BackendName.PYMP + or Backend.get() == BackendName.PYSEQ + or Backend.get() == BackendName.MPI + and not MpiSharedObjectStore.get(), + reason="PYMP, PYSEQ and MPI (disabled shared object store) do not copy data into the object store", +) +def test_data_immutability(): + data = [1, 2, 3] + object_ref = unidist.put(data) + + data[0] = 111 + try: + assert_equal(object_ref, data) + except AssertionError: + pass + else: + assert False From 42dc6dabb9dec50a5f80ae3f729f3d309406c526 Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Tue, 5 Dec 2023 15:51:08 +0000 Subject: [PATCH 2/6] Test update --- unidist/test/test_task.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/unidist/test/test_task.py b/unidist/test/test_task.py index 02df73cc..1cb346d8 100644 --- a/unidist/test/test_task.py +++ b/unidist/test/test_task.py @@ -6,7 +6,7 @@ import pytest import unidist -from unidist.config import Backend, MpiSharedObjectStore +from unidist.config import Backend from unidist.core.base.common import BackendName from .utils import ( assert_equal, @@ -125,20 +125,12 @@ def f(params): @pytest.mark.xfail( - Backend.get() == BackendName.PYMP - or Backend.get() == BackendName.PYSEQ - or Backend.get() == BackendName.MPI - and not MpiSharedObjectStore.get(), - reason="PYMP, PYSEQ and MPI (disabled shared object store) do not copy data into the object store", + Backend.get() == BackendName.PYSEQ, + reason="PUT using PYSEQ does not provide immutable data", ) def test_data_immutability(): data = [1, 2, 3] object_ref = unidist.put(data) data[0] = 111 - try: - assert_equal(object_ref, data) - except AssertionError: - pass - else: - assert False + assert_equal(object_ref, [1, 2, 3]) From fb7b2ab7e21359f75518514eb55d535f30e95e77 Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Wed, 13 Dec 2023 06:40:35 -0700 Subject: [PATCH 3/6] Separate local storage from shared storage --- .../core/backends/mpi/core/controller/api.py | 14 ++-- .../backends/mpi/core/controller/common.py | 67 ++++++++++++++++++- .../backends/mpi/core/local_object_store.py | 34 +--------- unidist/core/backends/mpi/core/worker/loop.py | 4 +- .../backends/mpi/core/worker/request_store.py | 8 +-- .../backends/mpi/core/worker/task_store.py | 8 +-- 6 files changed, 81 insertions(+), 54 deletions(-) diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index caee566f..655dcb80 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -28,6 +28,8 @@ request_worker_data, push_data, RoundRobin, + get_data, + contains_data, ) import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication @@ -411,21 +413,17 @@ def get(data_ids): object A Python object. """ - local_store = LocalObjectStore.get_instance() - is_list = isinstance(data_ids, list) if not is_list: data_ids = [data_ids] - remote_data_ids = [ - data_id for data_id in data_ids if not local_store.contains(data_id) - ] + remote_data_ids = [data_id for data_id in data_ids if not contains_data(data_id)] # Remote data gets available in the local store inside `request_worker_data` if remote_data_ids: request_worker_data(remote_data_ids) logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids))) - values = [local_store.get(data_id) for data_id in data_ids] + values = [get_data(data_id) for data_id in data_ids] # Initiate reference count based cleaup # if all the tasks were completed @@ -463,11 +461,9 @@ def wait(data_ids, num_returns=1): not_ready = data_ids pending_returns = num_returns ready = [] - local_store = LocalObjectStore.get_instance() - logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids))) for data_id in not_ready.copy(): - if local_store.contains(data_id): + if contains_data(data_id): ready.append(data_id) not_ready.remove(data_id) pending_returns -= 1 diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index b1fe1fe3..ef4980e2 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -12,7 +12,10 @@ from unidist.core.backends.mpi.core.async_operations import AsyncOperations from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore -from unidist.core.backends.mpi.core.serialization import serialize_complex_data +from unidist.core.backends.mpi.core.serialization import ( + deserialize_complex_data, + serialize_complex_data, +) logger = common.get_logger("common", "common.log") @@ -138,10 +141,10 @@ def pull_data(comm, owner_rank=None): shared_store = SharedObjectStore.get_instance() data_id = info_package["id"] - if local_store.contains(data_id): + if contains_data(data_id): return { "id": data_id, - "data": local_store.get(data_id), + "data": get_data(data_id), } data = shared_store.get(data_id, owner_rank, info_package) @@ -411,3 +414,61 @@ def push_data(dest_rank, value, is_blocking_op=False): _push_data_owner(dest_rank, data_id) else: raise ValueError("Unknown DataID!") + + +def contains_data(data_id): + """ + Check if the data associated with `data_id` exists in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + bool + Return the status if an object exist in the current process. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + return ( + local_store.contains(data_id) + or local_store.is_already_serialized(data_id) + or shared_store.contains(data_id) + ) + + +def get_data(data_id): + """ + Get data from any location in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + object + Return data associated with `data_id`. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + + if local_store.contains(data_id): + return local_store.get(data_id) + + if local_store.is_already_serialized(data_id): + serialized_data = local_store.get_serialized_data(data_id) + value = deserialize_complex_data( + serialized_data["s_data"], + serialized_data["raw_buffers"], + serialized_data["buffer_count"], + ) + elif shared_store.contains(data_id): + value = shared_store.get(data_id) + else: + raise ValueError("The current data ID is not contained in the procces.") + local_store.put(data_id, value) + return value diff --git a/unidist/core/backends/mpi/core/local_object_store.py b/unidist/core/backends/mpi/core/local_object_store.py index 64fa50b4..64428047 100644 --- a/unidist/core/backends/mpi/core/local_object_store.py +++ b/unidist/core/backends/mpi/core/local_object_store.py @@ -8,8 +8,6 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication -from unidist.core.backends.mpi.core.serialization import deserialize_complex_data -from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore class LocalObjectStore: @@ -114,26 +112,7 @@ def get(self, data_id): object Return local data associated with `data_id`. """ - if data_id in self._data_map: - return self._data_map[data_id] - # data can be prepared for sending to another process, but is not saved to local storage - else: - shared_object_store = SharedObjectStore.get_instance() - if self.is_already_serialized(data_id): - serialized_data = self.get_serialized_data(data_id) - value = deserialize_complex_data( - serialized_data["s_data"], - serialized_data["raw_buffers"], - serialized_data["buffer_count"], - ) - elif shared_object_store.contains(data_id): - value = shared_object_store.get(data_id) - else: - raise ValueError( - "The current data ID is not contained in the LocalObjectStore." - ) - self.put(data_id, value) - return value + return self._data_map[data_id] def get_data_owner(self, data_id): """ @@ -165,12 +144,7 @@ def contains(self, data_id): bool Return the status if an object exist in local dictionary. """ - shared_object_store = SharedObjectStore.get_instance() - return ( - data_id in self._data_map - or self.is_already_serialized(data_id) - or shared_object_store.contains(data_id) - ) + return data_id in self._data_map def contains_data_owner(self, data_id): """ @@ -305,9 +279,7 @@ def cache_serialized_data(self, data_id, data): """ # Copying is necessary to avoid corruption of data obtained through out-of-band serialization, # and buffers are marked read-only to prevent them from being modified. - data["raw_buffers"] = [ - memoryview(buf.tobytes()).toreadonly() for buf in data["raw_buffers"] - ] + data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]] self._serialization_cache[data_id] = data self.maybe_update_data_id_map(data_id) diff --git a/unidist/core/backends/mpi/core/worker/loop.py b/unidist/core/backends/mpi/core/worker/loop.py index 32d6eda7..12d3a0d6 100644 --- a/unidist/core/backends/mpi/core/worker/loop.py +++ b/unidist/core/backends/mpi/core/worker/loop.py @@ -20,7 +20,7 @@ from unidist.core.backends.mpi.core.worker.request_store import RequestStore from unidist.core.backends.mpi.core.worker.task_store import TaskStore from unidist.core.backends.mpi.core.async_operations import AsyncOperations -from unidist.core.backends.mpi.core.controller.common import pull_data +from unidist.core.backends.mpi.core.controller.common import pull_data, get_data from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore # TODO: Find a way to move this after all imports @@ -185,7 +185,7 @@ async def worker_loop(): if not ready_to_shutdown_posted: # Prepare the data # Actor method here is a data id so we have to retrieve it from the storage - method_name = local_store.get(request["task"]) + method_name = get_data(request["task"]) handler = request["handler"] actor_method = getattr(actor_map[handler], method_name) request["task"] = actor_method diff --git a/unidist/core/backends/mpi/core/worker/request_store.py b/unidist/core/backends/mpi/core/worker/request_store.py index 47f2dc9c..df8450cb 100644 --- a/unidist/core/backends/mpi/core/worker/request_store.py +++ b/unidist/core/backends/mpi/core/worker/request_store.py @@ -6,8 +6,7 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication -from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore -from unidist.core.backends.mpi.core.controller.common import push_data +from unidist.core.backends.mpi.core.controller.common import push_data, contains_data mpi_state = communication.MPIState.get_instance() @@ -213,7 +212,7 @@ def process_wait_request(self, data_id): ----- Only ROOT rank is supported for now, therefore no rank argument needed. """ - if LocalObjectStore.get_instance().contains(data_id): + if contains_data(data_id): # Executor wait just for signal # We use a blocking send here because the receiver is waiting for the result. communication.mpi_send_object( @@ -247,8 +246,7 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False): ----- Request is asynchronous, no wait for the data sending. """ - local_store = LocalObjectStore.get_instance() - if local_store.contains(data_id): + if contains_data(data_id): push_data( source_rank, data_id, diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index ea1b8efd..7231459d 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -10,6 +10,7 @@ from unidist.core.backends.common.data_id import is_data_id import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication +from unidist.core.backends.mpi.core.controller.common import get_data, contains_data from unidist.core.backends.mpi.core.async_operations import AsyncOperations from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore @@ -187,8 +188,8 @@ def unwrap_local_data_id(self, arg): """ if is_data_id(arg): local_store = LocalObjectStore.get_instance() - if local_store.contains(arg): - value = local_store.get(arg) + if contains_data(arg): + value = get_data(arg) # Data is already local or was pushed from master return value, False elif local_store.contains_data_owner(arg): @@ -418,12 +419,11 @@ def process_task_request(self, request): Same request if the task couldn`t be executed, otherwise ``None``. """ # Parse request - local_store = LocalObjectStore.get_instance() task = request["task"] # Remote function here is a data id so we have to retrieve it from the storage, # whereas actor method is already materialized in the worker loop. if is_data_id(task): - task = local_store.get(task) + task = get_data(task) args = request["args"] kwargs = request["kwargs"] output_ids = request["output"] From aa9928fd8e8755bbd6812cacbd425cd7406926fc Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Wed, 13 Dec 2023 07:41:03 -0700 Subject: [PATCH 4/6] Add class ObjectStore --- .../core/backends/mpi/core/controller/api.py | 14 ++-- .../backends/mpi/core/controller/common.py | 63 +------------- .../backends/mpi/core/local_object_store.py | 3 +- .../core/backends/mpi/core/object_store.py | 84 +++++++++++++++++++ .../backends/mpi/core/shared_object_store.py | 6 +- unidist/core/backends/mpi/core/worker/loop.py | 6 +- .../backends/mpi/core/worker/request_store.py | 9 +- .../backends/mpi/core/worker/task_store.py | 10 ++- 8 files changed, 116 insertions(+), 79 deletions(-) create mode 100644 unidist/core/backends/mpi/core/object_store.py diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index 655dcb80..8f13546a 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -19,6 +19,7 @@ ) from None from unidist.core.backends.mpi.core.serialization import serialize_complex_data +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.controller.garbage_collector import ( @@ -28,8 +29,6 @@ request_worker_data, push_data, RoundRobin, - get_data, - contains_data, ) import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication @@ -388,7 +387,6 @@ def put(data): data_id = local_store.generate_data_id(garbage_collector) serialized_data = serialize_complex_data(data) - # data is prepared for sending to another process, but is not saved to local storage if shared_store.is_allocated(): shared_store.put(data_id, serialized_data) else: @@ -413,17 +411,20 @@ def get(data_ids): object A Python object. """ + object_store = ObjectStore.get_instance() is_list = isinstance(data_ids, list) if not is_list: data_ids = [data_ids] - remote_data_ids = [data_id for data_id in data_ids if not contains_data(data_id)] + remote_data_ids = [ + data_id for data_id in data_ids if not object_store.contains_data(data_id) + ] # Remote data gets available in the local store inside `request_worker_data` if remote_data_ids: request_worker_data(remote_data_ids) logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids))) - values = [get_data(data_id) for data_id in data_ids] + values = [object_store.get_data(data_id) for data_id in data_ids] # Initiate reference count based cleaup # if all the tasks were completed @@ -452,6 +453,7 @@ def wait(data_ids, num_returns=1): tuple List of data IDs that are ready and list of the remaining data IDs. """ + object_store = ObjectStore.get_instance() if not isinstance(data_ids, list): data_ids = [data_ids] # Since the controller should operate MpiDataID(s), @@ -463,7 +465,7 @@ def wait(data_ids, num_returns=1): ready = [] logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids))) for data_id in not_ready.copy(): - if contains_data(data_id): + if object_store.contains_data(data_id): ready.append(data_id) not_ready.remove(data_id) pending_returns -= 1 diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index ef4980e2..6d6f3689 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -13,7 +13,6 @@ from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.serialization import ( - deserialize_complex_data, serialize_complex_data, ) @@ -141,10 +140,10 @@ def pull_data(comm, owner_rank=None): shared_store = SharedObjectStore.get_instance() data_id = info_package["id"] - if contains_data(data_id): + if local_store.contains(data_id): return { "id": data_id, - "data": get_data(data_id), + "data": local_store.get(data_id), } data = shared_store.get(data_id, owner_rank, info_package) @@ -414,61 +413,3 @@ def push_data(dest_rank, value, is_blocking_op=False): _push_data_owner(dest_rank, data_id) else: raise ValueError("Unknown DataID!") - - -def contains_data(data_id): - """ - Check if the data associated with `data_id` exists in the current process. - - Parameters - ---------- - data_id : unidist.core.backends.mpi.core.common.MpiDataID - An ID to data. - - Returns - ------- - bool - Return the status if an object exist in the current process. - """ - local_store = LocalObjectStore.get_instance() - shared_store = SharedObjectStore.get_instance() - return ( - local_store.contains(data_id) - or local_store.is_already_serialized(data_id) - or shared_store.contains(data_id) - ) - - -def get_data(data_id): - """ - Get data from any location in the current process. - - Parameters - ---------- - data_id : unidist.core.backends.mpi.core.common.MpiDataID - An ID to data. - - Returns - ------- - object - Return data associated with `data_id`. - """ - local_store = LocalObjectStore.get_instance() - shared_store = SharedObjectStore.get_instance() - - if local_store.contains(data_id): - return local_store.get(data_id) - - if local_store.is_already_serialized(data_id): - serialized_data = local_store.get_serialized_data(data_id) - value = deserialize_complex_data( - serialized_data["s_data"], - serialized_data["raw_buffers"], - serialized_data["buffer_count"], - ) - elif shared_store.contains(data_id): - value = shared_store.get(data_id) - else: - raise ValueError("The current data ID is not contained in the procces.") - local_store.put(data_id, value) - return value diff --git a/unidist/core/backends/mpi/core/local_object_store.py b/unidist/core/backends/mpi/core/local_object_store.py index 64428047..2fbaf404 100644 --- a/unidist/core/backends/mpi/core/local_object_store.py +++ b/unidist/core/backends/mpi/core/local_object_store.py @@ -277,8 +277,9 @@ def cache_serialized_data(self, data_id, data): data : object Serialized data to cache. """ - # Copying is necessary to avoid corruption of data obtained through out-of-band serialization, + # We make a copy to avoid data corruption obtained through out-of-band serialization, # and buffers are marked read-only to prevent them from being modified. + # `to_bytes()` call handles both points. data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]] self._serialization_cache[data_id] = data self.maybe_update_data_id_map(data_id) diff --git a/unidist/core/backends/mpi/core/object_store.py b/unidist/core/backends/mpi/core/object_store.py new file mode 100644 index 00000000..c938a69d --- /dev/null +++ b/unidist/core/backends/mpi/core/object_store.py @@ -0,0 +1,84 @@ +from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore +from unidist.core.backends.mpi.core.serialization import deserialize_complex_data +from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore + + +class ObjectStore: + """ + Class that combines checking and reciving data from all stores in a current process. + + Notes + ----- + The store checks for both deserialized and serialized data. + """ + + __instance = None + + @classmethod + def get_instance(cls): + """ + Get instance of ``ObjectStore``. + + Returns + ------- + ObjectStore + """ + if cls.__instance is None: + cls.__instance = ObjectStore() + return cls.__instance + + def contains_data(self, data_id): + """ + Check if the data associated with `data_id` exists in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + bool + Return the status if an object exist in the current process. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + return ( + local_store.contains(data_id) + or local_store.is_already_serialized(data_id) + or shared_store.contains(data_id) + ) + + def get_data(self, data_id): + """ + Get data from any location in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + object + Return data associated with `data_id`. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + + if local_store.contains(data_id): + return local_store.get(data_id) + + if local_store.is_already_serialized(data_id): + serialized_data = local_store.get_serialized_data(data_id) + value = deserialize_complex_data( + serialized_data["s_data"], + serialized_data["raw_buffers"], + serialized_data["buffer_count"], + ) + elif shared_store.contains(data_id): + value = shared_store.get(data_id) + else: + raise ValueError("The current data ID is not contained in the procces.") + local_store.put(data_id, value) + return value diff --git a/unidist/core/backends/mpi/core/shared_object_store.py b/unidist/core/backends/mpi/core/shared_object_store.py index d898e900..85050081 100644 --- a/unidist/core/backends/mpi/core/shared_object_store.py +++ b/unidist/core/backends/mpi/core/shared_object_store.py @@ -757,9 +757,11 @@ def get(self, data_id, owner_rank=None, shared_info=None): An ID to data. owner_rank : int, default: None The rank that sent the data. - This value is used to synchronize data in shared memory between different hosts if the value is defined. + This value is used to synchronize data in shared memory between different hosts + if the value is not ``None``. shared_info : dict, default: None - The necessary information to properly deserialize data from shared memory. If `shared_info` is None, the data already exists in shared memory. + The necessary information to properly deserialize data from shared memory. + If `shared_info` is ``None``, the data already exists in shared memory in the current process. """ if shared_info is None: shared_info = self.get_shared_info(data_id) diff --git a/unidist/core/backends/mpi/core/worker/loop.py b/unidist/core/backends/mpi/core/worker/loop.py index 12d3a0d6..f90cdc24 100644 --- a/unidist/core/backends/mpi/core/worker/loop.py +++ b/unidist/core/backends/mpi/core/worker/loop.py @@ -16,11 +16,12 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.worker.request_store import RequestStore from unidist.core.backends.mpi.core.worker.task_store import TaskStore from unidist.core.backends.mpi.core.async_operations import AsyncOperations -from unidist.core.backends.mpi.core.controller.common import pull_data, get_data +from unidist.core.backends.mpi.core.controller.common import pull_data from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore # TODO: Find a way to move this after all imports @@ -85,6 +86,7 @@ async def worker_loop(): ``unidist.core.backends.mpi.core.common.Operations`` defines a set of supported operations. """ task_store = TaskStore.get_instance() + object_store = ObjectStore.get_instance() local_store = LocalObjectStore.get_instance() request_store = RequestStore.get_instance() async_operations = AsyncOperations.get_instance() @@ -185,7 +187,7 @@ async def worker_loop(): if not ready_to_shutdown_posted: # Prepare the data # Actor method here is a data id so we have to retrieve it from the storage - method_name = get_data(request["task"]) + method_name = object_store.get_data(request["task"]) handler = request["handler"] actor_method = getattr(actor_map[handler], method_name) request["task"] = actor_method diff --git a/unidist/core/backends/mpi/core/worker/request_store.py b/unidist/core/backends/mpi/core/worker/request_store.py index df8450cb..b44b61a5 100644 --- a/unidist/core/backends/mpi/core/worker/request_store.py +++ b/unidist/core/backends/mpi/core/worker/request_store.py @@ -6,7 +6,8 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication -from unidist.core.backends.mpi.core.controller.common import push_data, contains_data +from unidist.core.backends.mpi.core.controller.common import push_data +from unidist.core.backends.mpi.core.object_store import ObjectStore mpi_state = communication.MPIState.get_instance() @@ -212,7 +213,8 @@ def process_wait_request(self, data_id): ----- Only ROOT rank is supported for now, therefore no rank argument needed. """ - if contains_data(data_id): + object_store = ObjectStore.get_instance() + if object_store.contains_data(data_id): # Executor wait just for signal # We use a blocking send here because the receiver is waiting for the result. communication.mpi_send_object( @@ -246,7 +248,8 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False): ----- Request is asynchronous, no wait for the data sending. """ - if contains_data(data_id): + object_store = ObjectStore.get_instance() + if object_store.contains_data(data_id): push_data( source_rank, data_id, diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index 7231459d..afb13bfb 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -10,8 +10,8 @@ from unidist.core.backends.common.data_id import is_data_id import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication -from unidist.core.backends.mpi.core.controller.common import get_data, contains_data from unidist.core.backends.mpi.core.async_operations import AsyncOperations +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.serialization import serialize_complex_data @@ -188,8 +188,9 @@ def unwrap_local_data_id(self, arg): """ if is_data_id(arg): local_store = LocalObjectStore.get_instance() - if contains_data(arg): - value = get_data(arg) + object_store = ObjectStore.get_instance() + if object_store.contains_data(arg): + value = object_store.get_data(arg) # Data is already local or was pushed from master return value, False elif local_store.contains_data_owner(arg): @@ -418,12 +419,13 @@ def process_task_request(self, request): dict or None Same request if the task couldn`t be executed, otherwise ``None``. """ + object_store = ObjectStore.get_instance() # Parse request task = request["task"] # Remote function here is a data id so we have to retrieve it from the storage, # whereas actor method is already materialized in the worker loop. if is_data_id(task): - task = get_data(task) + task = object_store.get_data(task) args = request["args"] kwargs = request["kwargs"] output_ids = request["output"] From 302369d33d15d3102d621742407937a054e9fffa Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Wed, 13 Dec 2023 08:18:31 -0700 Subject: [PATCH 5/6] Fixes --- unidist/core/backends/mpi/core/controller/api.py | 6 +++--- unidist/core/backends/mpi/core/object_store.py | 10 ++++++++-- unidist/core/backends/mpi/core/worker/loop.py | 2 +- unidist/core/backends/mpi/core/worker/request_store.py | 4 ++-- unidist/core/backends/mpi/core/worker/task_store.py | 6 +++--- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index 8f13546a..23523920 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -416,7 +416,7 @@ def get(data_ids): if not is_list: data_ids = [data_ids] remote_data_ids = [ - data_id for data_id in data_ids if not object_store.contains_data(data_id) + data_id for data_id in data_ids if not object_store.contains(data_id) ] # Remote data gets available in the local store inside `request_worker_data` if remote_data_ids: @@ -424,7 +424,7 @@ def get(data_ids): logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids))) - values = [object_store.get_data(data_id) for data_id in data_ids] + values = [object_store.get(data_id) for data_id in data_ids] # Initiate reference count based cleaup # if all the tasks were completed @@ -465,7 +465,7 @@ def wait(data_ids, num_returns=1): ready = [] logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids))) for data_id in not_ready.copy(): - if object_store.contains_data(data_id): + if object_store.contains(data_id): ready.append(data_id) not_ready.remove(data_id) pending_returns -= 1 diff --git a/unidist/core/backends/mpi/core/object_store.py b/unidist/core/backends/mpi/core/object_store.py index c938a69d..135dd7d1 100644 --- a/unidist/core/backends/mpi/core/object_store.py +++ b/unidist/core/backends/mpi/core/object_store.py @@ -1,3 +1,9 @@ +# Copyright (C) 2021-2023 Modin authors +# +# SPDX-License-Identifier: Apache-2.0 + +"""`ObjectStore` functionality.""" + from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.serialization import deserialize_complex_data from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore @@ -27,7 +33,7 @@ def get_instance(cls): cls.__instance = ObjectStore() return cls.__instance - def contains_data(self, data_id): + def contains(self, data_id): """ Check if the data associated with `data_id` exists in the current process. @@ -49,7 +55,7 @@ def contains_data(self, data_id): or shared_store.contains(data_id) ) - def get_data(self, data_id): + def get(self, data_id): """ Get data from any location in the current process. diff --git a/unidist/core/backends/mpi/core/worker/loop.py b/unidist/core/backends/mpi/core/worker/loop.py index f90cdc24..5515ff56 100644 --- a/unidist/core/backends/mpi/core/worker/loop.py +++ b/unidist/core/backends/mpi/core/worker/loop.py @@ -187,7 +187,7 @@ async def worker_loop(): if not ready_to_shutdown_posted: # Prepare the data # Actor method here is a data id so we have to retrieve it from the storage - method_name = object_store.get_data(request["task"]) + method_name = object_store.get(request["task"]) handler = request["handler"] actor_method = getattr(actor_map[handler], method_name) request["task"] = actor_method diff --git a/unidist/core/backends/mpi/core/worker/request_store.py b/unidist/core/backends/mpi/core/worker/request_store.py index b44b61a5..89703cf7 100644 --- a/unidist/core/backends/mpi/core/worker/request_store.py +++ b/unidist/core/backends/mpi/core/worker/request_store.py @@ -214,7 +214,7 @@ def process_wait_request(self, data_id): Only ROOT rank is supported for now, therefore no rank argument needed. """ object_store = ObjectStore.get_instance() - if object_store.contains_data(data_id): + if object_store.contains(data_id): # Executor wait just for signal # We use a blocking send here because the receiver is waiting for the result. communication.mpi_send_object( @@ -249,7 +249,7 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False): Request is asynchronous, no wait for the data sending. """ object_store = ObjectStore.get_instance() - if object_store.contains_data(data_id): + if object_store.contains(data_id): push_data( source_rank, data_id, diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index afb13bfb..f9be5c50 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -189,8 +189,8 @@ def unwrap_local_data_id(self, arg): if is_data_id(arg): local_store = LocalObjectStore.get_instance() object_store = ObjectStore.get_instance() - if object_store.contains_data(arg): - value = object_store.get_data(arg) + if object_store.contains(arg): + value = object_store.get(arg) # Data is already local or was pushed from master return value, False elif local_store.contains_data_owner(arg): @@ -425,7 +425,7 @@ def process_task_request(self, request): # Remote function here is a data id so we have to retrieve it from the storage, # whereas actor method is already materialized in the worker loop. if is_data_id(task): - task = object_store.get_data(task) + task = object_store.get(task) args = request["args"] kwargs = request["kwargs"] output_ids = request["output"] From 8d5ae139a5af7abfa95f9c06d106ed54398a9a93 Mon Sep 17 00:00:00 2001 From: Iaroslav Igoshev Date: Wed, 13 Dec 2023 16:24:14 +0100 Subject: [PATCH 6/6] Update unidist/core/backends/mpi/core/object_store.py --- unidist/core/backends/mpi/core/object_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unidist/core/backends/mpi/core/object_store.py b/unidist/core/backends/mpi/core/object_store.py index 135dd7d1..9d5b2710 100644 --- a/unidist/core/backends/mpi/core/object_store.py +++ b/unidist/core/backends/mpi/core/object_store.py @@ -11,7 +11,7 @@ class ObjectStore: """ - Class that combines checking and reciving data from all stores in a current process. + Class that combines checking and retrieving data from the shared and local stores in a current process. Notes -----