Skip to content

Commit

Permalink
tracking tasks outside the condition
Browse files Browse the repository at this point in the history
  • Loading branch information
arunjose696 committed Apr 19, 2023
1 parent 9a3ec5a commit 6c6c0f2
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 70 deletions.
3 changes: 1 addition & 2 deletions unidist/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RayObjectStoreMemory,
)
from .backends.dask import DaskMemoryLimit, IsDaskCluster, DaskSchedulerAddress
from .backends.mpi import IsMpiSpawnWorkers, MpiHosts, MpiPickleThreshold, BackOff
from .backends.mpi import IsMpiSpawnWorkers, MpiHosts, MpiPickleThreshold
from .parameter import ValueSource

__all__ = [
Expand All @@ -31,5 +31,4 @@
"MpiHosts",
"ValueSource",
"MpiPickleThreshold",
"BackOff",
]
4 changes: 2 additions & 2 deletions unidist/config/backends/mpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

"""Config entities specific for MPI backend which can be used for unidist behavior tuning."""

from .envvars import IsMpiSpawnWorkers, MpiHosts, MpiPickleThreshold, BackOff
from .envvars import IsMpiSpawnWorkers, MpiHosts, MpiPickleThreshold

__all__ = ["IsMpiSpawnWorkers", "MpiHosts", "MpiPickleThreshold", "BackOff"]
__all__ = ["IsMpiSpawnWorkers", "MpiHosts", "MpiPickleThreshold"]
7 changes: 0 additions & 7 deletions unidist/config/backends/mpi/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,3 @@ class MpiPickleThreshold(EnvironmentVariable, type=int):

default = 1024**2 // 4 # 0.25 MiB
varname = "UNIDIST_MPI_PICKLE_THRESHOLD"


class BackOff(EnvironmentVariable, type=int):
"""Backoff value for sleeping background threads when thread idle"""

default = 0.001
varname = "BackOff"
7 changes: 2 additions & 5 deletions unidist/core/backends/mpi/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def mpi_send_object(comm, data, dest_rank):
comm.send(data, dest=dest_rank)


def mpi_isend_object(comm, data, dest_rank, tag=0):
def mpi_isend_object(comm, data, dest_rank):
"""
Send Python object to another MPI rank in a non-blocking way.
Expand All @@ -193,16 +193,13 @@ def mpi_isend_object(comm, data, dest_rank, tag=0):
Data to send.
dest_rank : int
Target MPI process to transfer data.
tag : int
To recieve only data with a label.
Used when background thread polls for data with a specific label.
Returns
-------
object
A handler to MPI_Isend communication result.
"""
return comm.isend(data, dest=dest_rank, tag=tag)
return comm.isend(data, dest=dest_rank)


def mpi_send_buffer(comm, buffer_size, buffer, dest_rank):
Expand Down
54 changes: 0 additions & 54 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import atexit
import signal
import asyncio
import time
from collections import defaultdict
import threading

try:
import mpi4py
Expand All @@ -37,7 +35,6 @@
MpiHosts,
ValueSource,
MpiPickleThreshold,
BackOff,
)


Expand All @@ -52,49 +49,6 @@
topology = dict()
# The global variable is responsible for if MPI backend has already been initialized
is_mpi_initialized = False
# List is used to keep keep track of th threads started so they could be later joined
threads = []
BACKOFF = BackOff.get_value_source()
# The global variable acts as a flag which when set true the function executing in background thread stops
exit_flag = False


class Backoff:
def __init__(self, seconds=BACKOFF):
self.tval = 0.0
self.tmax = max(float(seconds), 0.0)
self.tmin = self.tmax / (1 << 10)

def reset(self):
self.tval = 0.0

def sleep(self):
time.sleep(self.tval)
self.tval = min(self.tmax, max(self.tmin, self.tval * 2))


class Poller(threading.Thread):
def __init__(self, thread_id, name, comm):
threading.Thread.__init__(self, daemon=True)
self.thread_id = thread_id
self.name = name
self.comm = comm

def run(self):
poll_tasks_completed(self.name, self.comm)


def poll_tasks_completed(threadName, comm):
global exit_flag
scheduler = Scheduler.get_instance()
backoff = Backoff()
while not exit_flag:
if comm.iprobe(source=communication.MPIRank.MONITOR, tag=1):
task_completed_rank = comm.recv(source=communication.MPIRank.MONITOR, tag=1)
scheduler.decrement_tasks_on_worker(task_completed_rank)
backoff.reset()
else:
backoff.sleep()


def init():
Expand Down Expand Up @@ -172,10 +126,6 @@ def init():
mpi_state = communication.MPIState.get_instance(
comm, comm.Get_rank(), comm.Get_size()
)
# if rank == 0 and not threads and parent_comm == MPI.COMM_NULL:
# thread = Poller(1, "Thread_Poll_Tasks", comm)
# thread.start()
# threads.append(thread)

global topology
if not topology:
Expand Down Expand Up @@ -229,11 +179,7 @@ def shutdown():
-----
Sends cancelation operation to all workers and monitor processes.
"""
global exit_flag, threads
exit_flag = True
mpi_state = communication.MPIState.get_instance()
for thread in threads:
thread.join()
# Send shutdown commands to all ranks
for rank_id in range(communication.MPIRank.MONITOR, mpi_state.world_size):
# We use a blocking send here because we have to wait for
Expand Down

0 comments on commit 6c6c0f2

Please sign in to comment.