Skip to content

Commit

Permalink
FIX-#343: Get rid of the hacks in the MPI backend
Browse files Browse the repository at this point in the history
Signed-off-by: Igoshev, Iaroslav <[email protected]>
  • Loading branch information
YarShev committed Oct 2, 2023
1 parent c5c10ab commit 4110619
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 296 deletions.
32 changes: 30 additions & 2 deletions unidist/config/backends/mpi/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,39 @@ class MpiHosts(EnvironmentVariable, type=ExactStr):


class MpiPickleThreshold(EnvironmentVariable, type=int):
"""Minimum buffer size for serialization with pickle 5 protocol."""
"""
Minimum buffer size for serialization with pickle 5 protocol.
Notes
-----
If the shared object store is enabled, ``MpiSharedObjectStoreThreshold`` takes
precedence on this configuration value and the threshold gets overridden.
It is done intentionally to prevent multiple copies when putting an object
into the local object store or into the shared object store.
Data copy happens once when doing in-band serialization in depend on the threshold.
In some cases output of a remote task can take up the memory of the task arguments.
If those arguments are placed in the shared object store, this location should not be overwritten
while output is being used, otherwise the output value may be corrupted.
"""

default = 1024**2 // 4 # 0.25 MiB
varname = "UNIDIST_MPI_PICKLE_THRESHOLD"

@classmethod
def get(cls) -> int:
"""
Get minimum buffer size for serialization with pickle 5 protocol.
Returns
-------
int
"""
if MpiSharedObjectStore.get():
mpi_pickle_threshold = MpiSharedObjectStoreThreshold.get()
else:
mpi_pickle_threshold = super().get()
return mpi_pickle_threshold


