Skip to content

Commit

Permalink
Add cleanup shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Jul 11, 2023
1 parent 3ded6a4 commit de0e548
Show file tree
Hide file tree
Showing 10 changed files with 298 additions and 130 deletions.
65 changes: 31 additions & 34 deletions unidist/core/backends/mpi/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def __init__(self, comm):
cluster_info = self.comm.allgather((self.host, self.rank, host_rank))

self.topology = defaultdict(dict)
self.__host_rank_by_rank = defaultdict(None)
self.__host_by_rank = defaultdict(None)
for host, rank, host_rank in cluster_info:
self.topology[host][host_rank] = rank
self.__host_rank_by_rank[rank] = host_rank
self.__host_by_rank[rank] = host

self.monitor_processes = [self.topology[host][MPIRank.MONITOR] for host in self.topology]

self.workers = []
for host in self.topology:
self.workers.extend(
Expand Down Expand Up @@ -209,9 +209,9 @@ def is_monitor_process(self, rank=None):
"""
if rank is None:
rank = self.rank
return self.__host_rank_by_rank[rank] == MPIRank.MONITOR
return rank in self.monitor_processes

def get_monitor_by_worker_rank(self, rank):
def get_monitor_by_worker_rank(self, rank=None):
"""
Get the monitor process rank for the host that includes this rank
Expand All @@ -228,6 +228,8 @@ def get_monitor_by_worker_rank(self, rank):
if self.host_comm is None:
return MPIRank.MONITOR

if rank is None:
rank = self.rank
host = self.__host_by_rank[rank]
if host is None:
raise ValueError("Unknown rank of workers")
Expand All @@ -247,52 +249,46 @@ def reserve_shared_memory(comm, data_id, data, is_serialized=False):
if is_serialized:
s_data = data["s_data"]
raw_buffers = data["raw_buffers"]

reservation_data = _send_reserve_operation_impl(
comm, data_id, s_data, raw_buffers
)

data_size = len(s_data) + sum([len(buf) for buf in raw_buffers])
reservation_data = send_reserve_operation(comm, data_id, data_size)
return reservation_data, None
else:
reservation_data, serialized_data = send_reserve_operation(comm, data_id, data)
serializer = ComplexDataSerializer()
# Main job
s_data = serializer.serialize(data)
# Retrive the metadata
raw_buffers = serializer.buffers
buffer_count = serializer.buffer_count

data_size = len(s_data) + sum([len(buf) for buf in raw_buffers])

reservation_data = send_reserve_operation(comm, data_id, data_size)
serialized_data = {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}

return reservation_data, serialized_data


def _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers):
def send_reserve_operation(comm, data_id, data_size):
operation_type = common.Operation.RESERVE_SHARED_MEMORY
mpi_state = MPIState.get_instance()

operation_data = {
"id": data_id,
"size": len(s_data) + sum([len(buf) for buf in raw_buffers]),
"size": data_size,
}
# We use a blocking send here because we have to wait for
# completion of the communication, which is necessary for the pipeline to continue.
send_simple_operation(
comm,
operation_type,
operation_data,
mpi_state.get_monitor_by_worker_rank(MPIRank.ROOT),
mpi_state.get_monitor_by_worker_rank(),
)
firstIndex, lastIndex = mpi_busy_wait_recv(comm, MPIRank.MONITOR)
return {"firstIndex": firstIndex, "lastIndex": lastIndex}


def send_reserve_operation(comm, data_id, data):
serializer = ComplexDataSerializer()
# Main job
s_data = serializer.serialize(data)
# Retrive the metadata
raw_buffers = serializer.buffers
buffer_count = serializer.buffer_count

reservation_data = _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers)

return reservation_data, {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}
return mpi_busy_wait_recv(comm, mpi_state.get_monitor_by_worker_rank())


# ------------------ #
Expand Down Expand Up @@ -324,15 +320,16 @@ def get_data_info(s_data_len, raw_buffers_lens, buffer_count):


def get_shared_info(
data_id, s_data_len, raw_buffers_lens, buffer_count, first_shared_index
s_data_len, raw_buffers_lens, buffer_count, first_shared_index, last_shared_index, service_index
):
info_package = {}
info_package["package_type"] = DataInfoType.SHARED_DATA
info_package["id"] = data_id
info_package["s_data_len"] = s_data_len
info_package["raw_buffers_lens"] = raw_buffers_lens
info_package["buffer_count"] = buffer_count
info_package["first_shared_index"] = first_shared_index
info_package["last_shared_index"] = last_shared_index
info_package["service_index"] = service_index
return info_package


Expand Down
2 changes: 1 addition & 1 deletion unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def init():
signal.signal(signal.SIGINT, _termination_handler)
return
elif mpi_state.is_monitor_process():
from unidist.core.backends.mpi.core.monitor import monitor_loop
from unidist.core.backends.mpi.core.monitor.loop import monitor_loop

monitor_loop()
# If the user executes a program in SPMD mode,
Expand Down
28 changes: 17 additions & 11 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,20 @@ def get_complex_data(comm, owner_rank):
if info_package["package_type"] == communication.DataInfoType.SHARED_DATA:
object_store = ObjectStore.get_instance()
shared_store = SharedStore.get_instance()
info_package["id"] = object_store.get_unique_data_id(info_package["id"])
shared_store.put_shared_info(info_package["id"], info_package)
data_id = info_package.pop("id", None)
data_id = object_store.get_unique_data_id(data_id)

