Skip to content

Commit

Permalink
implement coordinated termination for all processes
Browse files Browse the repository at this point in the history
  • Loading branch information
odartsi committed Jul 3, 2024
1 parent bbbb759 commit bf53041
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 41 deletions.
34 changes: 20 additions & 14 deletions alab_management/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DeviceManager class, which will handle all the request to run certain methods on the real device.
"""

import multiprocessing
import time
from collections.abc import Callable
from concurrent.futures import Future
Expand Down Expand Up @@ -108,7 +109,7 @@ class DeviceManager:
executes commands on the device drivers, as requested by the task process.
"""

def __init__(self, _check_status: bool = True):
def __init__(self, _check_status: bool = True, live_time: float | None = None, termination_event=None):
"""
Args:
_check_status: Check if the task currently occupied this device when
Expand All @@ -127,23 +128,28 @@ def __init__(self, _check_status: bool = True):
self._device_view = DeviceView(connect_to_devices=True)
self._check_status = _check_status
self.threads = []
self.live_time = live_time
self.termination_event = termination_event or multiprocessing.Event()

def run(self):
"""Start to listen on the device_rpc queue and conduct the command one by one."""
self.connection = get_rabbitmq_connection()
with self.connection.channel() as channel:
channel.queue_declare(
queue=self._rpc_queue_name,
auto_delete=True,
exclusive=False,
)
channel.basic_consume(
queue=self._rpc_queue_name,
on_message_callback=self.on_message,
auto_ack=False,
consumer_tag=self._rpc_queue_name,
)
channel.start_consuming()
start = time.time
while not self.termination_event.is_set() and (self.live_time is None or time.time() - start < self.live_time):

with self.connection.channel() as channel:
channel.queue_declare(
queue=self._rpc_queue_name,
auto_delete=True,
exclusive=False,
)
channel.basic_consume(
queue=self._rpc_queue_name,
on_message_callback=self.on_message,
auto_ack=False,
consumer_tag=self._rpc_queue_name,
)
channel.start_consuming()

def _execute_command_wrapper(
self,
Expand Down
9 changes: 7 additions & 2 deletions alab_management/experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
done.
"""

import multiprocessing
import time
from typing import Any

Expand All @@ -23,7 +24,7 @@ class ExperimentManager:
and submit the experiment to executor and flag the completed experiments.
"""

def __init__(self):
def __init__(self, live_time: float | None = None, termination_event=None):
self.experiment_view = ExperimentView()
self.task_view = TaskView()
self.sample_view = SampleView()
Expand All @@ -36,6 +37,9 @@ def __init__(self):
if self.__copy_to_completed_db:
self.completed_experiment_view = CompletedExperimentView()

self.live_time = live_time
self.termination_event = termination_event or multiprocessing.Event()

def run(self):
"""Start the event loop."""
self.logger.system_log(
Expand All @@ -45,7 +49,8 @@ def run(self):
"type": "ExperimentManagerStarted",
},
)
while True:
start = time.time()
while not self.termination_event.is_set() and (self.live_time is None or time.time() - start < self.live_time):
self._loop()
time.sleep(1)

Expand Down
8 changes: 6 additions & 2 deletions alab_management/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
which actually executes the tasks.
"""