class MpiBackoff(EnvironmentVariable, type=float):
"""
Expand Down Expand Up @@ -65,5 +93,5 @@ class MpiSharedObjectStoreMemory(EnvironmentVariable, type=int):
class MpiSharedObjectStoreThreshold(EnvironmentVariable, type=int):
"""Minimum size of data to put into the shared object store."""

default = 1024**2 # 1 MiB
default = 10 ** 5 # 100 KB
varname = "UNIDIST_MPI_SHARED_OBJECT_STORE_THRESHOLD"
7 changes: 6 additions & 1 deletion unidist/core/backends/mpi/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,16 @@ class MetadataPackage(ImmutableDict):
SHARED_DATA = 1

@classmethod
def get_local_info(cls, s_data_len, raw_buffers_len, buffer_count):
def get_local_info(cls, data_id, s_data_len, raw_buffers_len, buffer_count):
"""
Get information package for sending local data.
Parameters
----------
data_id : unidist.core.backends.common.data_id.DataID
An ID to data.
Can be ``None`` to indicate a fake `data_id` to get full metadata package.
It is usually used when submitting a task or an actor for not yet serialized data.
s_data_len : int
Main buffer length.
raw_buffers_len : list
Expand All @@ -145,6 +149,7 @@ def get_local_info(cls, s_data_len, raw_buffers_len, buffer_count):
return MetadataPackage(
{
"package_type": MetadataPackage.LOCAL_DATA,
"id": data_id,
"s_data_len": s_data_len,
"raw_buffers_len": raw_buffers_len,
"buffer_count": buffer_count,
Expand Down
121 changes: 58 additions & 63 deletions unidist/core/backends/mpi/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from unidist.config import MpiBackoff
from unidist.core.backends.mpi.core.serialization import (
ComplexDataSerializer,
SimpleDataSerializer,
serialize_complex_data,
deserialize_complex_data,
)
import unidist.core.backends.mpi.core.common as common

Expand Down Expand Up @@ -525,7 +526,7 @@ def mpi_busy_wait_recv(comm, source_rank):
# --------------------------------- #


def _send_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank):
def _send_complex_data_impl(comm, s_data, raw_buffers, dest_rank, info_package):
"""
Send already serialized complex data.
Expand All @@ -537,21 +538,16 @@ def _send_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank):
Serialized data as bytearray.
raw_buffers : list
Pickle buffers list, out-of-band data collected with pickle 5 protocol.
buffer_count : list
List of the number of buffers for each object
to be serialized/deserialized using the pickle 5 protocol.
See details in :py:class:`~unidist.core.backends.mpi.core.serialization.ComplexDataSerializer`.
dest_rank : int
Target MPI process to transfer data.
info_package : unidist.core.backends.mpi.core.common.MetadataPackage
Required information to deserialize data on a receiver side.
Notes
-----
The special tags are used for this communication, namely,
``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``.
"""
info_package = common.MetadataPackage.get_local_info(
len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count
)
# wrap to dict for sending and correct deserialization of the object by the recipient
comm.send(dict(info_package), dest=dest_rank, tag=common.MPITag.OBJECT)
with pkl5._bigmpi as bigmpi:
Expand All @@ -577,44 +573,40 @@ def send_complex_data(comm, data, dest_rank, is_serialized=False):
Returns
-------
dict or None
dict
Serialized data for caching purpose.
Notes
-----
* ``None`` is returned if `data` is already serialized,
otherwise ``dict`` containing data serialized in this function.
* This blocking send is used when we have to wait for completion of the communication,
which is necessary for the pipeline to continue, or when the receiver is waiting for a result.
Otherwise, use non-blocking ``isend_complex_data``.
"""
serialized_data = None
if is_serialized:
s_data = data["s_data"]
raw_buffers = data["raw_buffers"]
buffer_count = data["buffer_count"]
# pop `data_id` out of the dict because it will be send as part of metadata package
data_id = data.pop("id")
serialized_data = data
else:
serializer = ComplexDataSerializer()
# Main job
s_data = serializer.serialize(data)
# Retrive the metadata
raw_buffers = serializer.buffers
buffer_count = serializer.buffer_count

serialized_data = {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}
data_id = data["id"]
serialized_data = serialize_complex_data(data)
s_data = serialized_data["s_data"]
raw_buffers = serialized_data["raw_buffers"]
buffer_count = serialized_data["buffer_count"]

info_package = common.MetadataPackage.get_local_info(
data_id, len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count
)
# MPI communication
_send_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank)
_send_complex_data_impl(comm, s_data, raw_buffers, dest_rank, info_package)

# For caching purpose
return serialized_data


def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank):
def _isend_complex_data_impl(comm, s_data, raw_buffers, dest_rank, info_package):
"""
Send serialized complex data.
Expand All @@ -628,12 +620,10 @@ def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank)
A serialized msgpack data.
raw_buffers : list
A list of pickle buffers.
buffer_count : list
List of the number of buffers for each object
to be serialized/deserialized using the pickle 5 protocol.
See details in :py:class:`~unidist.core.backends.mpi.core.serialization.ComplexDataSerializer`.
dest_rank : int
Target MPI process to transfer data.
info_package : unidist.core.backends.mpi.core.common.MetadataPackage
Required information to deserialize data on a receiver side.
Returns
-------
Expand All @@ -646,13 +636,9 @@ def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank)
``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``.
"""
handlers = []
info_package = common.MetadataPackage.get_local_info(
len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count
)
# wrap to dict for sending and correct deserialization of the object by the recipient
h1 = comm.isend(dict(info_package), dest=dest_rank, tag=common.MPITag.OBJECT)
handlers.append((h1, None))

with pkl5._bigmpi as bigmpi:
h2 = comm.Isend(bigmpi(s_data), dest=dest_rank, tag=common.MPITag.BUFFER)
handlers.append((h2, s_data))
Expand All @@ -663,7 +649,7 @@ def _isend_complex_data_impl(comm, s_data, raw_buffers, buffer_count, dest_rank)
return handlers


def isend_complex_data(comm, data, dest_rank):
def isend_complex_data(comm, data, dest_rank, is_serialized=False):
"""
Send the data that consists of different user provided complex types, lambdas and buffers in a non-blocking way.
Expand All @@ -677,6 +663,8 @@ def isend_complex_data(comm, data, dest_rank):
Data object to send.
dest_rank : int
Target MPI process to transfer data.
is_serialized : bool, default: False
`operation_data` is already serialized or not.
Returns
-------
Expand All @@ -694,16 +682,28 @@ def isend_complex_data(comm, data, dest_rank):
* The special tags are used for this communication, namely,
``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``.
"""
serializer = ComplexDataSerializer()
# Main job
s_data = serializer.serialize(data)
# Retrive the metadata
raw_buffers = serializer.buffers
buffer_count = serializer.buffer_count
if is_serialized:
s_data = data["s_data"]
raw_buffers = data["raw_buffers"]
buffer_count = data["buffer_count"]
# pop `data_id` out of the dict because it will be send as part of metadata package
data_id = data.pop("id")
else:
serialized_data = serialize_complex_data(data)
s_data = serialized_data["s_data"]
raw_buffers = serialized_data["raw_buffers"]
buffer_count = serialized_data["buffer_count"]
# Set fake `data_id` to get full metadata package below.
# This branch is usually used when submitting a task or an actor
# for not yet serialized data.
data_id = None

# Send message pack bytestring
info_package = common.MetadataPackage.get_local_info(
data_id, len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count
)
# MPI communication
handlers = _isend_complex_data_impl(
comm, s_data, raw_buffers, buffer_count, dest_rank
comm, s_data, raw_buffers, dest_rank, info_package
)

return handlers, s_data, raw_buffers, buffer_count
Expand Down Expand Up @@ -742,11 +742,7 @@ def recv_complex_data(comm, source_rank, info_package):
for rbuf in raw_buffers:
comm.Recv(bigmpi(rbuf), source=source_rank, tag=common.MPITag.BUFFER)

# Set the necessary metadata for unpacking
deserializer = ComplexDataSerializer(raw_buffers, buffer_count)

# Start unpacking
return deserializer.deserialize(msgpack_buffer)
return deserialize_complex_data(msgpack_buffer, raw_buffers, buffer_count)


# ---------- #
Expand Down Expand Up @@ -845,14 +841,11 @@ def isend_complex_operation(
Returns
-------
dict and dict or dict and None
Async handlers and serialization data for caching purpose.
list and dict
Async handlers list and serialization data dict for caching purpose.
Notes
-----
* Function always returns a ``dict`` containing async handlers to the sent MPI operations.
In addition, ``None`` is returned if `operation_data` is already serialized,
otherwise ``dict`` containing data serialized in this function.
* The special tags are used for this communication, namely,
``common.MPITag.OPERATION``, ``common.MPITag.OBJECT`` and ``common.MPITag.BUFFER``.
"""
Expand All @@ -867,24 +860,26 @@ def isend_complex_operation(
s_data = operation_data["s_data"]
raw_buffers = operation_data["raw_buffers"]
buffer_count = operation_data["buffer_count"]

# pop `data_id` out of the dict because it will be send as part of metadata package
data_id = operation_data.pop("id")
info_package = common.MetadataPackage.get_local_info(
data_id, len(s_data), [len(sbuf) for sbuf in raw_buffers], buffer_count
)
h2_list = _isend_complex_data_impl(
comm, s_data, raw_buffers, buffer_count, dest_rank
comm, s_data, raw_buffers, dest_rank, info_package
)
handlers.extend(h2_list)

return handlers, None
else:
# Serialize and send the data
h2_list, s_data, raw_buffers, buffer_count = isend_complex_data(
comm, operation_data, dest_rank
)
handlers.extend(h2_list)
return handlers, {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}
return handlers, {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}


def isend_serialized_operation(comm, operation_type, operation_data, dest_rank):
Expand Down
2 changes: 2 additions & 0 deletions unidist/core/backends/mpi/core/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __call__(self, *args, num_returns=1, **kwargs):
operation_type,
operation_data,
self._actor._owner_rank,
is_serialized=False,
)
async_operations.extend(h_list)
return output_id
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(self, cls, *args, owner_rank=None, handler_id=None, **kwargs):
operation_type,
operation_data,
self._owner_rank,
is_serialized=False,
)
async_operations.extend(h_list)

