Skip to content

Commit

Permalink
FIX-#340: Fix hangs at low UNIDIST_MPI_SHARED_OBJECT_STORE_THRESHOLD
Browse files Browse the repository at this point in the history
Signed-off-by: Kirill Suvorov <[email protected]>
  • Loading branch information
Retribution98 committed Nov 3, 2023
1 parent 8f19490 commit e51e133
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/flow/unidist/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Unidist Configuration Settings List
+-------------------------------+-------------------------------------------+--------------------------------------------------------------------------+
| MpiSharedObjectStoreMemory | UNIDIST_MPI_SHARED_OBJECT_STORE_MEMORY | How many bytes of memory to start the shared object store with |
+-------------------------------+-------------------------------------------+--------------------------------------------------------------------------+
| MpiSharedServiceMemory | UNIDIST_MPI_SHARED_SERVICE_MEMORY | How many bytes of memory to start the shared service memory with |
+-------------------------------+-------------------------------------------+--------------------------------------------------------------------------+
| MpiSharedObjectStoreThreshold | UNIDIST_MPI_SHARED_OBJECT_STORE_THRESHOLD | Minimum size of data to put into the shared object store |
+-------------------------------+-------------------------------------------+--------------------------------------------------------------------------+
| MpiRuntimeEnv | Only the config API is available | Runtime environment for MPI worker processes |
Expand Down
2 changes: 2 additions & 0 deletions unidist/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MpiLog,
MpiSharedObjectStore,
MpiSharedObjectStoreMemory,
MpiSharedServiceMemory,
MpiSharedObjectStoreThreshold,
MpiRuntimeEnv,
)
Expand All @@ -45,6 +46,7 @@
"MpiLog",
"MpiSharedObjectStore",
"MpiSharedObjectStoreMemory",
"MpiSharedServiceMemory",
"MpiSharedObjectStoreThreshold",
"MpiRuntimeEnv",
]
2 changes: 2 additions & 0 deletions unidist/config/backends/mpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MpiLog,
MpiSharedObjectStore,
MpiSharedObjectStoreMemory,
MpiSharedServiceMemory,
MpiSharedObjectStoreThreshold,
MpiRuntimeEnv,
)
Expand All @@ -24,6 +25,7 @@
"MpiLog",
"MpiSharedObjectStore",
"MpiSharedObjectStoreMemory",
"MpiSharedServiceMemory",
"MpiSharedObjectStoreThreshold",
"MpiRuntimeEnv",
]
6 changes: 6 additions & 0 deletions unidist/config/backends/mpi/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class MpiSharedObjectStoreMemory(EnvironmentVariable, type=int):
varname = "UNIDIST_MPI_SHARED_OBJECT_STORE_MEMORY"


class MpiSharedServiceMemory(EnvironmentVariable, type=int):
"""How many bytes of memory to start the shared service memory with."""

varname = "UNIDIST_MPI_SHARED_SERVICE_MEMORY"


class MpiSharedObjectStoreThreshold(EnvironmentVariable, type=int):
"""Minimum size of data to put into the shared object store."""

