Skip to content

Commit

Permalink
Drain workers as walltime expiry nears (#3063)
Browse files Browse the repository at this point in the history
This is an implementation of the first part (the "easy part") of issue #3059.

It adds a parameter to HighThroughputExecutor specifying a "not quite as long as walltime" parameter, after which time workers will drain themselves: they will continue with existing tasks but ask the interchange to not send any more.

When they are drained, the worker pools will exit immediately.
  • Loading branch information
benclifford authored Mar 12, 2024
1 parent 822f060 commit 920852f
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 6 deletions.
12 changes: 12 additions & 0 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"--hb_period={heartbeat_period} "
"{address_probe_timeout_string} "
"--hb_threshold={heartbeat_threshold} "
"--drain_period={drain_period} "
"--cpu-affinity {cpu_affinity} "
"{enable_mpi_mode} "
"--mpi-launcher={mpi_launcher} "
Expand Down Expand Up @@ -201,6 +202,14 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
Timeout period to be used by the executor components in milliseconds. Increasing poll_periods
trades performance for cpu efficiency. Default: 10ms
drain_period : int
The number of seconds after start when workers will begin to drain
and then exit. Set this to a time that is slightly less than the
maximum walltime of batch jobs to avoid killing tasks while they
execute. For example, you could set this to the walltime minus a grace
period for the batch job to start the workers, minus the expected
maximum length of an individual task.
worker_logdir_root : string
In case of a remote file system, specify the path to where logs will be kept.
Expand Down Expand Up @@ -240,6 +249,7 @@ def __init__(self,
prefetch_capacity: int = 0,
heartbeat_threshold: int = 120,
heartbeat_period: int = 30,
drain_period: Optional[int] = None,
poll_period: int = 10,
address_probe_timeout: Optional[int] = None,
worker_logdir_root: Optional[str] = None,
Expand Down Expand Up @@ -303,6 +313,7 @@ def __init__(self,
self.interchange_port_range = interchange_port_range
self.heartbeat_threshold = heartbeat_threshold
self.heartbeat_period = heartbeat_period
self.drain_period = drain_period
self.poll_period = poll_period
self.run_dir = '.'
self.worker_logdir_root = worker_logdir_root
Expand Down Expand Up @@ -376,6 +387,7 @@ def initialize_scaling(self):
nodes_per_block=self.provider.nodes_per_block,
heartbeat_period=self.heartbeat_period,
heartbeat_threshold=self.heartbeat_threshold,
drain_period=self.drain_period,
poll_period=self.poll_period,
cert_dir=self.cert_dir,
logdir=self.worker_logdir,
Expand Down
26 changes: 24 additions & 2 deletions parsl/executors/high_throughput/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


PKL_HEARTBEAT_CODE = pickle.dumps((2 ** 32) - 1)
PKL_DRAINED_CODE = pickle.dumps((2 ** 32) - 2)

LOGGER_NAME = "interchange"
logger = logging.getLogger(LOGGER_NAME)
Expand Down Expand Up @@ -308,7 +309,8 @@ def _command_server(self) -> NoReturn:
'worker_count': m['worker_count'],
'tasks': len(m['tasks']),
'idle_duration': idle_duration,
'active': m['active']}
'active': m['active'],
'draining': m['draining']}
reply.append(resp)

elif command_req.startswith("HOLD_WORKER"):
Expand Down Expand Up @@ -385,6 +387,7 @@ def start(self) -> None:
self.process_task_outgoing_incoming(interesting_managers, hub_channel, kill_event)
self.process_results_incoming(interesting_managers, hub_channel)
self.expire_bad_managers(interesting_managers, hub_channel)
self.expire_drained_managers(interesting_managers, hub_channel)
self.process_tasks_to_send(interesting_managers)

self.zmq_context.destroy()
Expand Down Expand Up @@ -431,6 +434,7 @@ def process_task_outgoing_incoming(
'max_capacity': 0,
'worker_count': 0,
'active': True,
'draining': False,
'tasks': []}
self.connected_block_history.append(msg['block_id'])

Expand Down Expand Up @@ -469,10 +473,28 @@ def process_task_outgoing_incoming(
self._ready_managers[manager_id]['last_heartbeat'] = time.time()
logger.debug("Manager {!r} sent heartbeat via tasks connection".format(manager_id))
self.task_outgoing.send_multipart([manager_id, b'', PKL_HEARTBEAT_CODE])
elif msg['type'] == 'drain':
self._ready_managers[manager_id]['draining'] = True
logger.debug(f"Manager {manager_id!r} requested drain")
else:
logger.error(f"Unexpected message type received from manager: {msg['type']}")
logger.debug("leaving task_outgoing section")

def expire_drained_managers(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None:

for manager_id in list(interesting_managers):
# is it always true that a draining manager will be in interesting managers?
# i think so because it will have outstanding capacity?
m = self._ready_managers[manager_id]
if m['draining'] and len(m['tasks']) == 0:
logger.info(f"Manager {manager_id!r} is drained - sending drained message to manager")
self.task_outgoing.send_multipart([manager_id, b'', PKL_DRAINED_CODE])
interesting_managers.remove(manager_id)
self._ready_managers.pop(manager_id)

m['active'] = False
self._send_monitoring_info(hub_channel, m)

def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None:
# Check if there are tasks that could be sent to managers

Expand All @@ -490,7 +512,7 @@ def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None:
tasks_inflight = len(m['tasks'])
real_capacity = m['max_capacity'] - tasks_inflight

if (real_capacity and m['active']):
if (real_capacity and m['active'] and not m['draining']):
tasks = self.get_tasks(real_capacity)
if tasks:
self.task_outgoing.send_multipart([manager_id, b'', pickle.dumps(tasks)])
Expand Down
1 change: 1 addition & 0 deletions parsl/executors/high_throughput/manager_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ManagerRecord(TypedDict, total=False):
worker_count: int
max_capacity: int
active: bool
draining: bool
hostname: str
last_heartbeat: float
idle_since: Optional[float]
Expand Down
44 changes: 40 additions & 4 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from parsl.executors.high_throughput.mpi_prefix_composer import compose_all, VALID_LAUNCHERS

HEARTBEAT_CODE = (2 ** 32) - 1
DRAINED_CODE = (2 ** 32) - 2


class Manager:
Expand Down Expand Up @@ -73,7 +74,8 @@ def __init__(self, *,
enable_mpi_mode: bool = False,
mpi_launcher: str = "mpiexec",
available_accelerators: Sequence[str],
cert_dir: Optional[str]):
cert_dir: Optional[str],
drain_period: Optional[int]):
"""
Parameters
----------
Expand Down Expand Up @@ -138,6 +140,9 @@ def __init__(self, *,
cert_dir : str | None
Path to the certificate directory.
drain_period: int | None
Number of seconds to drain after TODO: could be a nicer timespec involving m,s,h qualifiers for user friendliness?
"""

logger.info("Manager initializing")
Expand Down Expand Up @@ -227,6 +232,14 @@ def __init__(self, *,
self.heartbeat_period = heartbeat_period
self.heartbeat_threshold = heartbeat_threshold
self.poll_period = poll_period

self.drain_time: float
if drain_period:
self.drain_time = self._start_time + drain_period
logger.info(f"Will request drain at {self.drain_time}")
else:
self.drain_time = float('inf')

self.cpu_affinity = cpu_affinity

# Define accelerator available, adjust worker count accordingly
Expand Down Expand Up @@ -262,10 +275,19 @@ def heartbeat_to_incoming(self):
""" Send heartbeat to the incoming task queue
"""
msg = {'type': 'heartbeat'}
# don't need to dumps and encode this every time - could do as a global on import?
b_msg = json.dumps(msg).encode('utf-8')
self.task_incoming.send(b_msg)
logger.debug("Sent heartbeat")

def drain_to_incoming(self):
""" Send heartbeat to the incoming task queue
"""
msg = {'type': 'drain'}
b_msg = json.dumps(msg).encode('utf-8')
self.task_incoming.send(b_msg)
logger.debug("Sent drain")

@wrap_with_logs
def pull_tasks(self, kill_event):
""" Pull tasks from the incoming tasks zmq pipe onto the internal
Expand Down Expand Up @@ -298,6 +320,7 @@ def pull_tasks(self, kill_event):
# time here are correctly copy-pasted from the relevant if
# statements.
next_interesting_event_time = min(last_beat + self.heartbeat_period,
self.drain_time,
last_interchange_contact + self.heartbeat_threshold)
try:
pending_task_count = self.pending_task_queue.qsize()
Expand All @@ -312,6 +335,14 @@ def pull_tasks(self, kill_event):
self.heartbeat_to_incoming()
last_beat = time.time()

if self.drain_time and time.time() > self.drain_time:
logger.info("Requesting drain")
self.drain_to_incoming()
self.drain_time = None
# This will start the pool draining...
# Drained exit behaviour does not happen here. It will be
# driven by the interchange sending a DRAINED_CODE message.

poll_duration_s = max(0, next_interesting_event_time - time.time())
socks = dict(poller.poll(timeout=poll_duration_s * 1000))

Expand All @@ -322,7 +353,9 @@ def pull_tasks(self, kill_event):

if tasks == HEARTBEAT_CODE:
logger.debug("Got heartbeat from interchange")

elif tasks == DRAINED_CODE:
logger.info("Got fulled drained message from interchange - setting kill flag")
kill_event.set()
else:
task_recv_counter += len(tasks)
logger.debug("Got executor tasks: {}, cumulative count of tasks: {}".format([t['task_id'] for t in tasks], task_recv_counter))
Expand Down Expand Up @@ -490,9 +523,8 @@ def start(self):
self._worker_watchdog_thread.start()
self._monitoring_handler_thread.start()

logger.info("Loop start")
logger.info("Manager threads started")

# TODO : Add mechanism in this loop to stop the worker pool
# This might need a multiprocessing event to signal back.
self._kill_event.wait()
logger.critical("Received kill event, terminating worker processes")
Expand Down Expand Up @@ -804,6 +836,8 @@ def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_
help="Heartbeat period in seconds. Uses manager default unless set")
parser.add_argument("--hb_threshold", default=120,
help="Heartbeat threshold in seconds. Uses manager default unless set")
parser.add_argument("--drain_period", default=None,
help="Drain this pool after specified number of seconds. By default, does not drain.")
parser.add_argument("--address_probe_timeout", default=30,
help="Timeout to probe for viable address to interchange. Default: 30s")
parser.add_argument("--poll", default=10,
Expand Down Expand Up @@ -856,6 +890,7 @@ def strategyorlist(s: str):
logger.info("Prefetch capacity: {}".format(args.prefetch_capacity))
logger.info("Heartbeat threshold: {}".format(args.hb_threshold))
logger.info("Heartbeat period: {}".format(args.hb_period))
logger.info("Drain period: {}".format(args.drain_period))
logger.info("CPU affinity: {}".format(args.cpu_affinity))
logger.info("Accelerators: {}".format(" ".join(args.available_accelerators)))
logger.info("enable_mpi_mode: {}".format(args.enable_mpi_mode))
Expand All @@ -876,6 +911,7 @@ def strategyorlist(s: str):
prefetch_capacity=int(args.prefetch_capacity),
heartbeat_threshold=int(args.hb_threshold),
heartbeat_period=int(args.hb_period),
drain_period=None if args.drain_period == "None" else int(args.drain_period),
poll_period=int(args.poll),
cpu_affinity=args.cpu_affinity,
enable_mpi_mode=args.enable_mpi_mode,
Expand Down
78 changes: 78 additions & 0 deletions parsl/tests/test_htex/test_drain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import parsl
import pytest
import time

from parsl.providers import LocalProvider
from parsl.channels import LocalChannel
from parsl.launchers import SimpleLauncher

from parsl.config import Config
from parsl.executors import HighThroughputExecutor

# this constant is used to scale some durations that happen
# based around the expected drain period: the drain period
# is TIME_CONST seconds, and the single executed task will
# last twice that many number of seconds.
TIME_CONST = 1


def local_config():
return Config(
executors=[
HighThroughputExecutor(
label="htex_local",
drain_period=TIME_CONST,
worker_debug=True,
cores_per_worker=1,
encrypted=True,
provider=LocalProvider(
channel=LocalChannel(),
init_blocks=1,
min_blocks=0,
max_blocks=0,
launcher=SimpleLauncher(),
),
)
],
strategy='none',
)


@parsl.python_app
def f(n):
import time
time.sleep(n)


@pytest.mark.local
def test_drain(try_assert):

htex = parsl.dfk().executors['htex_local']

# wait till we have a block running...

try_assert(lambda: len(htex.connected_managers()) == 1)

managers = htex.connected_managers()
assert managers[0]['active'], "The manager should be active"
assert not managers[0]['draining'], "The manager should not be draining"

fut = f(TIME_CONST * 2)

time.sleep(TIME_CONST)

# this assert should happen *very fast* after the above delay...
try_assert(lambda: htex.connected_managers()[0]['draining'], timeout_ms=500)

# and the test task should still be running...
assert not fut.done(), "The test task should still be running"

fut.result()

# and now we should see the manager disappear...
# ... with strategy='none', this should be coming from draining but
# that information isn't immediately obvious from the absence in
# connected managers.
# As with the above draining assert, this should happen very fast after
# the task ends.
try_assert(lambda: len(htex.connected_managers()) == 0, timeout_ms=500)

0 comments on commit 920852f

Please sign in to comment.