Skip to content

Commit

Permalink
Add shared store
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and Retribution98 committed Jun 26, 2023
1 parent 4f0c6c8 commit 1c7d528
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 173 deletions.
18 changes: 18 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unidist
import numpy as np

unidist.init()


@unidist.remote
def f(arr):
print(f"{len(arr)}: {arr.sum()}")


refs = []
data = np.array(range(100**3))
for i in range(120):
data_ref = unidist.put(data)
refs.append(f.remote(data_ref))

unidist.wait(refs)
1 change: 1 addition & 0 deletions unidist/core/backends/mpi/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Operation:
PUT_DATA = 3
PUT_OWNER = 4
PUT_SHARED_DATA = 12
REQUEST_SHARED_DATA = 14
WAIT = 5
ACTOR_CREATE = 6
ACTOR_EXECUTE = 7
Expand Down
11 changes: 10 additions & 1 deletion unidist/core/backends/mpi/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def reserve_shared_memory(comm, data_id, data, is_serialized=False):

def _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers):
operation_type = common.Operation.RESERVE_SHARING_MEMORY
mpi_state = MPIState.get_instance()

operation_data = {
"id": data_id,
Expand All @@ -286,7 +287,7 @@ def _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers):
comm,
operation_type,
operation_data,
MPIRank.MONITOR,
mpi_state.get_monitor_by_worker_rank(MPIRank.ROOT),
)
firstIndex, lastIndex = mpi_busy_wait_recv(comm, MPIRank.MONITOR)
return {"firstIndex": firstIndex, "lastIndex": lastIndex}
Expand Down Expand Up @@ -423,6 +424,10 @@ def mpi_send_buffer(comm, buffer_size, buffer, dest_rank):
comm.Send([buffer, MPI.CHAR], dest=dest_rank)


def mpi_send_shared_buffer(comm, shared_buffer, dest_rank):
return comm.Send([shared_buffer, MPI.CHAR], dest=dest_rank)


def mpi_recv_buffer(comm, source_rank):
"""
Receive data buffer.
Expand All @@ -445,6 +450,10 @@ def mpi_recv_buffer(comm, source_rank):
return s_buffer


def mpi_recv_shared_buffer(comm, shared_buffer, source_rank):
comm.Recv([shared_buffer, MPI.CHAR], source=source_rank)


def mpi_isend_buffer(comm, buffer_size, buffer, dest_rank):
"""
Send buffer object to another MPI rank in a non-blocking way.
Expand Down
33 changes: 5 additions & 28 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

"""High-level API of MPI backend."""
import os
import psutil
import sys
import atexit
import signal
import asyncio
from collections import defaultdict

from unidist.core.backends.mpi.core.shared_store import SharedStore

try:
import mpi4py
except ImportError:
Expand Down Expand Up @@ -212,31 +212,8 @@ def init():
if not is_mpi_initialized:
is_mpi_initialized = True

virtual_memory = psutil.virtual_memory().total
if mpi_state.rank == communication.MPIRank.MONITOR:
if sys.platform.startswith("linux"):
shm_fd = os.open("/dev/shm", os.O_RDONLY)
try:
shm_stats = os.fstatvfs(shm_fd)
system_memory = shm_stats.f_bsize * shm_stats.f_bavail
if system_memory / (virtual_memory / 2) < 0.99:
print(
f"The size of /dev/shm is too small ({system_memory} bytes). The required size "
+ f"at least half of RAM ({virtual_memory // 2} bytes). Please, delete files in /dev/shm or "
+ "increase size of /dev/shm with --shm-size in Docker."
)
finally:
os.close(shm_fd)
else:
system_memory = virtual_memory

# use only 95% because other memory need for local worker storages
shared_memory = int(system_memory * 0.95)
else:
shared_memory = 0
# experimentary for 07 server
# 4800374938 - 73728 * mpi_state.world_size
ObjectStore.get_instance().init_shared_memory(comm, shared_memory)
# Initalize shared memory
SharedStore.get_instance()

