diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index e573aafd..23523920 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -19,6 +19,7 @@ ) from None from unidist.core.backends.mpi.core.serialization import serialize_complex_data +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.controller.garbage_collector import ( @@ -386,7 +387,6 @@ def put(data): data_id = local_store.generate_data_id(garbage_collector) serialized_data = serialize_complex_data(data) - local_store.put(data_id, data) if shared_store.is_allocated(): shared_store.put(data_id, serialized_data) else: @@ -411,13 +411,12 @@ def get(data_ids): object A Python object. """ - local_store = LocalObjectStore.get_instance() - + object_store = ObjectStore.get_instance() is_list = isinstance(data_ids, list) if not is_list: data_ids = [data_ids] remote_data_ids = [ - data_id for data_id in data_ids if not local_store.contains(data_id) + data_id for data_id in data_ids if not object_store.contains(data_id) ] # Remote data gets available in the local store inside `request_worker_data` if remote_data_ids: @@ -425,7 +424,7 @@ def get(data_ids): logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids))) - values = [local_store.get(data_id) for data_id in data_ids] + values = [object_store.get(data_id) for data_id in data_ids] # Initiate reference count based cleaup # if all the tasks were completed @@ -454,6 +453,7 @@ def wait(data_ids, num_returns=1): tuple List of data IDs that are ready and list of the remaining data IDs. """ + object_store = ObjectStore.get_instance() if not isinstance(data_ids, list): data_ids = [data_ids] # Since the controller should operate MpiDataID(s), @@ -463,11 +463,9 @@ def wait(data_ids, num_returns=1): not_ready = data_ids pending_returns = num_returns ready = [] - local_store = LocalObjectStore.get_instance() - logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids))) for data_id in not_ready.copy(): - if local_store.contains(data_id): + if object_store.contains(data_id): ready.append(data_id) not_ready.remove(data_id) pending_returns -= 1 diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index 0646fd86..6d6f3689 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -12,7 +12,9 @@ from unidist.core.backends.mpi.core.async_operations import AsyncOperations from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore -from unidist.core.backends.mpi.core.serialization import serialize_complex_data +from unidist.core.backends.mpi.core.serialization import ( + serialize_complex_data, +) logger = common.get_logger("common", "common.log") @@ -394,22 +396,19 @@ def push_data(dest_rank, value, is_blocking_op=False): data_id = value if shared_store.contains(data_id): _push_shared_data(dest_rank, data_id, is_blocking_op) + elif local_store.is_already_serialized(data_id): + _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) elif local_store.contains(data_id): - if local_store.is_already_serialized(data_id): - _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) + data = local_store.get(data_id) + serialized_data = serialize_complex_data(data) + if shared_store.is_allocated() and shared_store.should_be_shared( + serialized_data + ): + shared_store.put(data_id, serialized_data) + _push_shared_data(dest_rank, data_id, is_blocking_op) else: - data = local_store.get(data_id) - serialized_data = serialize_complex_data(data) - if shared_store.is_allocated() and shared_store.should_be_shared( - serialized_data - ): - shared_store.put(data_id, serialized_data) - _push_shared_data(dest_rank, data_id, is_blocking_op) - else: - local_store.cache_serialized_data(data_id, serialized_data) - _push_local_data( - dest_rank, data_id, is_blocking_op, is_serialized=True - ) + local_store.cache_serialized_data(data_id, serialized_data) + _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) elif local_store.contains_data_owner(data_id): _push_data_owner(dest_rank, data_id) else: diff --git a/unidist/core/backends/mpi/core/local_object_store.py b/unidist/core/backends/mpi/core/local_object_store.py index bc80be5c..2fbaf404 100644 --- a/unidist/core/backends/mpi/core/local_object_store.py +++ b/unidist/core/backends/mpi/core/local_object_store.py @@ -277,6 +277,10 @@ def cache_serialized_data(self, data_id, data): data : object Serialized data to cache. """ + # We make a copy to avoid data corruption obtained through out-of-band serialization, + # and buffers are marked read-only to prevent them from being modified. + # `to_bytes()` call handles both points. + data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]] self._serialization_cache[data_id] = data self.maybe_update_data_id_map(data_id) diff --git a/unidist/core/backends/mpi/core/object_store.py b/unidist/core/backends/mpi/core/object_store.py new file mode 100644 index 00000000..9d5b2710 --- /dev/null +++ b/unidist/core/backends/mpi/core/object_store.py @@ -0,0 +1,90 @@ +# Copyright (C) 2021-2023 Modin authors +# +# SPDX-License-Identifier: Apache-2.0 + +"""`ObjectStore` functionality.""" + +from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore +from unidist.core.backends.mpi.core.serialization import deserialize_complex_data +from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore + + +class ObjectStore: + """ + Class that combines checking and retrieving data from the shared and local stores in a current process. + + Notes + ----- + The store checks for both deserialized and serialized data. + """ + + __instance = None + + @classmethod + def get_instance(cls): + """ + Get instance of ``ObjectStore``. + + Returns + ------- + ObjectStore + """ + if cls.__instance is None: + cls.__instance = ObjectStore() + return cls.__instance + + def contains(self, data_id): + """ + Check if the data associated with `data_id` exists in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + bool + Return the status if an object exist in the current process. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + return ( + local_store.contains(data_id) + or local_store.is_already_serialized(data_id) + or shared_store.contains(data_id) + ) + + def get(self, data_id): + """ + Get data from any location in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + object + Return data associated with `data_id`. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + + if local_store.contains(data_id): + return local_store.get(data_id) + + if local_store.is_already_serialized(data_id): + serialized_data = local_store.get_serialized_data(data_id) + value = deserialize_complex_data( + serialized_data["s_data"], + serialized_data["raw_buffers"], + serialized_data["buffer_count"], + ) + elif shared_store.contains(data_id): + value = shared_store.get(data_id) + else: + raise ValueError("The current data ID is not contained in the procces.") + local_store.put(data_id, value) + return value diff --git a/unidist/core/backends/mpi/core/shared_object_store.py b/unidist/core/backends/mpi/core/shared_object_store.py index c3f76405..85050081 100644 --- a/unidist/core/backends/mpi/core/shared_object_store.py +++ b/unidist/core/backends/mpi/core/shared_object_store.py @@ -747,7 +747,7 @@ def put(self, data_id, serialized_data): # put shared info self._put_shared_info(data_id, shared_info) - def get(self, data_id, owner_rank, shared_info): + def get(self, data_id, owner_rank=None, shared_info=None): """ Get data from another worker using shared memory. @@ -755,54 +755,66 @@ def get(self, data_id, owner_rank, shared_info): ---------- data_id : unidist.core.backends.mpi.core.common.MpiDataID An ID to data. - owner_rank : int + owner_rank : int, default: None The rank that sent the data. - shared_info : dict + This value is used to synchronize data in shared memory between different hosts + if the value is not ``None``. + shared_info : dict, default: None The necessary information to properly deserialize data from shared memory. + If `shared_info` is ``None``, the data already exists in shared memory in the current process. """ - mpi_state = communication.MPIState.get_instance() - s_data_len = shared_info["s_data_len"] - raw_buffers_len = shared_info["raw_buffers_len"] - service_index = shared_info["service_index"] - buffer_count = shared_info["buffer_count"] - - # check data in shared memory - if not self._check_service_info(data_id, service_index): - # reserve shared memory - shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len]) - reservation_info = communication.send_reserve_operation( - mpi_state.global_comm, data_id, shared_data_len - ) - - service_index = reservation_info["service_index"] - # check if worker should sync shared buffer or it is doing by another worker - if reservation_info["is_first_request"]: - # syncronize shared buffer - self._sync_shared_memory_from_another_host( - mpi_state.global_comm, - data_id, - owner_rank, - reservation_info["first_index"], - reservation_info["last_index"], - service_index, - ) - # put service info - self._put_service_info( - service_index, data_id, reservation_info["first_index"] + if shared_info is None: + shared_info = self.get_shared_info(data_id) + else: + mpi_state = communication.MPIState.get_instance() + s_data_len = shared_info["s_data_len"] + raw_buffers_len = shared_info["raw_buffers_len"] + service_index = shared_info["service_index"] + buffer_count = shared_info["buffer_count"] + + # check data in shared memory + if not self._check_service_info(data_id, service_index): + # reserve shared memory + shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len]) + reservation_info = communication.send_reserve_operation( + mpi_state.global_comm, data_id, shared_data_len ) - else: - # wait while another worker syncronize shared buffer - while not self._check_service_info(data_id, service_index): - time.sleep(MpiBackoff.get()) - # put shared info with updated data_id and service_index - shared_info = common.MetadataPackage.get_shared_info( - data_id, s_data_len, raw_buffers_len, buffer_count, service_index - ) - self._put_shared_info(data_id, shared_info) + service_index = reservation_info["service_index"] + # check if worker should sync shared buffer or it is doing by another worker + if reservation_info["is_first_request"]: + # syncronize shared buffer + if owner_rank is None: + raise ValueError( + "The data is not in the host's shared memory and the data must be synchronized, " + + "but the owner rank is not defined." + ) + + self._sync_shared_memory_from_another_host( + mpi_state.global_comm, + data_id, + owner_rank, + reservation_info["first_index"], + reservation_info["last_index"], + service_index, + ) + # put service info + self._put_service_info( + service_index, data_id, reservation_info["first_index"] + ) + else: + # wait while another worker syncronize shared buffer + while not self._check_service_info(data_id, service_index): + time.sleep(MpiBackoff.get()) + + # put shared info with updated data_id and service_index + shared_info = common.MetadataPackage.get_shared_info( + data_id, s_data_len, raw_buffers_len, buffer_count, service_index + ) + self._put_shared_info(data_id, shared_info) - # increment ref - self._increment_ref_number(data_id, shared_info["service_index"]) + # increment ref + self._increment_ref_number(data_id, shared_info["service_index"]) # read from shared buffer and deserialized return self._read_from_shared_buffer(data_id, shared_info) diff --git a/unidist/core/backends/mpi/core/worker/loop.py b/unidist/core/backends/mpi/core/worker/loop.py index 32d6eda7..5515ff56 100644 --- a/unidist/core/backends/mpi/core/worker/loop.py +++ b/unidist/core/backends/mpi/core/worker/loop.py @@ -16,6 +16,7 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.worker.request_store import RequestStore from unidist.core.backends.mpi.core.worker.task_store import TaskStore @@ -85,6 +86,7 @@ async def worker_loop(): ``unidist.core.backends.mpi.core.common.Operations`` defines a set of supported operations. """ task_store = TaskStore.get_instance() + object_store = ObjectStore.get_instance() local_store = LocalObjectStore.get_instance() request_store = RequestStore.get_instance() async_operations = AsyncOperations.get_instance() @@ -185,7 +187,7 @@ async def worker_loop(): if not ready_to_shutdown_posted: # Prepare the data # Actor method here is a data id so we have to retrieve it from the storage - method_name = local_store.get(request["task"]) + method_name = object_store.get(request["task"]) handler = request["handler"] actor_method = getattr(actor_map[handler], method_name) request["task"] = actor_method diff --git a/unidist/core/backends/mpi/core/worker/request_store.py b/unidist/core/backends/mpi/core/worker/request_store.py index 47f2dc9c..89703cf7 100644 --- a/unidist/core/backends/mpi/core/worker/request_store.py +++ b/unidist/core/backends/mpi/core/worker/request_store.py @@ -6,8 +6,8 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication -from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.controller.common import push_data +from unidist.core.backends.mpi.core.object_store import ObjectStore mpi_state = communication.MPIState.get_instance() @@ -213,7 +213,8 @@ def process_wait_request(self, data_id): ----- Only ROOT rank is supported for now, therefore no rank argument needed. """ - if LocalObjectStore.get_instance().contains(data_id): + object_store = ObjectStore.get_instance() + if object_store.contains(data_id): # Executor wait just for signal # We use a blocking send here because the receiver is waiting for the result. communication.mpi_send_object( @@ -247,8 +248,8 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False): ----- Request is asynchronous, no wait for the data sending. """ - local_store = LocalObjectStore.get_instance() - if local_store.contains(data_id): + object_store = ObjectStore.get_instance() + if object_store.contains(data_id): push_data( source_rank, data_id, diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index 4e8b7bc1..f9be5c50 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -11,6 +11,7 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication from unidist.core.backends.mpi.core.async_operations import AsyncOperations +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.serialization import serialize_complex_data @@ -187,8 +188,9 @@ def unwrap_local_data_id(self, arg): """ if is_data_id(arg): local_store = LocalObjectStore.get_instance() - if local_store.contains(arg): - value = LocalObjectStore.get_instance().get(arg) + object_store = ObjectStore.get_instance() + if object_store.contains(arg): + value = object_store.get(arg) # Data is already local or was pushed from master return value, False elif local_store.contains_data_owner(arg): @@ -417,13 +419,13 @@ def process_task_request(self, request): dict or None Same request if the task couldn`t be executed, otherwise ``None``. """ + object_store = ObjectStore.get_instance() # Parse request - local_store = LocalObjectStore.get_instance() task = request["task"] # Remote function here is a data id so we have to retrieve it from the storage, # whereas actor method is already materialized in the worker loop. if is_data_id(task): - task = local_store.get(task) + task = object_store.get(task) args = request["args"] kwargs = request["kwargs"] output_ids = request["output"] diff --git a/unidist/test/test_task.py b/unidist/test/test_task.py index c53b8f11..1cb346d8 100644 --- a/unidist/test/test_task.py +++ b/unidist/test/test_task.py @@ -122,3 +122,15 @@ def f(params): } assert_equal(f.remote(data), data) + + +@pytest.mark.xfail( + Backend.get() == BackendName.PYSEQ, + reason="PUT using PYSEQ does not provide immutable data", +) +def test_data_immutability(): + data = [1, 2, 3] + object_ref = unidist.put(data) + + data[0] = 111 + assert_equal(object_ref, [1, 2, 3])