Skip to content

Commit

Permalink
Add cleanup shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Jul 6, 2023
1 parent 3ded6a4 commit 76bd87c
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 118 deletions.
58 changes: 28 additions & 30 deletions unidist/core/backends/mpi/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def is_monitor_process(self, rank=None):
rank = self.rank
return self.__host_rank_by_rank[rank] == MPIRank.MONITOR

def get_monitor_by_worker_rank(self, rank):
def get_monitor_by_worker_rank(self, rank=None):
"""
Get the monitor process rank for the host that includes this rank
Expand All @@ -228,6 +228,8 @@ def get_monitor_by_worker_rank(self, rank):
if self.host_comm is None:
return MPIRank.MONITOR

if rank is None:
rank = self.rank
host = self.__host_by_rank[rank]
if host is None:
raise ValueError("Unknown rank of workers")
Expand All @@ -247,52 +249,46 @@ def reserve_shared_memory(comm, data_id, data, is_serialized=False):
if is_serialized:
s_data = data["s_data"]
raw_buffers = data["raw_buffers"]

reservation_data = _send_reserve_operation_impl(
comm, data_id, s_data, raw_buffers
)

data_size = len(s_data) + sum([len(buf) for buf in raw_buffers])
reservation_data = send_reserve_operation(comm, data_id, data_size)
return reservation_data, None
else:
reservation_data, serialized_data = send_reserve_operation(comm, data_id, data)
serializer = ComplexDataSerializer()
# Main job
s_data = serializer.serialize(data)
# Retrive the metadata
raw_buffers = serializer.buffers
buffer_count = serializer.buffer_count

data_size = len(s_data) + sum([len(buf) for buf in raw_buffers])

reservation_data = send_reserve_operation(comm, data_id, data_size)
serialized_data = {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}

return reservation_data, serialized_data


def _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers):
def send_reserve_operation(comm, data_id, data_size):
operation_type = common.Operation.RESERVE_SHARED_MEMORY
mpi_state = MPIState.get_instance()

operation_data = {
"id": data_id,
"size": len(s_data) + sum([len(buf) for buf in raw_buffers]),
"size": data_size,
}
# We use a blocking send here because we have to wait for
# completion of the communication, which is necessary for the pipeline to continue.
send_simple_operation(
comm,
operation_type,
operation_data,
mpi_state.get_monitor_by_worker_rank(MPIRank.ROOT),
mpi_state.get_monitor_by_worker_rank(),
)
firstIndex, lastIndex = mpi_busy_wait_recv(comm, MPIRank.MONITOR)
return {"firstIndex": firstIndex, "lastIndex": lastIndex}


def send_reserve_operation(comm, data_id, data):
serializer = ComplexDataSerializer()
# Main job
s_data = serializer.serialize(data)
# Retrive the metadata
raw_buffers = serializer.buffers
buffer_count = serializer.buffer_count

reservation_data = _send_reserve_operation_impl(comm, data_id, s_data, raw_buffers)

return reservation_data, {
"s_data": s_data,
"raw_buffers": raw_buffers,
"buffer_count": buffer_count,
}
return mpi_busy_wait_recv(comm, mpi_state.get_monitor_by_worker_rank())


# ------------------ #
Expand Down Expand Up @@ -324,7 +320,7 @@ def get_data_info(s_data_len, raw_buffers_lens, buffer_count):


def get_shared_info(
data_id, s_data_len, raw_buffers_lens, buffer_count, first_shared_index
data_id, s_data_len, raw_buffers_lens, buffer_count, first_shared_index, last_shared_index, service_index
):
info_package = {}
info_package["package_type"] = DataInfoType.SHARED_DATA
Expand All @@ -333,6 +329,8 @@ def get_shared_info(
info_package["raw_buffers_lens"] = raw_buffers_lens
info_package["buffer_count"] = buffer_count
info_package["first_shared_index"] = first_shared_index
info_package["last_shared_index"] = last_shared_index
info_package["service_index"] = service_index
return info_package


Expand Down
18 changes: 11 additions & 7 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,17 @@ def get_complex_data(comm, owner_rank):
object_store = ObjectStore.get_instance()
shared_store = SharedStore.get_instance()
info_package["id"] = object_store.get_unique_data_id(info_package["id"])
shared_store.put_shared_info(info_package["id"], info_package)

# check data in shared memory
is_contained_in_shared_memory = SharedStore.get_instance().contains(
info_package["id"]
)
if shared_store.check_serice_index(info_package["id"], info_package["service_index"]):
shared_store.put_shared_info(info_package["id"], info_package)
else:
shared_data_len = info_package["last_shared_index"] - info_package["first_shared_index"]
reservation_info = communication.send_reserve_operation(comm, info_package["id"], shared_data_len)
info_package["first_shared_index"] = reservation_info["first_index"]
info_package["service_index"] = reservation_info["service_index"]
shared_store.put_shared_info(info_package["id"], info_package)

if not is_contained_in_shared_memory:
sh_buf = shared_store.get_shared_buffer(info_package["id"])
# recv serialized data to shared memory
owner_monitor = (
Expand All @@ -133,13 +136,14 @@ def get_complex_data(comm, owner_rank):
dest_rank=owner_monitor,
)
communication.mpi_recv_shared_buffer(comm, sh_buf, owner_monitor)
shared_store.put_service_info(info_package["id"])
shared_store.put_service_info(info_package["service_index"], info_package["id"], info_package["first_shared_index"])

data = shared_store.get(info_package["id"])
return {
"id": info_package["id"],
"data": data,
}
if info_package["package_type"] == communication.DataInfoType.LOCAL_DATA:
elif info_package["package_type"] == communication.DataInfoType.LOCAL_DATA:
return communication.recv_complex_data(comm, owner_rank, info=info_package)
else:
raise ValueError("Unexpected package of data info!")
Expand Down
19 changes: 10 additions & 9 deletions unidist/core/backends/mpi/core/controller/garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ def _send_cleanup_request(self, cleanup_list):
# Cache serialized list of data IDs
s_cleanup_list = SimpleDataSerializer().serialize_pickle(cleanup_list)
async_operations = AsyncOperations.get_instance()
for rank_id in mpi_state.workers:
if rank_id != mpi_state.rank:
h_list = communication.isend_serialized_operation(
mpi_state.comm,
common.Operation.CLEANUP,
s_cleanup_list,
rank_id,
)
async_operations.extend(h_list)
for host in mpi_state.topology:
for rank_id in mpi_state.topology[host]:
if not mpi_state.is_root_process(rank_id) and rank_id != mpi_state.rank:
h_list = communication.isend_serialized_operation(
mpi_state.comm,
common.Operation.CLEANUP,
s_cleanup_list,
rank_id,
)
async_operations.extend(h_list)

def increment_task_counter(self):
"""
Expand Down
22 changes: 15 additions & 7 deletions unidist/core/backends/mpi/core/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.shared_store import SharedStore
from unidist.core.backends.mpi.core.shared_store import SharedMemoryMahager, SharedStore

