Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX-#407: Make data put into MPI object store immutable #409

Merged
merged 6 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
) from None

from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.garbage_collector import (
Expand Down Expand Up @@ -386,7 +387,6 @@ def put(data):

data_id = local_store.generate_data_id(garbage_collector)
serialized_data = serialize_complex_data(data)
local_store.put(data_id, data)
if shared_store.is_allocated():
shared_store.put(data_id, serialized_data)
else:
Expand All @@ -411,21 +411,20 @@ def get(data_ids):
object
A Python object.
"""
local_store = LocalObjectStore.get_instance()

object_store = ObjectStore.get_instance()
is_list = isinstance(data_ids, list)
if not is_list:
data_ids = [data_ids]
remote_data_ids = [
data_id for data_id in data_ids if not local_store.contains(data_id)
data_id for data_id in data_ids if not object_store.contains(data_id)
]
# Remote data gets available in the local store inside `request_worker_data`
if remote_data_ids:
request_worker_data(remote_data_ids)

logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids)))

values = [local_store.get(data_id) for data_id in data_ids]
values = [object_store.get(data_id) for data_id in data_ids]

# Initiate reference count based cleaup
# if all the tasks were completed
Expand Down Expand Up @@ -454,6 +453,7 @@ def wait(data_ids, num_returns=1):
tuple
List of data IDs that are ready and list of the remaining data IDs.
"""
object_store = ObjectStore.get_instance()
if not isinstance(data_ids, list):
data_ids = [data_ids]
# Since the controller should operate MpiDataID(s),
Expand All @@ -463,11 +463,9 @@ def wait(data_ids, num_returns=1):
not_ready = data_ids
pending_returns = num_returns
ready = []
local_store = LocalObjectStore.get_instance()

logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids)))
for data_id in not_ready.copy():
if local_store.contains(data_id):
if object_store.contains(data_id):
ready.append(data_id)
not_ready.remove(data_id)
pending_returns -= 1
Expand Down
29 changes: 14 additions & 15 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.serialization import (
serialize_complex_data,
)


logger = common.get_logger("common", "common.log")
Expand Down Expand Up @@ -394,22 +396,19 @@ def push_data(dest_rank, value, is_blocking_op=False):
data_id = value
if shared_store.contains(data_id):
_push_shared_data(dest_rank, data_id, is_blocking_op)
elif local_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
elif local_store.contains(data_id):
if local_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
data = local_store.get(data_id)
serialized_data = serialize_complex_data(data)
if shared_store.is_allocated() and shared_store.should_be_shared(
serialized_data
):
shared_store.put(data_id, serialized_data)
_push_shared_data(dest_rank, data_id, is_blocking_op)
else:
data = local_store.get(data_id)
serialized_data = serialize_complex_data(data)
if shared_store.is_allocated() and shared_store.should_be_shared(
serialized_data
):
shared_store.put(data_id, serialized_data)
_push_shared_data(dest_rank, data_id, is_blocking_op)
else:
local_store.cache_serialized_data(data_id, serialized_data)
_push_local_data(
dest_rank, data_id, is_blocking_op, is_serialized=True
)
local_store.cache_serialized_data(data_id, serialized_data)
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
elif local_store.contains_data_owner(data_id):
_push_data_owner(dest_rank, data_id)
else:
Expand Down
4 changes: 4 additions & 0 deletions unidist/core/backends/mpi/core/local_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def cache_serialized_data(self, data_id, data):
data : object
Serialized data to cache.
"""
# We make a copy to avoid data corruption obtained through out-of-band serialization,
# and buffers are marked read-only to prevent them from being modified.
# `to_bytes()` call handles both points.
data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]]
self._serialization_cache[data_id] = data
self.maybe_update_data_id_map(data_id)

Expand Down
90 changes: 90 additions & 0 deletions unidist/core/backends/mpi/core/object_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (C) 2021-2023 Modin authors
#
# SPDX-License-Identifier: Apache-2.0

"""`ObjectStore` functionality."""

from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
Retribution98 marked this conversation as resolved.
Show resolved Hide resolved
from unidist.core.backends.mpi.core.serialization import deserialize_complex_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore


class ObjectStore:
"""
Class that combines checking and retrieving data from the shared and local stores in a current process.

Notes
-----
The store checks for both deserialized and serialized data.
"""

__instance = None

@classmethod
def get_instance(cls):
"""
Get instance of ``ObjectStore``.

Returns
-------
ObjectStore
"""
if cls.__instance is None:
cls.__instance = ObjectStore()
return cls.__instance

def contains(self, data_id):
"""
Check if the data associated with `data_id` exists in the current process.

Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.

Returns
-------
bool
Return the status if an object exist in the current process.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()
return (
local_store.contains(data_id)
or local_store.is_already_serialized(data_id)
or shared_store.contains(data_id)
)