# check data in shared memory
is_contained_in_shared_memory = SharedStore.get_instance().contains(
info_package["id"]
)
if shared_store.check_serice_index(data_id, info_package["service_index"]):
shared_store.put_shared_info(data_id, info_package)
else:
shared_data_len = info_package["last_shared_index"] - info_package["first_shared_index"]
reservation_info = communication.send_reserve_operation(comm, data_id, shared_data_len)
info_package["first_shared_index"] = reservation_info["first_index"]
info_package["service_index"] = reservation_info["service_index"]
shared_store.put_shared_info(data_id, info_package)

if not is_contained_in_shared_memory:
sh_buf = shared_store.get_shared_buffer(info_package["id"])
sh_buf = shared_store.get_shared_buffer(data_id)
# recv serialized data to shared memory
owner_monitor = (
communication.MPIState.get_instance().get_monitor_by_worker_rank(
Expand All @@ -133,13 +137,14 @@ def get_complex_data(comm, owner_rank):
dest_rank=owner_monitor,
)
communication.mpi_recv_shared_buffer(comm, sh_buf, owner_monitor)
shared_store.put_service_info(info_package["id"])
data = shared_store.get(info_package["id"])
shared_store.put_service_info(info_package["service_index"], data_id, info_package["first_shared_index"])

data = shared_store.get(data_id)
return {
"id": info_package["id"],
"id": data_id,
"data": data,
}
if info_package["package_type"] == communication.DataInfoType.LOCAL_DATA:
elif info_package["package_type"] == communication.DataInfoType.LOCAL_DATA:
return communication.recv_complex_data(comm, owner_rank, info=info_package)
else:
raise ValueError("Unexpected package of data info!")
Expand Down Expand Up @@ -267,6 +272,7 @@ def _push_shared_data(dest_rank, data_id, is_blocking_op):
operation_type = common.Operation.PUT_SHARED_DATA
async_operations = AsyncOperations.get_instance()
info_package = shared_store.get_data_shared_info(data_id)
info_package["id"] = data_id
if is_blocking_op:
communication.mpi_send_object(mpi_state.comm, info_package, dest_rank)
else:
Expand Down
19 changes: 10 additions & 9 deletions unidist/core/backends/mpi/core/controller/garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ def _send_cleanup_request(self, cleanup_list):
# Cache serialized list of data IDs
s_cleanup_list = SimpleDataSerializer().serialize_pickle(cleanup_list)
async_operations = AsyncOperations.get_instance()
for rank_id in mpi_state.workers:
if rank_id != mpi_state.rank:
h_list = communication.isend_serialized_operation(
mpi_state.comm,
common.Operation.CLEANUP,
s_cleanup_list,
rank_id,
)
async_operations.extend(h_list)
for host in mpi_state.topology:
for rank_id in mpi_state.topology[host]:
if not mpi_state.is_root_process(rank_id) and rank_id != mpi_state.rank:
h_list = communication.isend_serialized_operation(
mpi_state.comm,
common.Operation.CLEANUP,
s_cleanup_list,
rank_id,
)
async_operations.extend(h_list)

def increment_task_counter(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions unidist/core/backends/mpi/core/monitor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2021-2023 Modin authors
#
# SPDX-License-Identifier: Apache-2.0

"""MPI backend functionality related to `monitor` concept."""
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.monitor.shared_memory_manager import SharedMemoryMahager
from unidist.core.backends.mpi.core.shared_store import SharedStore

# TODO: Find a way to move this after all imports
Expand Down Expand Up @@ -165,14 +166,15 @@ def monitor_loop():
mpi_state = communication.MPIState.get_instance()
wait_handler = WaitHandler.get_instance()
data_id_tracker = DataIDTracker.get_instance()
shared_store = shared_store = SharedStore.get_instance()
shared_store = SharedStore.get_instance()

shared_index = 0
workers_ready_to_shutdown = []
shutdown_workers = False
# Once all workers excluding ``Root`` and ``Monitor`` ranks are ready to shutdown,
# ``Monitor` sends the shutdown signal to every worker, as well as notifies ``Root`` that
# it can exit the program.
shm_manager = SharedMemoryMahager(shared_store.shared_memory_size)

while True:
# Listen receive operation from any source
operation_type, source_rank = communication.mpi_recv_operation(mpi_state.comm)
Expand All @@ -199,11 +201,13 @@ def monitor_loop():
)
elif operation_type == common.Operation.RESERVE_SHARED_MEMORY:
request = communication.mpi_recv_object(mpi_state.comm, source_rank)
first_index = shared_index
last_index = first_index + request["size"]
shared_index = last_index
if request["id"] in shm_manager.deleted_ids:
communication.mpi_send_object(
mpi_state.comm, data=ValueError("This data was already deleted."), dest_rank=source_rank
)
reservation_info = shm_manager.put(request["id"], request["size"])
communication.mpi_send_object(
mpi_state.comm, data=(first_index, last_index), dest_rank=source_rank
mpi_state.comm, data=reservation_info, dest_rank=source_rank
)
elif operation_type == common.Operation.REQUEST_SHARED_DATA:
info_package = communication.mpi_recv_object(mpi_state.comm, source_rank)
Expand All @@ -215,6 +219,11 @@ def monitor_loop():
sh_buf,
dest_rank=source_rank,
)
elif operation_type == common.Operation.CLEANUP:
cleanup_list = communication.recv_serialized_data(
mpi_state.comm, source_rank
)
shm_manager.clear(cleanup_list)
elif operation_type == common.Operation.READY_TO_SHUTDOWN:
workers_ready_to_shutdown.append(source_rank)
shutdown_workers = (
Expand Down
Loading

0 comments on commit de0e548

Please sign in to comment.