# TODO: Find a way to move this after all imports
mpi4py.rc(recv_mprobe=False, initialize=False)
Expand Down Expand Up @@ -165,14 +165,15 @@ def monitor_loop():
mpi_state = communication.MPIState.get_instance()
wait_handler = WaitHandler.get_instance()
data_id_tracker = DataIDTracker.get_instance()
shared_store = shared_store = SharedStore.get_instance()
shared_store = SharedStore.get_instance()

shared_index = 0
workers_ready_to_shutdown = []
shutdown_workers = False
# Once all workers excluding ``Root`` and ``Monitor`` ranks are ready to shutdown,
# ``Monitor` sends the shutdown signal to every worker, as well as notifies ``Root`` that
# it can exit the program.
shm_manager = SharedMemoryMahager(shared_store.shared_memory_size)

while True:
# Listen receive operation from any source
operation_type, source_rank = communication.mpi_recv_operation(mpi_state.comm)
Expand All @@ -199,11 +200,13 @@ def monitor_loop():
)
elif operation_type == common.Operation.RESERVE_SHARED_MEMORY:
request = communication.mpi_recv_object(mpi_state.comm, source_rank)
first_index = shared_index
last_index = first_index + request["size"]
shared_index = last_index
if request["id"] in shm_manager.deleted_ids:
communication.mpi_send_object(
mpi_state.comm, data=ValueError("This data was already deleted."), dest_rank=source_rank
)
reservation_info = shm_manager.put(request["id"], request["size"])
communication.mpi_send_object(
mpi_state.comm, data=(first_index, last_index), dest_rank=source_rank
mpi_state.comm, data=reservation_info, dest_rank=source_rank
)
elif operation_type == common.Operation.REQUEST_SHARED_DATA:
info_package = communication.mpi_recv_object(mpi_state.comm, source_rank)
Expand All @@ -215,6 +218,11 @@ def monitor_loop():
sh_buf,
dest_rank=source_rank,
)
elif operation_type == common.Operation.CLEANUP:
cleanup_list = communication.recv_serialized_data(
mpi_state.comm, source_rank
)
shm_manager.clear(cleanup_list)
elif operation_type == common.Operation.READY_TO_SHUTDOWN:
workers_ready_to_shutdown.append(source_rank)
shutdown_workers = (
Expand Down
1 change: 1 addition & 0 deletions unidist/core/backends/mpi/core/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def clear(self, cleanup_list):
"""
for data_id in cleanup_list:
self._data_id_map.pop(data_id, None)
self._data_id_map.pop(data_id, None)

def generate_data_id(self, gc):
"""
Expand Down
Loading

0 comments on commit 76bd87c

Please sign in to comment.