def get(self, data_id):
"""
Get data from any location in the current process.

Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.

Returns
-------
object
Return data associated with `data_id`.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()

if local_store.contains(data_id):
return local_store.get(data_id)

if local_store.is_already_serialized(data_id):
serialized_data = local_store.get_serialized_data(data_id)
value = deserialize_complex_data(
serialized_data["s_data"],
serialized_data["raw_buffers"],
serialized_data["buffer_count"],
)
elif shared_store.contains(data_id):
value = shared_store.get(data_id)
else:
raise ValueError("The current data ID is not contained in the procces.")
local_store.put(data_id, value)
return value
98 changes: 55 additions & 43 deletions unidist/core/backends/mpi/core/shared_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,62 +747,74 @@ def put(self, data_id, serialized_data):
# put shared info
self._put_shared_info(data_id, shared_info)

def get(self, data_id, owner_rank, shared_info):
def get(self, data_id, owner_rank=None, shared_info=None):
"""
Get data from another worker using shared memory.

Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
owner_rank : int
owner_rank : int, default: None
The rank that sent the data.
shared_info : dict
This value is used to synchronize data in shared memory between different hosts
if the value is not ``None``.
shared_info : dict, default: None
The necessary information to properly deserialize data from shared memory.
If `shared_info` is ``None``, the data already exists in shared memory in the current process.
"""
mpi_state = communication.MPIState.get_instance()
s_data_len = shared_info["s_data_len"]
raw_buffers_len = shared_info["raw_buffers_len"]
service_index = shared_info["service_index"]
buffer_count = shared_info["buffer_count"]

# check data in shared memory
if not self._check_service_info(data_id, service_index):
# reserve shared memory
shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len])
reservation_info = communication.send_reserve_operation(
mpi_state.global_comm, data_id, shared_data_len
)

service_index = reservation_info["service_index"]
# check if worker should sync shared buffer or it is doing by another worker
if reservation_info["is_first_request"]:
# syncronize shared buffer
self._sync_shared_memory_from_another_host(
mpi_state.global_comm,
data_id,
owner_rank,
reservation_info["first_index"],
reservation_info["last_index"],
service_index,
)
# put service info
self._put_service_info(
service_index, data_id, reservation_info["first_index"]
if shared_info is None:
shared_info = self.get_shared_info(data_id)
else:
mpi_state = communication.MPIState.get_instance()
s_data_len = shared_info["s_data_len"]
raw_buffers_len = shared_info["raw_buffers_len"]
service_index = shared_info["service_index"]
buffer_count = shared_info["buffer_count"]

# check data in shared memory
if not self._check_service_info(data_id, service_index):
# reserve shared memory
shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len])
reservation_info = communication.send_reserve_operation(
mpi_state.global_comm, data_id, shared_data_len
)
else:
# wait while another worker syncronize shared buffer
while not self._check_service_info(data_id, service_index):
time.sleep(MpiBackoff.get())

# put shared info with updated data_id and service_index
shared_info = common.MetadataPackage.get_shared_info(
data_id, s_data_len, raw_buffers_len, buffer_count, service_index
)
self._put_shared_info(data_id, shared_info)
service_index = reservation_info["service_index"]
# check if worker should sync shared buffer or it is doing by another worker
if reservation_info["is_first_request"]:
# syncronize shared buffer
if owner_rank is None:
raise ValueError(
"The data is not in the host's shared memory and the data must be synchronized, "
+ "but the owner rank is not defined."
)

self._sync_shared_memory_from_another_host(
mpi_state.global_comm,
data_id,
owner_rank,
reservation_info["first_index"],
reservation_info["last_index"],
service_index,
)
# put service info
self._put_service_info(
service_index, data_id, reservation_info["first_index"]
)
else:
# wait while another worker syncronize shared buffer
while not self._check_service_info(data_id, service_index):
time.sleep(MpiBackoff.get())

# put shared info with updated data_id and service_index
shared_info = common.MetadataPackage.get_shared_info(
data_id, s_data_len, raw_buffers_len, buffer_count, service_index
)
self._put_shared_info(data_id, shared_info)

# increment ref
self._increment_ref_number(data_id, shared_info["service_index"])
# increment ref
self._increment_ref_number(data_id, shared_info["service_index"])

# read from shared buffer and deserialized
return self._read_from_shared_buffer(data_id, shared_info)
Expand Down
4 changes: 3 additions & 1 deletion unidist/core/backends/mpi/core/worker/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.worker.request_store import RequestStore
from unidist.core.backends.mpi.core.worker.task_store import TaskStore
Expand Down Expand Up @@ -85,6 +86,7 @@ async def worker_loop():
``unidist.core.backends.mpi.core.common.Operations`` defines a set of supported operations.
"""
task_store = TaskStore.get_instance()
object_store = ObjectStore.get_instance()
local_store = LocalObjectStore.get_instance()
request_store = RequestStore.get_instance()
async_operations = AsyncOperations.get_instance()
Expand Down Expand Up @@ -185,7 +187,7 @@ async def worker_loop():
if not ready_to_shutdown_posted:
# Prepare the data
# Actor method here is a data id so we have to retrieve it from the storage
method_name = local_store.get(request["task"])
method_name = object_store.get(request["task"])
handler = request["handler"]
actor_method = getattr(actor_map[handler], method_name)
request["task"] = actor_method
Expand Down
9 changes: 5 additions & 4 deletions unidist/core/backends/mpi/core/worker/request_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.common import push_data
from unidist.core.backends.mpi.core.object_store import ObjectStore


mpi_state = communication.MPIState.get_instance()
Expand Down Expand Up @@ -213,7 +213,8 @@ def process_wait_request(self, data_id):
-----
Only ROOT rank is supported for now, therefore no rank argument needed.
"""
if LocalObjectStore.get_instance().contains(data_id):
object_store = ObjectStore.get_instance()
if object_store.contains(data_id):
# Executor wait just for signal
# We use a blocking send here because the receiver is waiting for the result.
communication.mpi_send_object(
Expand Down Expand Up @@ -247,8 +248,8 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False):
-----
Request is asynchronous, no wait for the data sending.
"""
local_store = LocalObjectStore.get_instance()
if local_store.contains(data_id):
object_store = ObjectStore.get_instance()
if object_store.contains(data_id):
push_data(
source_rank,
data_id,
Expand Down
Loading
Loading