if mpi_state.is_root_process():
atexit.register(_termination_handler)
Expand Down Expand Up @@ -456,7 +433,7 @@ def wait(data_ids, num_returns=1):
operation_data,
root_monitor,
)
data = communication.recv_simple_operation(
data = communication.recv_simple_data(
mpi_state.comm,
root_monitor,
)
Expand Down
79 changes: 57 additions & 22 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.config import MpiSharingThreshold
from unidist.core.backends.mpi.core.shared_store import SharedStore

logger = common.get_logger("common", "common.log")

Expand Down Expand Up @@ -110,13 +111,49 @@ def get_complex_data(comm, owner_rank):
info_package = communication.recv_simple_data(comm, owner_rank)
if info_package["package_type"] == communication.DataInfoType.SHARED_DATA:
object_store = ObjectStore.get_instance()
shared_store = SharedStore.get_instance()
info_package["id"] = object_store.get_unique_data_id(info_package["id"])
object_store.put_shared_info(info_package["id"], info_package)
shared_store.put_shared_info(info_package["id"], info_package)

# check data in shared memory
is_contained_in_shared_memory = SharedStore.get_instance().contains(
info_package["id"]
)
mpi_state = communication.MPIState.get_instance()
log_name = f'shared_store{mpi_state.rank}'
sh_logger = common.get_logger(log_name, f'log_name.log')

if not is_contained_in_shared_memory:
sh_buf = shared_store.get_shared_buffer(info_package["id"])
# recv serialized data to shared memory
owner_monitor = communication.MPIState.get_instance().get_monitor_by_worker_rank(owner_rank)
communication.send_simple_operation(
comm,
operation_type=common.Operation.REQUEST_SHARED_DATA,
operation_data=info_package,
dest_rank=owner_monitor,
)
communication.mpi_recv_shared_buffer(
comm,
sh_buf,
owner_monitor
)
sh_logger.debug('\n\nGET SHARED BUFFER')
sh_logger.debug(sh_buf[:100].tobytes())
shared_store.put_service_info(info_package["id"])
try:
data = shared_store.get(info_package["id"])
except Exception as ex:
sh_logger.debug(info_package["id"])
sh_logger.debug(info_package)
sh_logger.debug(is_contained_in_shared_memory)
sh_logger.exception(ex)
raise
return {
"id": info_package["id"],
"data": object_store.get_shared_data(info_package["id"]),
"data": data,
}
if info_package["package_type"] == communication.DataInfoType.LOCAL_DATA:
elif info_package["package_type"] == communication.DataInfoType.LOCAL_DATA:
return communication.recv_complex_buffers(comm, owner_rank, info_package)
else:
raise ValueError("Unexpected package of data info!")
Expand Down Expand Up @@ -238,11 +275,13 @@ def _push_shared_data(dest_rank, data_id, is_blocking_op):
value : unidist.core.backends.mpi.core.common.MasterDataID
An ID to data.
"""
object_store = ObjectStore.get_instance()
mpi_state = communication.MPIState.get_instance()
shared_store = SharedStore.get_instance()
mpi_state = communication.MPIState.get_instance()
operation_type = common.Operation.PUT_SHARED_DATA
info_package = object_store.get_data_shared_info(data_id)
async_operations = AsyncOperations.get_instance()
info_package = shared_store.get_data_shared_info(data_id)
logger.debug(info_package)
if is_blocking_op:
communication.mpi_send_object(mpi_state.comm, info_package, dest_rank)
else:
Expand Down Expand Up @@ -297,6 +336,7 @@ def push_data(dest_rank, value, is_blocking_op=False):
Arguments to be sent.
"""
object_store = ObjectStore.get_instance()
shared_store = SharedStore.get_instance()

if isinstance(value, (list, tuple)):
for v in value:
Expand All @@ -306,7 +346,7 @@ def push_data(dest_rank, value, is_blocking_op=False):
push_data(dest_rank, v)
elif is_data_id(value):
data_id = value
if object_store.contains_shared_memory(data_id):
if shared_store.contains_shared_info(data_id):
_push_shared_data(dest_rank, data_id, is_blocking_op)
elif object_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
Expand All @@ -327,23 +367,18 @@ def push_data(dest_rank, value, is_blocking_op=False):

def put_to_shared_memory(data_id):
object_store = ObjectStore.get_instance()
shared_store = SharedStore.get_instance()

operation_data = object_store.get(data_id)
reservation_data, serialized_data = object_store.reserve_shared_memory(
operation_data
)
# mpi_state = communication.MPIState.get_instance()
# reservation_data, serialized_data = communication.reserve_shared_memory(
# mpi_state.comm,
# data_id,
# operation_data,
# is_serialized=False
# reservation_data, serialized_data = shared_store.reserve_shared_memory(
# operation_data
# )

s_data_len, buffer_lens, buffer_count, first_index = object_store.put_shared_memory(
data_id, reservation_data, serialized_data
)
sharing_info = communication.get_shared_info(
data_id, s_data_len, buffer_lens, buffer_count, first_index
mpi_state = communication.MPIState.get_instance()
reservation_data, serialized_data = communication.reserve_shared_memory(
mpi_state.comm,
data_id,
operation_data,
is_serialized=False
)
object_store.put_shared_info(data_id, sharing_info)

shared_store.put(data_id, reservation_data, serialized_data)
22 changes: 20 additions & 2 deletions unidist/core/backends/mpi/core/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.shared_store import SharedStore

# TODO: Find a way to move this after all imports
mpi4py.rc(recv_mprobe=False, initialize=False)
from mpi4py import MPI # noqa: E402

mpi_state = communication.MPIState.get_instance()
log_name = f'monitor_{mpi_state.rank}'
logger = common.get_logger(log_name, f'{log_name}.log')


class TaskCounter:
__instance = None
Expand Down Expand Up @@ -166,6 +171,8 @@ def monitor_loop():
async_operations = AsyncOperations.get_instance()
wait_handler = WaitHandler.get_instance()
data_id_tracker = DataIDTracker.get_instance()
shared_store = shared_store = SharedStore.get_instance()


shared_index = 0

Expand All @@ -175,15 +182,15 @@ def monitor_loop():
# Proceed the request
if operation_type == common.Operation.TASK_DONE:
task_counter.increment()
output_data_ids = communication.recv_simple_operation(
output_data_ids = communication.recv_simple_data(
mpi_state.comm, source_rank
)
data_id_tracker.add_to_completed(output_data_ids)
wait_handler.process_wait_requests()
elif operation_type == common.Operation.WAIT:
# TODO: WAIT request can be received from several workers,
# but not only from master. Handle this case when requested.
operation_data = communication.recv_simple_operation(
operation_data = communication.recv_simple_data(
mpi_state.comm, source_rank
)
awaited_data_ids = operation_data["data_ids"]
Expand All @@ -205,6 +212,17 @@ def monitor_loop():
communication.mpi_send_object(
mpi_state.comm, data=(first_index, last_index), dest_rank=source_rank
)
elif operation_type == common.Operation.REQUEST_SHARED_DATA:
info_package = communication.recv_simple_data(mpi_state.comm, source_rank)
if not shared_store.contains_shared_info(info_package["id"]):
shared_store.put_shared_info(info_package["id"], info_package)
sh_buf = shared_store.get_shared_buffer(info_package["id"])
logger.debug(sh_buf[:100].tobytes())
communication.mpi_send_shared_buffer(
mpi_state.comm,
sh_buf,
dest_rank=source_rank,
)
elif operation_type == common.Operation.CANCEL:
async_operations.finish()
if not MPI.Is_finalized():
Expand Down
Loading

0 comments on commit 1c7d528

Please sign in to comment.