diff --git a/unidist/core/backends/mpi/core/controller/api.py b/unidist/core/backends/mpi/core/controller/api.py index 655dcb80..8f13546a 100644 --- a/unidist/core/backends/mpi/core/controller/api.py +++ b/unidist/core/backends/mpi/core/controller/api.py @@ -19,6 +19,7 @@ ) from None from unidist.core.backends.mpi.core.serialization import serialize_complex_data +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.controller.garbage_collector import ( @@ -28,8 +29,6 @@ request_worker_data, push_data, RoundRobin, - get_data, - contains_data, ) import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication @@ -388,7 +387,6 @@ def put(data): data_id = local_store.generate_data_id(garbage_collector) serialized_data = serialize_complex_data(data) - # data is prepared for sending to another process, but is not saved to local storage if shared_store.is_allocated(): shared_store.put(data_id, serialized_data) else: @@ -413,17 +411,20 @@ def get(data_ids): object A Python object. """ + object_store = ObjectStore.get_instance() is_list = isinstance(data_ids, list) if not is_list: data_ids = [data_ids] - remote_data_ids = [data_id for data_id in data_ids if not contains_data(data_id)] + remote_data_ids = [ + data_id for data_id in data_ids if not object_store.contains_data(data_id) + ] # Remote data gets available in the local store inside `request_worker_data` if remote_data_ids: request_worker_data(remote_data_ids) logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids))) - values = [get_data(data_id) for data_id in data_ids] + values = [object_store.get_data(data_id) for data_id in data_ids] # Initiate reference count based cleaup # if all the tasks were completed @@ -452,6 +453,7 @@ def wait(data_ids, num_returns=1): tuple List of data IDs that are ready and list of the remaining data IDs. """ + object_store = ObjectStore.get_instance() if not isinstance(data_ids, list): data_ids = [data_ids] # Since the controller should operate MpiDataID(s), @@ -463,7 +465,7 @@ def wait(data_ids, num_returns=1): ready = [] logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids))) for data_id in not_ready.copy(): - if contains_data(data_id): + if object_store.contains_data(data_id): ready.append(data_id) not_ready.remove(data_id) pending_returns -= 1 diff --git a/unidist/core/backends/mpi/core/controller/common.py b/unidist/core/backends/mpi/core/controller/common.py index ef4980e2..6d6f3689 100644 --- a/unidist/core/backends/mpi/core/controller/common.py +++ b/unidist/core/backends/mpi/core/controller/common.py @@ -13,7 +13,6 @@ from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.serialization import ( - deserialize_complex_data, serialize_complex_data, ) @@ -141,10 +140,10 @@ def pull_data(comm, owner_rank=None): shared_store = SharedObjectStore.get_instance() data_id = info_package["id"] - if contains_data(data_id): + if local_store.contains(data_id): return { "id": data_id, - "data": get_data(data_id), + "data": local_store.get(data_id), } data = shared_store.get(data_id, owner_rank, info_package) @@ -414,61 +413,3 @@ def push_data(dest_rank, value, is_blocking_op=False): _push_data_owner(dest_rank, data_id) else: raise ValueError("Unknown DataID!") - - -def contains_data(data_id): - """ - Check if the data associated with `data_id` exists in the current process. - - Parameters - ---------- - data_id : unidist.core.backends.mpi.core.common.MpiDataID - An ID to data. - - Returns - ------- - bool - Return the status if an object exist in the current process. - """ - local_store = LocalObjectStore.get_instance() - shared_store = SharedObjectStore.get_instance() - return ( - local_store.contains(data_id) - or local_store.is_already_serialized(data_id) - or shared_store.contains(data_id) - ) - - -def get_data(data_id): - """ - Get data from any location in the current process. - - Parameters - ---------- - data_id : unidist.core.backends.mpi.core.common.MpiDataID - An ID to data. - - Returns - ------- - object - Return data associated with `data_id`. - """ - local_store = LocalObjectStore.get_instance() - shared_store = SharedObjectStore.get_instance() - - if local_store.contains(data_id): - return local_store.get(data_id) - - if local_store.is_already_serialized(data_id): - serialized_data = local_store.get_serialized_data(data_id) - value = deserialize_complex_data( - serialized_data["s_data"], - serialized_data["raw_buffers"], - serialized_data["buffer_count"], - ) - elif shared_store.contains(data_id): - value = shared_store.get(data_id) - else: - raise ValueError("The current data ID is not contained in the procces.") - local_store.put(data_id, value) - return value diff --git a/unidist/core/backends/mpi/core/local_object_store.py b/unidist/core/backends/mpi/core/local_object_store.py index 64428047..2fbaf404 100644 --- a/unidist/core/backends/mpi/core/local_object_store.py +++ b/unidist/core/backends/mpi/core/local_object_store.py @@ -277,8 +277,9 @@ def cache_serialized_data(self, data_id, data): data : object Serialized data to cache. """ - # Copying is necessary to avoid corruption of data obtained through out-of-band serialization, + # We make a copy to avoid data corruption obtained through out-of-band serialization, # and buffers are marked read-only to prevent them from being modified. + # `to_bytes()` call handles both points. data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]] self._serialization_cache[data_id] = data self.maybe_update_data_id_map(data_id) diff --git a/unidist/core/backends/mpi/core/object_store.py b/unidist/core/backends/mpi/core/object_store.py new file mode 100644 index 00000000..be07494e --- /dev/null +++ b/unidist/core/backends/mpi/core/object_store.py @@ -0,0 +1,84 @@ +from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore +from unidist.core.backends.mpi.core.serialization import deserialize_complex_data +from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore + + +class ObjectStore: + """ + Class that combines checking and reciving data from all stores in a current process. + + Notes + ----- + The store checks for both deserialized and serialized data. + """ + + __instance = None + + @classmethod + def get_instance(cls): + """ + Get instance of ``ObjectStore``. + + Returns + ------- + ObjectStore + """ + if cls.__instance is None: + cls.__instance = ObjectStore() + return cls.__instance + + def contains_data(data_id): + """ + Check if the data associated with `data_id` exists in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + bool + Return the status if an object exist in the current process. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + return ( + local_store.contains(data_id) + or local_store.is_already_serialized(data_id) + or shared_store.contains(data_id) + ) + + def get_data(data_id): + """ + Get data from any location in the current process. + + Parameters + ---------- + data_id : unidist.core.backends.mpi.core.common.MpiDataID + An ID to data. + + Returns + ------- + object + Return data associated with `data_id`. + """ + local_store = LocalObjectStore.get_instance() + shared_store = SharedObjectStore.get_instance() + + if local_store.contains(data_id): + return local_store.get(data_id) + + if local_store.is_already_serialized(data_id): + serialized_data = local_store.get_serialized_data(data_id) + value = deserialize_complex_data( + serialized_data["s_data"], + serialized_data["raw_buffers"], + serialized_data["buffer_count"], + ) + elif shared_store.contains(data_id): + value = shared_store.get(data_id) + else: + raise ValueError("The current data ID is not contained in the procces.") + local_store.put(data_id, value) + return value diff --git a/unidist/core/backends/mpi/core/shared_object_store.py b/unidist/core/backends/mpi/core/shared_object_store.py index d898e900..85050081 100644 --- a/unidist/core/backends/mpi/core/shared_object_store.py +++ b/unidist/core/backends/mpi/core/shared_object_store.py @@ -757,9 +757,11 @@ def get(self, data_id, owner_rank=None, shared_info=None): An ID to data. owner_rank : int, default: None The rank that sent the data. - This value is used to synchronize data in shared memory between different hosts if the value is defined. + This value is used to synchronize data in shared memory between different hosts + if the value is not ``None``. shared_info : dict, default: None - The necessary information to properly deserialize data from shared memory. If `shared_info` is None, the data already exists in shared memory. + The necessary information to properly deserialize data from shared memory. + If `shared_info` is ``None``, the data already exists in shared memory in the current process. """ if shared_info is None: shared_info = self.get_shared_info(data_id) diff --git a/unidist/core/backends/mpi/core/worker/loop.py b/unidist/core/backends/mpi/core/worker/loop.py index 12d3a0d6..f90cdc24 100644 --- a/unidist/core/backends/mpi/core/worker/loop.py +++ b/unidist/core/backends/mpi/core/worker/loop.py @@ -16,11 +16,12 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.worker.request_store import RequestStore from unidist.core.backends.mpi.core.worker.task_store import TaskStore from unidist.core.backends.mpi.core.async_operations import AsyncOperations -from unidist.core.backends.mpi.core.controller.common import pull_data, get_data +from unidist.core.backends.mpi.core.controller.common import pull_data from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore # TODO: Find a way to move this after all imports @@ -85,6 +86,7 @@ async def worker_loop(): ``unidist.core.backends.mpi.core.common.Operations`` defines a set of supported operations. """ task_store = TaskStore.get_instance() + object_store = ObjectStore.get_instance() local_store = LocalObjectStore.get_instance() request_store = RequestStore.get_instance() async_operations = AsyncOperations.get_instance() @@ -185,7 +187,7 @@ async def worker_loop(): if not ready_to_shutdown_posted: # Prepare the data # Actor method here is a data id so we have to retrieve it from the storage - method_name = get_data(request["task"]) + method_name = object_store.get_data(request["task"]) handler = request["handler"] actor_method = getattr(actor_map[handler], method_name) request["task"] = actor_method diff --git a/unidist/core/backends/mpi/core/worker/request_store.py b/unidist/core/backends/mpi/core/worker/request_store.py index df8450cb..a1f5ad11 100644 --- a/unidist/core/backends/mpi/core/worker/request_store.py +++ b/unidist/core/backends/mpi/core/worker/request_store.py @@ -7,6 +7,7 @@ import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication from unidist.core.backends.mpi.core.controller.common import push_data, contains_data +from unidist.core.backends.mpi.core.object_store import ObjectStore mpi_state = communication.MPIState.get_instance() @@ -212,7 +213,8 @@ def process_wait_request(self, data_id): ----- Only ROOT rank is supported for now, therefore no rank argument needed. """ - if contains_data(data_id): + object_store = ObjectStore.get_instance() + if object_store.contains_data(data_id): # Executor wait just for signal # We use a blocking send here because the receiver is waiting for the result. communication.mpi_send_object( diff --git a/unidist/core/backends/mpi/core/worker/task_store.py b/unidist/core/backends/mpi/core/worker/task_store.py index 7231459d..afb13bfb 100644 --- a/unidist/core/backends/mpi/core/worker/task_store.py +++ b/unidist/core/backends/mpi/core/worker/task_store.py @@ -10,8 +10,8 @@ from unidist.core.backends.common.data_id import is_data_id import unidist.core.backends.mpi.core.common as common import unidist.core.backends.mpi.core.communication as communication -from unidist.core.backends.mpi.core.controller.common import get_data, contains_data from unidist.core.backends.mpi.core.async_operations import AsyncOperations +from unidist.core.backends.mpi.core.object_store import ObjectStore from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore from unidist.core.backends.mpi.core.serialization import serialize_complex_data @@ -188,8 +188,9 @@ def unwrap_local_data_id(self, arg): """ if is_data_id(arg): local_store = LocalObjectStore.get_instance() - if contains_data(arg): - value = get_data(arg) + object_store = ObjectStore.get_instance() + if object_store.contains_data(arg): + value = object_store.get_data(arg) # Data is already local or was pushed from master return value, False elif local_store.contains_data_owner(arg): @@ -418,12 +419,13 @@ def process_task_request(self, request): dict or None Same request if the task couldn`t be executed, otherwise ``None``. """ + object_store = ObjectStore.get_instance() # Parse request task = request["task"] # Remote function here is a data id so we have to retrieve it from the storage, # whereas actor method is already materialized in the worker loop. if is_data_id(task): - task = get_data(task) + task = object_store.get_data(task) args = request["args"] kwargs = request["kwargs"] output_ids = request["output"]