Skip to content

Commit

Permalink
Fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Jun 26, 2023
1 parent 1c7d528 commit 5010735
Show file tree
Hide file tree
Showing 16 changed files with 76 additions and 171 deletions.
4 changes: 2 additions & 2 deletions docs/developer/architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ details just pick module you are interested in.
| │ └─── :doc:`remote_function </flow/unidist/core/backends/dask/remote_function>`
| ├───mpi
| | ├───core
| │ │ ├─── :doc:`common </flow/unidist/core/backends/mpi/core/async_operations>`
| │ │ ├─── :doc:`async_operations </flow/unidist/core/backends/mpi/core/async_operations>`
| │ │ ├─── :doc:`common </flow/unidist/core/backends/mpi/core/common>`
| │ │ ├─── :doc:`communication </flow/unidist/core/backends/mpi/core/communication>`
| │ │ ├─── :doc:`controller </flow/unidist/core/backends/mpi/core/controller>`
| │ │ ├─── :doc:`monitor </flow/unidist/core/backends/mpi/core/monitor>`
| │ │ ├─── :doc:`monitor </flow/unidist/core/backends/mpi/core/object_store>`
| │ │ ├─── :doc:`object_store </flow/unidist/core/backends/mpi/core/object_store>`
| │ │ ├─── :doc:`serialization </flow/unidist/core/backends/mpi/core/serialization>`
| │ │ └─── :doc:`worker </flow/unidist/core/backends/mpi/core/worker>`
| │ ├─── :doc:`actor </flow/unidist/core/backends/mpi/actor>`
Expand Down
8 changes: 4 additions & 4 deletions docs/flow/unidist/core/backends/mpi/core/async_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

:orphan:

Async Operations Storage
===============
Async Operations
""""""""""""""""

:py:class:`~unidist.core.backends.mpi.core.async_operations.AsyncOperations` stores ``MPI_Isend`` asynchronous handlers and holds
a reference to the sending data to prolong lifetime until the operation completed.
API
===

.. autoclass:: unidist.core.backends.mpi.core.async_operations.AsyncOperations
:members:
6 changes: 3 additions & 3 deletions docs/flow/unidist/core/backends/mpi/core/object_store.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

:orphan:

Local Object Storage
====================
Object Store
============

MPI :py:class:`~unidist.core.backends.mpi.core.object_store.ObjectStore` stores the data for master process in a local dict.
MPI :py:class:`~unidist.core.backends.mpi.core.object_store.ObjectStore` stores the data either in the shared memory or in each single process in a local dict in depend on data size.

API
===
Expand Down
18 changes: 0 additions & 18 deletions test.py

This file was deleted.

4 changes: 2 additions & 2 deletions unidist/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
IsMpiSpawnWorkers,
MpiHosts,
MpiPickleThreshold,
MpiSharingThreshold,
MpiSharedMemoryThreshold,
)
from .parameter import ValueSource

Expand All @@ -36,5 +36,5 @@
"MpiHosts",
"ValueSource",
"MpiPickleThreshold",
"MpiSharingThreshold",
"MpiSharedMemoryThreshold",
]
9 changes: 7 additions & 2 deletions unidist/config/backends/mpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
IsMpiSpawnWorkers,
MpiHosts,
MpiPickleThreshold,
MpiSharingThreshold,
MpiSharedMemoryThreshold,
)

__all__ = ["IsMpiSpawnWorkers", "MpiHosts", "MpiPickleThreshold", "MpiSharingThreshold"]
__all__ = [
"IsMpiSpawnWorkers",
"MpiHosts",
"MpiPickleThreshold",
"MpiSharedMemoryThreshold",
]
8 changes: 4 additions & 4 deletions unidist/config/backends/mpi/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class MpiPickleThreshold(EnvironmentVariable, type=int):
varname = "UNIDIST_MPI_PICKLE_THRESHOLD"


class MpiSharingThreshold(EnvironmentVariable, type=int):
"""Minimum data size for sending with shared memory"""
class MpiSharedMemoryThreshold(EnvironmentVariable, type=int):
"""Minimum size of data to put into the shared memory."""

default = 1024 # 1 MiB
varname = "UNIDIST_MPI_SHARING_THRESHOLD"
default = 1024**2 # 1 MiB
varname = "UNIDIST_MPI_SHARED_MEMORY_THRESHOLD"
17 changes: 5 additions & 12 deletions unidist/core/backends/mpi/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Operation:
### --- Monitor operations --- ###
TASK_DONE = 9
GET_TASK_COUNT = 10
RESERVE_SHARING_MEMORY = 13
RESERVE_SHARED_MEMORY = 13
### --- Common operations --- ###
CANCEL = 11

Expand Down Expand Up @@ -115,13 +115,8 @@ def __del__(self):
"""Track object deletion by garbage collector."""
# We check for existence of `_qc` attribute because
# it might be deleted during serialization via `__getstate__`
if (
hasattr(self, "_gc")
and self._gc is not None
and self.base_data_id is not None
):
base_id = self.base_data_id()
self._gc.collect(base_id)
if hasattr(self, "_gc") and self._gc is not None:
self._gc.collect(self.base_data_id())

def __getstate__(self):
"""Remove a reference to garbage collector for correct `pickle` serialization."""
Expand All @@ -138,12 +133,10 @@ def base_data_id(self):
unidist.core.backends.common.data_id.DataID
Base ``DataID`` class object without garbage collector reference.
"""
if DataID is not None:
return DataID(self._id)
return None
return DataID(self._id)


