From 41106199ea0da94a50aee0125fa33ab5ca066be3 Mon Sep 17 00:00:00 2001 From: "Igoshev, Iaroslav" Date: Sat, 16 Sep 2023 10:02:44 +0200 Subject: [PATCH] FIX-#343: Get rid of the hacks in the MPI backend Signed-off-by: Igoshev, Iaroslav --- unidist/config/backends/mpi/envvars.py | 32 ++- unidist/core/backends/mpi/core/common.py | 7 +- .../core/backends/mpi/core/communication.py | 121 +++++------ .../backends/mpi/core/controller/actor.py | 2 + .../core/backends/mpi/core/controller/api.py | 12 +- .../backends/mpi/core/controller/common.py | 26 ++- .../core/backends/mpi/core/serialization.py | 68 +++++- .../backends/mpi/core/shared_object_store.py | 54 ++--- unidist/core/backends/mpi/core/worker/loop.py | 6 +- .../backends/mpi/core/worker/task_store.py | 205 +++--------------- 10 files changed, 237 insertions(+), 296 deletions(-) diff --git a/unidist/config/backends/mpi/envvars.py b/unidist/config/backends/mpi/envvars.py index 070606da..5bbfb333 100644 --- a/unidist/config/backends/mpi/envvars.py +++ b/unidist/config/backends/mpi/envvars.py @@ -21,11 +21,39 @@ class MpiHosts(EnvironmentVariable, type=ExactStr): class MpiPickleThreshold(EnvironmentVariable, type=int): - """Minimum buffer size for serialization with pickle 5 protocol.""" + """ + Minimum buffer size for serialization with pickle 5 protocol. + + Notes + ----- + If the shared object store is enabled, ``MpiSharedObjectStoreThreshold`` takes + precedence on this configuration value and the threshold gets overridden. + It is done intentionally to prevent multiple copies when putting an object + into the local object store or into the shared object store. + Data copy happens once when doing in-band serialization in depend on the threshold. + In some cases output of a remote task can take up the memory of the task arguments. + If those arguments are placed in the shared object store, this location should not be overwritten + while output is being used, otherwise the output value may be corrupted. + """ default = 1024**2 // 4 # 0.25 MiB varname = "UNIDIST_MPI_PICKLE_THRESHOLD" + @classmethod + def get(cls) -> int: + """ + Get minimum buffer size for serialization with pickle 5 protocol. + + Returns + ------- + int + """ + if MpiSharedObjectStore.get(): + mpi_pickle_threshold = MpiSharedObjectStoreThreshold.get() + else: + mpi_pickle_threshold = super().get() + return mpi_pickle_threshold + class MpiBackoff(EnvironmentVariable, type=float): """ @@ -65,5 +93,5 @@ class MpiSharedObjectStoreMemory(EnvironmentVariable, type=int): class MpiSharedObjectStoreThreshold(EnvironmentVariable, type=int): """Minimum size of data to put into the shared object store.""" - default = 1024**2 # 1 MiB + default = 10 ** 5 # 100 KB varname = "UNIDIST_MPI_SHARED_OBJECT_STORE_THRESHOLD" diff --git a/unidist/core/backends/mpi/core/common.py b/unidist/core/backends/mpi/core/common.py index 7dfd090e..2b075301 100755 --- a/unidist/core/backends/mpi/core/common.py +++ b/unidist/core/backends/mpi/core/common.py @@ -123,12 +123,16 @@ class MetadataPackage(ImmutableDict): SHARED_DATA = 1 @classmethod - def get_local_info(cls, s_data_len, raw_buffers_len, buffer_count): + def get_local_info(cls, data_id, s_data_len, raw_buffers_len, buffer_count): """ Get information package for sending local data. Parameters ---------- + data_id : unidist.core.backends.common.data_id.DataID + An ID to data. + Can be ``None`` to indicate a fake `data_id` to get full metadata package. + It is usually used when submitting a task or an actor for not yet serialized data. s_data_len : int Main buffer length. raw_buffers_len : list @@ -145,6 +149,7 @@ def get_local_info(cls, s_data_len, raw_buffers_len, buffer_count): return MetadataPackage( { "package_type": MetadataPackage.LOCAL_DATA, + "id": data_id, "s_data_len": s_data_len, "raw_buffers_len": raw_buffers_len, "buffer_count": buffer_count, diff --git a/unidist/core/backends/mpi/core/communication.py b/unidist/core/backends/mpi/core/communication.py index 08dfb4f9..661e2e13 100755 --- a/unidist/core/backends/mpi/core/communication.py +++ b/unidist/core/backends/mpi/core/communication.py @@ -17,8 +17,9 @@ from unidist.config import MpiBackoff from unidist.core.backends.mpi.core.serialization import ( - ComplexDataSerializer, SimpleDataSerializer, + serialize_complex_data, + deserialize_complex_data, ) import unidist.core.backends.mpi.core.common as common @@ -525,7 +526,7 @@ def mpi_busy_wait_recv(comm, source_rank): # --------------------------------- # -def _send_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank): +def _send_complex_data_impl(comm, s_data, raw_buffers, dest_rank, info_package): """ Send already serialized complex data. @@ -537,21 +538,16 @@ def _send_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank): Serialized data as bytearray. raw_buffers : list Pickle buffers list, out-of-band data collected with pickle 5 protocol. - buffer_count : list - List of the number of buffers for each object - to be serialized/deserialized using the pickle 5 protocol. - See details in :py:class:`~unidist.core.backends.mpi.core.serialization.ComplexDataSerializer`. dest_rank : int Target MPI process to transfer data. + info_package : unidist.core.backends.mpi.core.common.MetadataPackage + Required information to deserialize data on a receiver side. Notes ----- The special tags are used for this communication, namely, ``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``. """ - info_package = common.MetadataPackage.get_local_info( - len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count - ) # wrap to dict for sending and correct deserialization of the object by the recipient comm.send(dict(info_package), dest=dest_rank, tag=common.MPITag.OBJECT) with pkl5._bigmpi as bigmpi: @@ -577,44 +573,40 @@ def send_complex_data(comm, data, dest_rank, is_serialized=False): Returns ------- - dict or None + dict Serialized data for caching purpose. Notes ----- - * ``None`` is returned if `data` is already serialized, - otherwise ``dict`` containing data serialized in this function. * This blocking send is used when we have to wait for completion of the communication, which is necessary for the pipeline to continue, or when the receiver is waiting for a result. Otherwise, use non-blocking ``isend_complex_data``. """ - serialized_data = None if is_serialized: s_data = data["s_data"] raw_buffers = data["raw_buffers"] buffer_count = data["buffer_count"] + # pop `data_id` out of the dict because it will be send as part of metadata package + data_id = data.pop("id") + serialized_data = data else: - serializer = ComplexDataSerializer() - # Main job - s_data = serializer.serialize(data) - # Retrive the metadata - raw_buffers = serializer.buffers - buffer_count = serializer.buffer_count - - serialized_data = { - "s_data": s_data, - "raw_buffers": raw_buffers, - "buffer_count": buffer_count, - } + data_id = data["id"] + serialized_data = serialize_complex_data(data) + s_data = serialized_data["s_data"] + raw_buffers = serialized_data["raw_buffers"] + buffer_count = serialized_data["buffer_count"] + info_package = common.MetadataPackage.get_local_info( + data_id, len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count + ) # MPI communication - _send_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank) + _send_complex_data_impl(comm, s_data, raw_buffers, dest_rank, info_package) # For caching purpose return serialized_data -def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank): +def _isend_complex_data_impl(comm, s_data, raw_buffers, dest_rank, info_package): """ Send serialized complex data. @@ -628,12 +620,10 @@ def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank) A serialized msgpack data. raw_buffers : list A list of pickle buffers. - buffer_count : list - List of the number of buffers for each object - to be serialized/deserialized using the pickle 5 protocol. - See details in :py:class:`~unidist.core.backends.mpi.core.serialization.ComplexDataSerializer`. dest_rank : int Target MPI process to transfer data. + info_package : unidist.core.backends.mpi.core.common.MetadataPackage + Required information to deserialize data on a receiver side. Returns ------- @@ -646,13 +636,9 @@ def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank) ``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``. """ handlers = [] - info_package = common.MetadataPackage.get_local_info( - len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count - ) # wrap to dict for sending and correct deserialization of the object by the recipient h1 = comm.isend(dict(info_package), dest=dest_rank, tag=common.MPITag.OBJECT) handlers.append((h1, None)) - with pkl5._bigmpi as bigmpi: h2 = comm.Isend(bigmpi(s_data), dest=dest_rank, tag=common.MPITag.BUFFER) handlers.append((h2, s_data)) @@ -663,7 +649,7 @@ def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank) return handlers -def isend_complex_data(comm, data, dest_rank): +def isend_complex_data(comm, data, dest_rank, is_serialized=False): """ Send the data that consists of different user provided complex types, lambdas and buffers in a non-blocking way. @@ -677,6 +663,8 @@ def isend_complex_data(comm, data, dest_rank): Data object to send. dest_rank : int Target MPI process to transfer data. + is_serialized : bool, default: False + `operation_data` is already serialized or not. Returns ------- @@ -694,16 +682,28 @@ def isend_complex_data(comm, data, dest_rank): * The special tags are used for this communication, namely, ``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``. """ - serializer = ComplexDataSerializer() - # Main job - s_data = serializer.serialize(data) - # Retrive the metadata - raw_buffers = serializer.buffers - buffer_count = serializer.buffer_count + if is_serialized: + s_data = data["s_data"] + raw_buffers = data["raw_buffers"] + buffer_count = data["buffer_count"] + # pop `data_id` out of the dict because it will be send as part of metadata package + data_id = data.pop("id") + else: + serialized_data = serialize_complex_data(data) + s_data = serialized_data["s_data"] + raw_buffers = serialized_data["raw_buffers"] + buffer_count = serialized_data["buffer_count"] + # Set fake `data_id` to get full metadata package below. + # This branch is usually used when submitting a task or an actor + # for not yet serialized data. + data_id = None - # Send message pack bytestring + info_package = common.MetadataPackage.get_local_info( + data_id, len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count + ) + # MPI communication handlers = _isend_complex_data_impl( - comm, s_data, raw_buffers, buffer_count, dest_rank + comm, s_data, raw_buffers, dest_rank, info_package ) return handlers, s_data, raw_buffers, buffer_count @@ -742,11 +742,7 @@ def recv_complex_data(comm, source_rank, info_package): for rbuf in raw_buffers: comm.Recv(bigmpi(rbuf), source=source_rank, tag=common.MPITag.BUFFER) - # Set the necessary metadata for unpacking - deserializer = ComplexDataSerializer(raw_buffers, buffer_count) - - # Start unpacking - return deserializer.deserialize(msgpack_buffer) + return deserialize_complex_data(msgpack_buffer, raw_buffers, buffer_count) # ---------- # @@ -845,14 +841,11 @@ def isend_complex_operation( Returns ------- - dict and dict or dict and None - Async handlers and serialization data for caching purpose. + list and dict + Async handlers list and serialization data dict for caching purpose. Notes ----- - * Function always returns a ``dict`` containing async handlers to the sent MPI operations. - In addition, ``None`` is returned if `operation_data` is already serialized, - otherwise ``dict`` containing data serialized in this function. * The special tags are used for this communication, namely, ``common.MPITag.OPERATION``, ``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``. """ @@ -867,24 +860,26 @@ def isend_complex_operation( s_data = operation_data["s_data"] raw_buffers = operation_data["raw_buffers"] buffer_count = operation_data["buffer_count"] - + # pop `data_id` out of the dict because it will be send as part of metadata package + data_id = operation_data.pop("id") + info_package = common.MetadataPackage.get_local_info( + data_id, len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count + ) h2_list = _isend_complex_data_impl( - comm, s_data, raw_buffers, buffer_count, dest_rank + comm, s_data, raw_buffers, dest_rank, info_package ) handlers.extend(h2_list) - - return handlers, None else: # Serialize and send the data h2_list, s_data, raw_buffers, buffer_count = isend_complex_data( comm, operation_data, dest_rank ) handlers.extend(h2_list) - return handlers, { - "s_data": s_data, - "raw_buffers": raw_buffers, - "buffer_count": buffer_count, - } + return handlers, { + "s_data": s_data, + "raw_buffers": raw_buffers, + "buffer_count": buffer_count, + } def isend_serialized_operation(comm, operation_type, operation_data, dest_rank): diff --git a/unidist/core/backends/mpi/core/controller/actor.py b/unidist/core/backends/mpi/core/controller/actor.py index 88925d84..c33b0c25 100644 --- a/unidist/core/backends/mpi/core/controller/actor.py +++ b/unidist/core/backends/mpi/core/controller/actor.py @@ -62,6 +62,7 @@ def __call__(self, *args, num_returns=1, **kwargs): operation_type, operation_data, self._actor._owner_rank, + is_serialized=False, ) async_operations.extend(h_list) return output_id @@ -126,6 +127,7 @@ def __init__(self, cls, *args, owner_rank=None, handler_id=None, **kwargs): operation_type, operation_data, self._owner_rank, + is_serialized=False, ) async_operations.extend(h_list) diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index 40f6c62e..18c95bdf 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -18,6 +18,7 @@ "Missing dependency 'mpi4py'. Use pip or conda to install it." ) from None +from unidist.core.backends.mpi.core.serialization import serialize_complex_data from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.controller.garbage_collector import ( @@ -369,10 +370,13 @@ def put(data): """ local_store = LocalObjectStore.get_instance() shared_store = SharedObjectStore.get_instance() + data_id = local_store.generate_data_id(garbage_collector) + serialized_data = serialize_complex_data(data) + local_store.cache_serialized_data(data_id, serialized_data) local_store.put(data_id, data) if shared_store.is_allocated(): - shared_store.put(data_id, data) + shared_store.put(data_id, serialized_data) logger.debug("PUT {} id".format(data_id._id)) @@ -545,13 +549,14 @@ def submit(task, *args, num_returns=1, **kwargs): unwrapped_args = [common.unwrap_data_ids(arg) for arg in args] unwrapped_kwargs = {k: common.unwrap_data_ids(v) for k, v in kwargs.items()} - push_data(dest_rank, common.master_data_ids_to_base(task)) + task_base_id = common.master_data_ids_to_base(task) + push_data(dest_rank, task_base_id) push_data(dest_rank, unwrapped_args) push_data(dest_rank, unwrapped_kwargs) operation_type = common.Operation.EXECUTE operation_data = { - "task": task, + "task": task_base_id, "args": unwrapped_args, "kwargs": unwrapped_kwargs, "output": common.master_data_ids_to_base(output_ids), @@ -562,6 +567,7 @@ def submit(task, *args, num_returns=1, **kwargs): operation_type, operation_data, dest_rank, + is_serialized=False, ) async_operations.extend(h_list) diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index ddf77b97..8c84c70a 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -12,6 +12,7 @@ from unidist.core.backends.mpi.core.async_operations import AsyncOperations from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore +from unidist.core.backends.mpi.core.serialization import serialize_complex_data logger = common.get_logger("common", "common.log") @@ -142,9 +143,15 @@ def pull_data(comm, owner_rank): "data": data, } elif info_package["package_type"] == common.MetadataPackage.LOCAL_DATA: - return communication.recv_complex_data( + local_store = LocalObjectStore.get_instance() + data_id = local_store.get_unique_data_id(info_package["id"]) + data = communication.recv_complex_data( comm, owner_rank, info_package=info_package ) + return { + "id": data_id, + "data": data, + } else: raise ValueError("Unexpected package of data info!") @@ -197,7 +204,6 @@ def request_worker_data(data_id): # Caching the result, check the protocol correctness here local_store.put(data_id, data) - return data @@ -229,6 +235,8 @@ def _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized): # Push the local master data to the target worker directly if is_serialized: operation_data = local_store.get_serialized_data(data_id) + # Insert `data_id` to get full metadata package further + operation_data["id"] = data_id else: operation_data = { "id": data_id, @@ -241,7 +249,7 @@ def _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized): mpi_state.comm, operation_data, dest_rank, - is_serialized, + is_serialized=is_serialized, ) else: h_list, serialized_data = communication.isend_complex_operation( @@ -249,11 +257,11 @@ def _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized): operation_type, operation_data, dest_rank, - is_serialized, + is_serialized=is_serialized, ) async_operations.extend(h_list) - if not is_serialized: + if not is_serialized or not local_store.is_already_serialized(data_id): local_store.cache_serialized_data(data_id, serialized_data) # Remember pushed id @@ -358,12 +366,14 @@ def push_data(dest_rank, value, is_blocking_op=False): _push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True) else: data = local_store.get(data_id) - if shared_store.should_be_shared(data): - shared_store.put(data_id, data) + serialized_data = serialize_complex_data(data) + if shared_store.is_allocated() and shared_store.should_be_shared(serialized_data): + shared_store.put(data_id, serialized_data) _push_shared_data(dest_rank, data_id, is_blocking_op) else: + local_store.cache_serialized_data(data_id, serialized_data) _push_local_data( - dest_rank, data_id, is_blocking_op, is_serialized=False + dest_rank, data_id, is_blocking_op, is_serialized=True ) elif local_store.contains_data_owner(data_id): _push_data_owner(dest_rank, data_id) diff --git a/unidist/core/backends/mpi/core/serialization.py b/unidist/core/backends/mpi/core/serialization.py index 84cea777..ba18d031 100755 --- a/unidist/core/backends/mpi/core/serialization.py +++ b/unidist/core/backends/mpi/core/serialization.py @@ -219,6 +219,11 @@ def serialize(self, data): data : object Data to serialize. + Returns + ------- + bytes + Serialized data. + Notes ----- Uses msgpack, cloudpickle and pickle libraries. @@ -242,7 +247,9 @@ def _decode_custom(self, obj): return pkl.loads(obj["as_bytes"]) elif "__pickle5_custom__" in obj: frame = pkl.loads(obj["as_bytes"], buffers=self.buffers) - del self.buffers[: self.buffer_count.pop(0)] + # check if there are out-of-band buffers + if self.buffer_count: + del self.buffers[: self.buffer_count.pop(0)] return frame else: return obj @@ -256,6 +263,11 @@ def deserialize(self, s_data): s_data : bytearray Data to de-serialize. + Returns + ------- + object + Deserialized data. + Notes ----- Uses msgpack, cloudpickle and pickle libraries. @@ -340,3 +352,57 @@ def deserialize_pickle(self, data): Original reconstructed object. """ return pkl.loads(data) + + +def serialize_complex_data(data): + """ + Serialize data to a bytearray. + + Parameters + ---------- + data : object + Data to serialize. + + Returns + ------- + bytes + Serialized data. + + Notes + ----- + Uses msgpack, cloudpickle and pickle libraries. + """ + serializer = ComplexDataSerializer() + s_data = serializer.serialize(data) + serialized_data = { + "s_data": s_data, + "raw_buffers": serializer.buffers, + "buffer_count": serializer.buffer_count, + } + return serialized_data + + +def deserialize_complex_data(s_data, raw_buffers, buffer_count): + """ + Deserialize data based on passed in information. + + Parameters + ---------- + s_data : bytearray + Serialized msgpack data. + raw_buffers : list + A list of ``PickleBuffer`` objects for data decoding. + buffer_count : list + List of the number of buffers for each object + to be deserialized using the pickle 5 protocol. + + Returns + ------- + object + Deserialized data. + Notes + ----- + Uses msgpack, cloudpickle and pickle libraries. + """ + deserializer = ComplexDataSerializer(raw_buffers, buffer_count) + return deserializer.deserialize(s_data) diff --git a/unidist/core/backends/mpi/core/shared_object_store.py b/unidist/core/backends/mpi/core/shared_object_store.py index 15b1323b..b8923b42 100644 --- a/unidist/core/backends/mpi/core/shared_object_store.py +++ b/unidist/core/backends/mpi/core/shared_object_store.py @@ -27,7 +27,9 @@ MpiBackoff, ) from unidist.core.backends.mpi.core import common, communication -from unidist.core.backends.mpi.core.serialization import ComplexDataSerializer +from unidist.core.backends.mpi.core.serialization import ( + deserialize_complex_data, +) # TODO: Find a way to move this after all imports mpi4py.rc(recv_mprobe=False, initialize=False) @@ -385,11 +387,7 @@ def _read_from_shared_buffer(self, data_id, shared_info): ) prev_last_index = raw_last_index - # Set the necessary metadata for unpacking - deserializer = ComplexDataSerializer(raw_buffers, buffer_count) - - # Start unpacking - data = deserializer.deserialize(s_data) + data = deserialize_complex_data(s_data, raw_buffers, buffer_count) self.logger.debug( f"Rank {communication.MPIState.get_instance().global_rank}: Get {data_id} from {first_index} to {prev_last_index}. Service index: {service_index}" ) @@ -530,31 +528,16 @@ def should_be_shared(self, data): Parameters ---------- - data : object - Any data needed to be sent to another process. + data : dict + Serialized data to check its size. Returns ------- bool Return the ``True`` status if data should be sent using shared memory. """ - if self.shared_buffer is None: - return False - - size = sys.getsizeof(data) - - # sys.getsizeof may return incorrect data size as - # it doesn't fully assume the whole structure of an object - # so we manually compute the data size for np.array here. - try: - import numpy as np - - if isinstance(data, np.ndarray): - size = data.size * data.dtype.itemsize - except ImportError: - pass - - return size > MpiSharedObjectStoreThreshold.get() + data_size = len(data["s_data"]) + sum([len(buf) for buf in data["raw_buffers"]]) + return data_size > MpiSharedObjectStoreThreshold.get() def contains(self, data_id): """ @@ -670,30 +653,21 @@ def delete_service_info(self, data_id, service_index): ) raise RuntimeError("Unexpected data_id for cleanup shared memory") - def put(self, data_id, data): + def put(self, data_id, serialized_data): """ Put data into shared memory. Parameters ---------- data_id : unidist.core.backends.common.data_id.DataID - data : object - The current data. + serialized_data : dict + Serialized data to put into the storage. """ mpi_state = communication.MPIState.get_instance() - # serialize data - serializer = ComplexDataSerializer() - s_data = serializer.serialize(data) - raw_buffers = serializer.buffers - buffer_count = serializer.buffer_count - data_size = len(s_data) + sum([len(buf) for buf in raw_buffers]) - serialized_data = { - "s_data": s_data, - "raw_buffers": raw_buffers, - "buffer_count": buffer_count, - } - + data_size = len(serialized_data["s_data"]) + sum( + [len(buf) for buf in serialized_data["raw_buffers"]] + ) # reserve shared memory reservation_data = communication.send_reserve_operation( mpi_state.comm, data_id, data_size diff --git a/unidist/core/backends/mpi/core/worker/loop.py b/unidist/core/backends/mpi/core/worker/loop.py index 76b2257c..60a89d39 100644 --- a/unidist/core/backends/mpi/core/worker/loop.py +++ b/unidist/core/backends/mpi/core/worker/loop.py @@ -105,7 +105,7 @@ async def worker_loop(): # Proceed the request if operation_type == common.Operation.EXECUTE: - request = pull_data(mpi_state.comm, source_rank) + request = pull_data(mpi_state.comm, source_rank)["data"] if not ready_to_shutdown_posted: # Execute the task if possible pending_request = task_store.process_task_request(request) @@ -173,7 +173,7 @@ async def worker_loop(): request_store.process_wait_request(request["id"]) elif operation_type == common.Operation.ACTOR_CREATE: - request = pull_data(mpi_state.comm, source_rank) + request = pull_data(mpi_state.comm, source_rank)["data"] if not ready_to_shutdown_posted: cls = request["class"] args = request["args"] @@ -182,7 +182,7 @@ async def worker_loop(): actor_map[handler] = cls(*args, **kwargs) elif operation_type == common.Operation.ACTOR_EXECUTE: - request = pull_data(mpi_state.comm, source_rank) + request = pull_data(mpi_state.comm, source_rank)["data"] if not ready_to_shutdown_posted: # Prepare the data # Actor method here is a data id so we have to retrieve it from the storage diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index 713b54a7..7772b6ed 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -6,7 +6,6 @@ import functools import inspect import time -import weakref from unidist.core.backends.common.data_id import is_data_id import unidist.core.backends.mpi.core.common as common @@ -14,6 +13,7 @@ from unidist.core.backends.mpi.core.async_operations import AsyncOperations from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore +from unidist.core.backends.mpi.core.serialization import serialize_complex_data from unidist.core.backends.mpi.core.worker.request_store import RequestStore mpi_state = communication.MPIState.get_instance() @@ -39,11 +39,6 @@ def __init__(self): self.event_loop = asyncio.get_event_loop() # Started async tasks self.background_tasks = set() - # In some cases output of tasks can take up the memory of task arguments. - # If these arguments are placed in shared memory, this location should not be overwritten - # while output is being used, otherwise the output value may be corrupted. - # {output weak ref: list of argument strong ref} - self.output_depends = weakref.WeakKeyDictionary() @classmethod def get_instance(cls): @@ -169,44 +164,6 @@ def request_worker_data(self, dest_rank, data_id): # Save request in order to prevent massive communication during pending task checks RequestStore.get_instance().put(data_id, dest_rank, RequestStore.DATA) - def check_local_data_id(self, arg): - """ - Inspect argument if the ID is available in the local object store. - - If the local object store doesn't contain this data ID, - request the data from another worker. - - Parameters - ---------- - arg : object or unidist.core.backends.common.data_id.DataID - Data ID or object to inspect. - - Returns - ------- - tuple - Same value and special flag. - - Notes - ----- - If the data ID could not be resolved, the function returns ``True``. - """ - if is_data_id(arg): - local_store = LocalObjectStore.get_instance() - arg = local_store.get_unique_data_id(arg) - if local_store.contains(arg): - return arg, False - elif local_store.contains_data_owner(arg): - if not RequestStore.get_instance().is_data_already_requested(arg): - # Request the data from an owner worker - owner_rank = local_store.get_data_owner(arg) - if owner_rank != communication.MPIState.get_instance().global_rank: - self.request_worker_data(owner_rank, arg) - return arg, True - else: - raise ValueError("DataID is missing!") - else: - return arg, False - def unwrap_local_data_id(self, arg): """ Inspect argument and get the ID associated data from the local object store if available. @@ -247,25 +204,6 @@ def unwrap_local_data_id(self, arg): else: return arg, False - def check_output_depends(self, data_ids, depends_id): - local_store = LocalObjectStore.get_instance() - if isinstance(data_ids, (list, tuple)): - for data_id in data_ids: - # the local store may not contain the data id yet - # if a remote function is a coroutine - if local_store.contains(data_id): - value = local_store.get(data_id) - if check_data_out_of_band(value): - self.output_depends[data_id] = depends_id - - else: - # the local store may not contain the data id yet - # if a remote function is a coroutine - if local_store.contains(data_ids): - value = local_store.get(data_ids) - if check_data_out_of_band(value): - self.output_depends[data_ids] = depends_id - def execute_received_task(self, output_data_ids, task, args, kwargs): """ Execute a task/actor-task and handle results. @@ -288,6 +226,9 @@ def execute_received_task(self, output_data_ids, task, args, kwargs): local_store = LocalObjectStore.get_instance() shared_store = SharedObjectStore.get_instance() completed_data_ids = [] + # Note that if a task is coroutine, + # the local store or the shared store will contain output data + # only once the task is complete. if inspect.iscoroutinefunction(task): async def execute(): @@ -333,14 +274,22 @@ async def execute(): for idx, (output_id, value) in enumerate( zip(output_data_ids, output_values) ): + serialized_data = serialize_complex_data(value) + local_store.cache_serialized_data( + output_id, serialized_data + ) local_store.put(output_id, value) - if shared_store.should_be_shared(value): - shared_store.put(output_id, value) + if shared_store.is_allocated() and shared_store.should_be_shared(serialized_data): + shared_store.put(output_id, serialized_data) completed_data_ids[idx] = output_id else: + serialized_data = serialize_complex_data(output_values) + local_store.cache_serialized_data( + output_data_ids, serialized_data + ) local_store.put(output_data_ids, output_values) - if shared_store.should_be_shared(output_values): - shared_store.put(output_data_ids, output_values) + if shared_store.is_allocated() and shared_store.should_be_shared(serialized_data): + shared_store.put(output_data_ids, serialized_data) completed_data_ids = [output_data_ids] RequestStore.get_instance().check_pending_get_requests(output_data_ids) @@ -406,14 +355,22 @@ async def execute(): for idx, (output_id, value) in enumerate( zip(output_data_ids, output_values) ): + serialized_data = serialize_complex_data(value) + local_store.cache_serialized_data( + output_id, serialized_data + ) local_store.put(output_id, value) - if shared_store.should_be_shared(value): - shared_store.put(output_id, value) + if shared_store.is_allocated() and shared_store.should_be_shared(serialized_data): + shared_store.put(output_id, serialized_data) completed_data_ids[idx] = output_id else: + serialized_data = serialize_complex_data(output_values) + local_store.cache_serialized_data( + output_data_ids, serialized_data + ) local_store.put(output_data_ids, output_values) - if shared_store.should_be_shared(output_values): - shared_store.put(output_data_ids, output_values) + if shared_store.is_allocated() and shared_store.should_be_shared(serialized_data): + shared_store.put(output_data_ids, serialized_data) completed_data_ids = [output_data_ids] RequestStore.get_instance().check_pending_get_requests(output_data_ids) # Monitor the task execution. @@ -473,9 +430,9 @@ def process_task_request(self, request): ) # DataID -> real data - args, is_pending = common.materialize_data_ids(args, self.check_local_data_id) + args, is_pending = common.materialize_data_ids(args, self.unwrap_local_data_id) kwargs, is_kw_pending = common.materialize_data_ids( - kwargs, self.check_local_data_id + kwargs, self.unwrap_local_data_id ) w_logger.debug("Is pending - {}".format(is_pending)) @@ -485,22 +442,7 @@ def process_task_request(self, request): request["kwargs"] = kwargs return request else: - args, is_pending = common.materialize_data_ids( - args, self.unwrap_local_data_id - ) - kwargs, is_kw_pending = common.materialize_data_ids( - kwargs, self.unwrap_local_data_id - ) - self.execute_received_task(output_ids, task, args, kwargs) - - shared_depends = [ - arg - for arg in request["args"] - if is_data_id(arg) and shared_store.contains(arg) - ] - if shared_depends: - self.check_output_depends(output_ids, shared_depends) if output_ids is not None: RequestStore.get_instance().check_pending_get_requests(output_ids) RequestStore.get_instance().check_pending_wait_requests(output_ids) @@ -508,90 +450,3 @@ def process_task_request(self, request): def __del__(self): self.event_loop.close() - - -################################################# -# Check if the data owns the memory it is using # -################################################# - - -def check_ndarray(ndarray): - """ - Check if the `np.ndarray` doesn't owns the memory it is using. - - Returns - ------- - bool - ``True`` if the `np.ndarray` doesn't own the memory it is using, ``False`` otherwise. - """ - return not ndarray.flags.owndata - - -def check_pandas_index(df_index): - """ - Check if the `pd.Index` doesn't owns the memory it is using. - - Returns - ------- - bool - ``True`` if the `pd.Index` doesn't own the memory it is using, ``False`` otherwise. - """ - if df_index._is_multi: - if any( - check_ndarray(df_index.get_level_values(i)._data) - for i in range(df_index.nlevels) - ): - return True - else: - if check_ndarray(df_index._data): - return True - return False - - -def check_data_out_of_band(value): - """ - Check if the data doesn't owns the memory it is using. - - Returns - ------- - bool - ``True`` if the data doesn't own the memory it is using, ``False`` otherwise. - - Notes - ----- - Only validation for `np.ndarray`, `pd.Dataframe` and `pd.Series` is currently supported. - """ - # check numpy - try: - import numpy as np - - if isinstance(value, np.ndarray): - if check_ndarray(value): - return True - return False - except ImportError: - pass - - # check pandas - try: - import pandas as pd - - if isinstance(value, pd.DataFrame): - if any(block.is_view for block in value._mgr.blocks) or any( - check_pandas_index(df_index) - for df_index in [value.index, value.columns] - ): - return True - return False - - if isinstance(value, pd.Series): - if any(block.is_view for block in value._mgr.blocks) or check_pandas_index( - value.index - ): - return True - return False - - # TODO: Add like blocks for other pandas classes - except ImportError: - pass - return False