Expand Down
5 changes: 5 additions & 0 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
MpiLog,
MpiSharedObjectStore,
MpiSharedObjectStoreMemory,
MpiSharedServiceMemory,
MpiSharedObjectStoreThreshold,
MpiRuntimeEnv,
)
Expand Down Expand Up @@ -170,6 +171,10 @@ def init():
py_str += [
f"cfg.MpiSharedObjectStoreMemory.put({MpiSharedObjectStoreMemory.get()})"
]
if MpiSharedServiceMemory.get_value_source() != ValueSource.DEFAULT:
py_str += [
f"cfg.MpiSharedServiceMemory.put({MpiSharedServiceMemory.get()})"
]
if MpiSharedObjectStoreThreshold.get_value_source() != ValueSource.DEFAULT:
py_str += [
f"cfg.MpiSharedObjectStoreThreshold.put({MpiSharedObjectStoreThreshold.get()})"
Expand Down
93 changes: 75 additions & 18 deletions unidist/core/backends/mpi/core/shared_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from unidist.config.backends.mpi.envvars import (
MpiSharedObjectStoreMemory,
MpiSharedServiceMemory,
MpiSharedObjectStoreThreshold,
MpiBackoff,
)
Expand Down Expand Up @@ -126,10 +127,6 @@ def _get_allowed_memory_size(self):
int
The number of bytes available to allocate shared memory.
"""
shared_object_store_memory = MpiSharedObjectStoreMemory.get()
if shared_object_store_memory is not None:
return shared_object_store_memory

virtual_memory = psutil.virtual_memory().total
if sys.platform.startswith("linux"):
shm_fd = os.open("/dev/shm", os.O_RDONLY)
Expand All @@ -154,30 +151,90 @@ def _allocate_shared_memory(self):
"""
mpi_state = communication.MPIState.get_instance()

# Use only 95% of available shared memory because
# the rest is needed for intermediate shared buffers
# handled by MPI itself for communication of small messages.
# Shared memory is allocated only once by the monitor process.
self.shared_memory_size = (
int(self._get_allowed_memory_size() * 0.95)
if mpi_state.is_monitor_process()
else 0
)
shared_object_store_memory = MpiSharedObjectStoreMemory.get()
shared_service_memory = MpiSharedServiceMemory.get()
allowed_memory_size = int(self._get_allowed_memory_size() * 0.95)

if shared_object_store_memory is not None:
if shared_service_memory is not None:
self.shared_memory_size = shared_object_store_memory
self.service_memory_size = shared_service_memory
else:
self.shared_memory_size = shared_object_store_memory
# To avoid division by 0
if MpiSharedObjectStoreThreshold.get() > 0:
self.service_memory_size = min(
# allowd memory for serivce buffer
allowed_memory_size - self.shared_memory_size,
# maximum amount of memory required for the service buffer
(self.shared_memory_size // MpiSharedObjectStoreThreshold.get())
* (self.INFO_SIZE * MPI.LONG.size),
)
else:
self.service_memory_size = (
allowed_memory_size - self.shared_memory_size
)
else:
if shared_service_memory is not None:
self.service_memory_size = shared_service_memory
self.shared_memory_size = allowed_memory_size - self.service_memory_size
else:
A = allowed_memory_size
B = MpiSharedObjectStoreThreshold.get()
C = self.INFO_SIZE * MPI.LONG.size
# "x" is shared_memory_size
# "y" is service_memory_size

# requirements:
# x + y = A
# y = min[ (x/B) * C, 0.01 * A ]

# calculation results:
# if B > 99 * C:
# x = (A * B) / (B + C)
# y = (A * C) / (B + C)
# else:
# x = 0.99 * A
# y = 0.01 * A

if B > 99 * C:
self.shared_memory_size = (A * B) // (B + C)
self.service_memory_size = (A * C) // (B + C)
else:
self.shared_memory_size = int(0.99 * A)
self.service_memory_size = int(0.01 * A)

if self.shared_memory_size > allowed_memory_size:
raise ValueError(
"Memory for shared object storage cannot be allocated because the value set to `MpiSharedObjectStoreMemory` exceeds the available memory."
)

if self.service_memory_size > allowed_memory_size:
raise ValueError(
"Memory for shared service storage cannot be allocated because the value set to `MpiSharedServiceMemory` exceeds the available memory."
)

if self.shared_memory_size + self.service_memory_size > allowed_memory_size:
raise ValueError(
"The sum of the `MpiSharedObjectStoreMemory` and `MpiSharedServiceMemory` values is greater than the amount of memory that exists."
)

# Shared memory is allocated only once by the monitor process.
info = MPI.Info.Create()
info.Set("alloc_shared_noncontig", "true")
self.win = MPI.Win.Allocate_shared(
self.shared_memory_size * MPI.BYTE.size,
self.shared_memory_size * MPI.BYTE.size
if mpi_state.is_monitor_process()
else 0,
MPI.BYTE.size,
comm=mpi_state.host_comm,
info=info,
)
self.shared_buffer, _ = self.win.Shared_query(communication.MPIRank.MONITOR)

# Service shared memory is allocated only once by the monitor process
self.service_info_max_count = (
self.shared_memory_size // MpiSharedObjectStoreThreshold.get()
) * self.INFO_SIZE
self.service_info_max_count = self.service_memory_size // (
self.INFO_SIZE * MPI.BYTE.size
)
self.service_win = MPI.Win.Allocate_shared(
self.service_info_max_count * MPI.LONG.size
if mpi_state.is_monitor_process()
Expand Down

0 comments on commit e51e133

Please sign in to comment.