def get_logger(logger_name, file_name, activate=True):
def get_logger(logger_name, file_name, activate=False):
"""
Configure logger and get it's instance.
Expand Down
14 changes: 1 addition & 13 deletions unidist/core/backends/mpi/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,6 @@ class MPIRank:
FIRST_WORKER = 2


def init_shared_memory(comm, size):
info = MPI.Info.Create()
# info.Set("alloc_shared_noncontig", "true")
win = MPI.Win.Allocate_shared(size, MPI.BYTE.size, comm=comm, info=info)
win_helper = MPI.Win.Allocate_shared(
1 if size > 0 else 0, MPI.INT.size, comm=comm, info=info
)

shared_buffer, itemsize = win.Shared_query(MPIRank.MONITOR)
return shared_buffer, itemsize, win_helper


def reserve_shared_memory(comm, data_id, data, is_serialized=False):
if is_serialized:
s_data = data["s_data"]
Expand All @@ -274,7 +262,7 @@ def reserve_shared_memory(comm, data_id, data, is_serialized=False):


def _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers):
operation_type = common.Operation.RESERVE_SHARING_MEMORY
operation_type = common.Operation.RESERVE_SHARED_MEMORY
mpi_state = MPIState.get_instance()

operation_data = {
Expand Down
10 changes: 6 additions & 4 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
MpiHosts,
ValueSource,
MpiPickleThreshold,
MpiSharingThreshold,
MpiSharedMemoryThreshold,
)


Expand Down Expand Up @@ -155,8 +155,10 @@ def init():
py_str += [f"cfg.CpuCount.put({CpuCount.get()})"]
if MpiPickleThreshold.get_value_source() != ValueSource.DEFAULT:
py_str += [f"cfg.MpiPickleThreshold.put({MpiPickleThreshold.get()})"]
if MpiSharingThreshold.get_value_source() != ValueSource.DEFAULT:
py_str += [f"cfg.MpiSharingThreshold.put({MpiSharingThreshold.get()})"]
if MpiSharedMemoryThreshold.get_value_source() != ValueSource.DEFAULT:
py_str += [
f"cfg.MpiSharingThreshold.put({MpiSharedMemoryThreshold.get()})"
]
py_str += ["unidist.init()"]
py_str = "; ".join(py_str)
args += [py_str]
Expand Down Expand Up @@ -327,7 +329,7 @@ def put(data):
object_store = ObjectStore.get_instance()
data_id = object_store.generate_data_id(garbage_collector)
object_store.put(data_id, data)
if sys.getsizeof(data) > MpiSharingThreshold.get():
if sys.getsizeof(data) > MpiSharedMemoryThreshold.get():
put_to_shared_memory(data_id.base_data_id())

logger.debug("PUT {} id".format(data_id._id))
Expand Down
38 changes: 10 additions & 28 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
# SPDX-License-Identifier: Apache-2.0

"""Common functionality related to `controller`."""
import sys
import itertools

from unidist.core.backends.common.data_id import is_data_id
import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.config import MpiSharingThreshold
from unidist.core.backends.mpi.core.shared_store import SharedStore

logger = common.get_logger("common", "common.log")
Expand Down Expand Up @@ -119,36 +117,24 @@ def get_complex_data(comm, owner_rank):
is_contained_in_shared_memory = SharedStore.get_instance().contains(
info_package["id"]
)
mpi_state = communication.MPIState.get_instance()
log_name = f'shared_store{mpi_state.rank}'
sh_logger = common.get_logger(log_name, f'log_name.log')


if not is_contained_in_shared_memory:
sh_buf = shared_store.get_shared_buffer(info_package["id"])
# recv serialized data to shared memory
owner_monitor = communication.MPIState.get_instance().get_monitor_by_worker_rank(owner_rank)
owner_monitor = (
communication.MPIState.get_instance().get_monitor_by_worker_rank(
owner_rank
)
)
communication.send_simple_operation(
comm,
operation_type=common.Operation.REQUEST_SHARED_DATA,
operation_data=info_package,
dest_rank=owner_monitor,
)
communication.mpi_recv_shared_buffer(
comm,
sh_buf,
owner_monitor
)
sh_logger.debug('\n\nGET SHARED BUFFER')
sh_logger.debug(sh_buf[:100].tobytes())
communication.mpi_recv_shared_buffer(comm, sh_buf, owner_monitor)
shared_store.put_service_info(info_package["id"])
try:
data = shared_store.get(info_package["id"])
except Exception as ex:
sh_logger.debug(info_package["id"])
sh_logger.debug(info_package)
sh_logger.debug(is_contained_in_shared_memory)
sh_logger.exception(ex)
raise
data = shared_store.get(info_package["id"])
return {
"id": info_package["id"],
"data": data,
Expand Down Expand Up @@ -281,7 +267,6 @@ def _push_shared_data(dest_rank, data_id, is_blocking_op):
operation_type = common.Operation.PUT_SHARED_DATA
async_operations = AsyncOperations.get_instance()
info_package = shared_store.get_data_shared_info(data_id)
logger.debug(info_package)
if is_blocking_op:
communication.mpi_send_object(mpi_state.comm, info_package, dest_rank)
else:
Expand Down Expand Up @@ -352,7 +337,7 @@ def push_data(dest_rank, value, is_blocking_op=False):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
elif object_store.contains(data_id):
data = object_store.get(data_id)
if sys.getsizeof(data) > MpiSharingThreshold.get():
if shared_store.is_should_be_shared(data):
put_to_shared_memory(data_id)
_push_shared_data(dest_rank, data_id, is_blocking_op)
else:
Expand All @@ -375,10 +360,7 @@ def put_to_shared_memory(data_id):
# )
mpi_state = communication.MPIState.get_instance()
reservation_data, serialized_data = communication.reserve_shared_memory(
mpi_state.comm,
data_id,
operation_data,
is_serialized=False
mpi_state.comm, data_id, operation_data, is_serialized=False
)

shared_store.put(data_id, reservation_data, serialized_data)
12 changes: 2 additions & 10 deletions unidist/core/backends/mpi/core/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
mpi4py.rc(recv_mprobe=False, initialize=False)
from mpi4py import MPI # noqa: E402

mpi_state = communication.MPIState.get_instance()
log_name = f'monitor_{mpi_state.rank}'
logger = common.get_logger(log_name, f'{log_name}.log')


class TaskCounter:
__instance = None
Expand Down Expand Up @@ -173,7 +169,6 @@ def monitor_loop():
data_id_tracker = DataIDTracker.get_instance()
shared_store = shared_store = SharedStore.get_instance()


shared_index = 0

while True:
Expand All @@ -190,9 +185,7 @@ def monitor_loop():
elif operation_type == common.Operation.WAIT:
# TODO: WAIT request can be received from several workers,
# but not only from master. Handle this case when requested.
operation_data = communication.recv_simple_data(
mpi_state.comm, source_rank
)
operation_data = communication.recv_simple_data(mpi_state.comm, source_rank)
awaited_data_ids = operation_data["data_ids"]
num_returns = operation_data["num_returns"]
wait_handler.add_wait_request(awaited_data_ids, num_returns)
Expand All @@ -204,7 +197,7 @@ def monitor_loop():
task_counter.task_counter,
source_rank,
)
elif operation_type == common.Operation.RESERVE_SHARING_MEMORY:
elif operation_type == common.Operation.RESERVE_SHARED_MEMORY:
request = communication.recv_simple_data(mpi_state.comm, source_rank)
first_index = shared_index
last_index = first_index + request["size"]
Expand All @@ -217,7 +210,6 @@ def monitor_loop():
if not shared_store.contains_shared_info(info_package["id"]):
shared_store.put_shared_info(info_package["id"], info_package)
sh_buf = shared_store.get_shared_buffer(info_package["id"])
logger.debug(sh_buf[:100].tobytes())
communication.mpi_send_shared_buffer(
mpi_state.comm,
sh_buf,
Expand Down
37 changes: 0 additions & 37 deletions unidist/core/backends/mpi/core/monitor/shared_store.py

This file was deleted.

2 changes: 1 addition & 1 deletion unidist/core/backends/mpi/core/object_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2021-2022 Modin authors
# Copyright (C) 2021-2023 Modin authors
#
# SPDX-License-Identifier: Apache-2.0

Expand Down
Loading

0 comments on commit 5010735

Please sign in to comment.