Expand Down
12 changes: 9 additions & 3 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"Missing dependency 'mpi4py'. Use pip or conda to install it."
) from None

from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.garbage_collector import (
Expand Down Expand Up @@ -369,10 +370,13 @@ def put(data):
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()

data_id = local_store.generate_data_id(garbage_collector)
serialized_data = serialize_complex_data(data)
local_store.cache_serialized_data(data_id, serialized_data)
local_store.put(data_id, data)
if shared_store.is_allocated():
shared_store.put(data_id, data)
shared_store.put(data_id, serialized_data)

logger.debug("PUT {} id".format(data_id._id))

Expand Down Expand Up @@ -545,13 +549,14 @@ def submit(task, *args, num_returns=1, **kwargs):
unwrapped_args = [common.unwrap_data_ids(arg) for arg in args]
unwrapped_kwargs = {k: common.unwrap_data_ids(v) for k, v in kwargs.items()}

push_data(dest_rank, common.master_data_ids_to_base(task))
task_base_id = common.master_data_ids_to_base(task)
push_data(dest_rank, task_base_id)
push_data(dest_rank, unwrapped_args)
push_data(dest_rank, unwrapped_kwargs)

operation_type = common.Operation.EXECUTE
operation_data = {
"task": task,
"task": task_base_id,
"args": unwrapped_args,
"kwargs": unwrapped_kwargs,
"output": common.master_data_ids_to_base(output_ids),
Expand All @@ -562,6 +567,7 @@ def submit(task, *args, num_returns=1, **kwargs):
operation_type,
operation_data,
dest_rank,
is_serialized=False,
)
async_operations.extend(h_list)

Expand Down
Loading

0 comments on commit 4110619

Please sign in to comment.