import multiprocessing
import time
from datetime import datetime
from traceback import format_exc
Expand Down Expand Up @@ -34,7 +35,7 @@ class ResourceManager(RequestMixin):
(2) handle all the resource requests
"""

def __init__(self):
def __init__(self, live_time: float | None = None, termination_event=None):
load_definition()
self.task_view = TaskView()
self.sample_view = SampleView()
Expand All @@ -44,10 +45,13 @@ def __init__(self):
self.logger = DBLogger(task_id=None)
super().__init__()
time.sleep(1) # allow some time for other modules to launch
self.live_time = live_time
self.termination_event = termination_event or multiprocessing.Event()

def run(self):
"""Start the loop."""
while True:
start = time.time()
while not self.termination_event.is_set() and (self.live_time is None or time.time() - start < self.live_time):
self._loop()
time.sleep(0.5)

Expand Down
48 changes: 27 additions & 21 deletions alab_management/scripts/launch_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
with contextlib.suppress(RuntimeError):
multiprocessing.set_start_method("spawn")

# Create a global termination event
termination_event = multiprocessing.Event()
class RestartableProcess:
"""A class for creating processes that can be automatically restarted after failures."""

Expand Down Expand Up @@ -62,7 +64,7 @@ def launch_experiment_manager():
from alab_management.utils.module_ops import load_definition

load_definition()
experiment_manager = ExperimentManager()
experiment_manager = ExperimentManager(live_time=3600, termination_event=termination_event)
experiment_manager.run()


Expand All @@ -72,7 +74,7 @@ def launch_task_manager():
from alab_management.utils.module_ops import load_definition

load_definition()
task_launcher = TaskManager(live_time=3600)
task_launcher = TaskManager(live_time=3600, termination_event=termination_event)
task_launcher.run()


Expand All @@ -82,7 +84,7 @@ def launch_device_manager():
from alab_management.utils.module_ops import load_definition

load_definition()
device_manager = DeviceManager()
device_manager = DeviceManager(live_time=3600, termination_event=termination_event)
device_manager.run()


Expand All @@ -92,7 +94,7 @@ def launch_resource_manager():
from alab_management.utils.module_ops import load_definition

load_definition()
resource_manager = ResourceManager()
resource_manager = ResourceManager(live_time=3600, termination_event=termination_event)
resource_manager.run()


Expand All @@ -108,20 +110,24 @@ def launch_lab(host, port, debug):
)
sys.exit(1)

# Create RestartableProcess objects for each process
dashboard_process = RestartableProcess(target=launch_dashboard, args=(host, port, debug), live_time=3600) # Restart every hour
experiment_manager_process = RestartableProcess(target=launch_experiment_manager, args=(host, port, debug), live_time=3600)
task_launcher_process = RestartableProcess(target=launch_task_manager, args=(host, port, debug), live_time=3600)
device_manager_process = RestartableProcess(target=launch_device_manager, args=(host, port, debug), live_time=3600)
resource_manager_process = RestartableProcess(target=launch_resource_manager, args=(host, port, debug), live_time=3600)

# Start the processes
dashboard_process.run()
experiment_manager_process.run()
task_launcher_process.run()
device_manager_process.run()
resource_manager_process.run()

"""With RestartableProcess, each process is designed to handle restarts automatically.
So, there's no need to worry about the program exiting before background tasks finish -
they will be restarted by RestartableProcess if necessary."""
# Create RestartableProcess objects for each process with shared termination_event
dashboard_process = RestartableProcess(target=launch_dashboard, args=(host, port, debug), live_time=3600, termination_event=termination_event)
experiment_manager_process = RestartableProcess(target=launch_experiment_manager, args=(host, port, debug), live_time=3600, termination_event=termination_event)
task_launcher_process = RestartableProcess(target=launch_task_manager, args=(host, port, debug), live_time=3600, termination_event=termination_event)
device_manager_process = RestartableProcess(target=launch_device_manager, args=(host, port, debug), live_time=3600, termination_event=termination_event)
resource_manager_process = RestartableProcess(target=launch_resource_manager, args=(host, port, debug), live_time=3600, termination_event=termination_event)

# Start the processes in separate threads to allow termination event setting
processes = [dashboard_process, experiment_manager_process, task_launcher_process, device_manager_process, resource_manager_process]

threads = []
for process in processes:
thread = Thread(target=process.run)
thread.start()
threads.append(thread)

return threads

def terminate_all_processes():
"""Set the termination event to stop all processes."""
termination_event.set()
6 changes: 4 additions & 2 deletions alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
which actually executes the tasks.
"""

import multiprocessing
import time

from dramatiq_abort import abort, abort_requested
Expand All @@ -22,18 +23,19 @@ class TaskManager:
(2) handle all the resource requests
"""

def __init__(self, live_time: float | None = None):
def __init__(self, live_time: float | None = None, termination_event=None):
load_definition()
self.task_view = TaskView()
self.logger = DBLogger(task_id=None)
super().__init__()
time.sleep(1) # allow some time for other modules to launch
self.live_time = live_time
self.termination_event = termination_event or multiprocessing.Event()

def run(self):
"""Start the loop."""
start = time.time()
while (time.time() - start) < self.live_time:
while not self.termination_event.is_set() and (self.live_time is None or time.time() - start < self.live_time):
self._loop()
time.sleep(1)

Expand Down

0 comments on commit bf53041

Please sign in to comment.