From 49aa9e84ef612218ae1eb4a4ae8ce6c0e9fbdddd Mon Sep 17 00:00:00 2001 From: Yuxing Fei Date: Wed, 3 Jul 2024 11:46:56 -0700 Subject: [PATCH] finish step 1. (To be tested with fixed launch_lab.py) --- alab_management/device_view/device_view.py | 41 ++++---- alab_management/lab_view.py | 2 +- alab_management/scripts/launch_lab.py | 110 ++++++++++++++------- tests/fake_lab/devices/robot_arm.py | 13 ++- tests/test_task_manager.py | 43 +++++++- 5 files changed, 149 insertions(+), 60 deletions(-) diff --git a/alab_management/device_view/device_view.py b/alab_management/device_view/device_view.py index e798a6f3..4ef6f149 100644 --- a/alab_management/device_view/device_view.py +++ b/alab_management/device_view/device_view.py @@ -3,14 +3,15 @@ import time from collections.abc import Collection from datetime import datetime -from enum import auto, Enum, unique -from typing import Any, cast, TypeVar +from enum import Enum, auto, unique +from typing import Any, TypeVar, cast import pymongo # type: ignore from bson import ObjectId # type: ignore from alab_management.sample_view import SamplePosition, SampleView from alab_management.utils.data_objects import get_collection, get_lock + from .device import BaseDevice, get_all_devices _DeviceType = TypeVar("_DeviceType", bound=BaseDevice) # pylint: disable=invalid-name @@ -161,12 +162,12 @@ def _clean_up_device_collection(self): self._device_collection.drop() def request_devices( - self, - task_id: ObjectId, - device_names_str: Collection[str] | None = None, - device_types_str: ( - Collection[str] | None - ) = None, # pylint: disable=unsubscriptable-object + self, + task_id: ObjectId, + device_names_str: Collection[str] | None = None, + device_types_str: ( + Collection[str] | None + ) = None, # pylint: disable=unsubscriptable-object ) -> dict[str, dict[str, str | bool]] | None: """ Request a list of device, this function will return the name of devices if all the requested device is ready. @@ -231,7 +232,7 @@ def request_devices( return idle_devices def get_available_devices( - self, device_str: str, type_or_name: str, task_id: ObjectId | None = None + self, device_str: str, type_or_name: str, task_id: ObjectId | None = None ) -> list[dict[str, str | bool]]: """ Given device type, it will return all the device with this type. @@ -311,8 +312,8 @@ def occupy_device(self, device: BaseDevice | str, task_id: ObjectId): device_name = device.name if isinstance(device, BaseDevice) else device # Wait until the device status has been updated to OCCUPIED while ( - self.get_status(device_name=device_name).name - != DeviceTaskStatus.OCCUPIED.name + self.get_status(device_name=device_name).name + != DeviceTaskStatus.OCCUPIED.name ): time.sleep(0.5) @@ -338,8 +339,8 @@ def release_device(self, device_name: str): } if ( - DevicePauseStatus[device_entry["pause_status"]] - == DevicePauseStatus.REQUESTED + DevicePauseStatus[device_entry["pause_status"]] + == DevicePauseStatus.REQUESTED ): update_dict.update( { @@ -367,11 +368,11 @@ def get_samples_on_device(self, device_name: str): return samples_per_position def _update_status( - self, - device: BaseDevice | str, - required_status: DeviceTaskStatus | list[DeviceTaskStatus] | None, - target_status: DeviceTaskStatus, - task_id: ObjectId | None, + self, + device: BaseDevice | str, + required_status: DeviceTaskStatus | list[DeviceTaskStatus] | None, + target_status: DeviceTaskStatus, + task_id: ObjectId | None, ): """ A method that check and update the status of a device. @@ -399,8 +400,8 @@ def _update_status( required_status = None if ( - required_status is not None - and DeviceTaskStatus[device_entry["status"]] not in required_status + required_status is not None + and DeviceTaskStatus[device_entry["status"]] not in required_status ): raise ValueError( f"Device's current status ({device_entry['status']}) is " diff --git a/alab_management/lab_view.py b/alab_management/lab_view.py index 23707111..893930c6 100644 --- a/alab_management/lab_view.py +++ b/alab_management/lab_view.py @@ -125,7 +125,7 @@ def request_resources( sample_positions = result["sample_positions"] self._task_view.update_status(task_id=self.task_id, status=TaskStatus.RUNNING) yield devices, sample_positions - + # todo: disconnect devices self._resource_requester.release_resources(request_id=request_id) def _sample_name_to_id(self, sample_name: str) -> ObjectId: diff --git a/alab_management/scripts/launch_lab.py b/alab_management/scripts/launch_lab.py index 1e998fda..d47f4b1d 100644 --- a/alab_management/scripts/launch_lab.py +++ b/alab_management/scripts/launch_lab.py @@ -5,33 +5,54 @@ import sys import time from threading import Thread + from gevent.pywsgi import WSGIServer # type: ignore -from multiprocessing import Process + with contextlib.suppress(RuntimeError): multiprocessing.set_start_method("spawn") +# Create a global termination event +termination_event = multiprocessing.Event() + + class RestartableProcess: """A class for creating processes that can be automatically restarted after failures.""" - def __init__(self, target_func, live_time=None): - self.target_func = target_func + + def __init__(self, target, args=(), live_time=None, termination_event=None): + self.target = target self.live_time = live_time self.process = None + self.args = args self.termination_event = termination_event or multiprocessing.Event() def run(self): start = time.time() - while not self.termination_event.is_set() and (self.live_time is None or time.time() - start < self.live_time): + while not self.termination_event.is_set() and ( + self.live_time is None or time.time() - start < self.live_time + ): try: process = multiprocessing.Process(target=self.target, args=self.args) process.start() process.join() # Wait for process to finish - # Check exit code and restart if needed - if self.process.exitcode == 0: - print(f"Process {self.process.name} exited normally. Restarting...") - else: - print(f"Process {self.process.name} exited with code {self.process.exitcode}.") - time.sleep(self.live_time or 0) # Restart after live_time or immediately if None + # Check exit code, handle errors, and restart if needed + if process.exitcode == 0: + print(f"Process {process.name} exited normally. Restarting...") + else: + print( + f"Process {process.name} exited with code {process.exitcode}." + ) + except Exception as e: + print(f"Error occurred while running process: {e}") + + # Check for termination before restarting + if self.termination_event.is_set(): + break + + time.sleep( + self.live_time or 0 + ) # Restart after live_time or immediately if None + def launch_dashboard(host: str, port: int, debug: bool = False): """Launch the dashboard alone.""" @@ -54,7 +75,9 @@ def launch_experiment_manager(): from alab_management.utils.module_ops import load_definition load_definition() - experiment_manager = ExperimentManager(live_time=3600, termination_event=termination_event) + experiment_manager = ExperimentManager( + live_time=3600, termination_event=termination_event + ) experiment_manager.run() @@ -68,23 +91,15 @@ def launch_task_manager(): task_launcher.run() -def launch_device_manager(): - """Launch the device manager.""" - from alab_management.device_manager import DeviceManager - from alab_management.utils.module_ops import load_definition - - load_definition() - device_manager = DeviceManager() - device_manager.run() - - def launch_resource_manager(): """Launch the resource manager.""" from alab_management.resource_manager.resource_manager import ResourceManager from alab_management.utils.module_ops import load_definition load_definition() - resource_manager = ResourceManager(live_time=3600, termination_event=termination_event) + resource_manager = ResourceManager( + live_time=3600, termination_event=termination_event + ) resource_manager.run() @@ -100,22 +115,49 @@ def launch_lab(host, port, debug): ) sys.exit(1) - # Create RestartableProcess objects for each process - dashboard_process = RestartableProcess(target=launch_dashboard, args=(host, port, debug), live_time=3600) # Restart every hour - experiment_manager_process = RestartableProcess(target=launch_experiment_manager) - task_launcher_process = RestartableProcess(target=launch_task_manager) - device_manager_process = RestartableProcess(target=launch_device_manager) - resource_manager_process = RestartableProcess(target=launch_resource_manager) + # Create RestartableProcess objects for each process with shared termination_event + dashboard_process = RestartableProcess( + target=launch_dashboard, + args=(host, port, debug), + live_time=3600, + termination_event=termination_event, + ) + experiment_manager_process = RestartableProcess( + target=launch_experiment_manager, + args=(host, port, debug), + live_time=3600, + termination_event=termination_event, + ) + task_launcher_process = RestartableProcess( + target=launch_task_manager, + args=(host, port, debug), + live_time=3600, + termination_event=termination_event, + ) + resource_manager_process = RestartableProcess( + target=launch_resource_manager, + args=(host, port, debug), + live_time=3600, + termination_event=termination_event, + ) + + # Start the processes in separate threads to allow termination event setting + processes = [ + dashboard_process, + experiment_manager_process, + task_launcher_process, + resource_manager_process, + ] - # Start the processes - dashboard_process.run() - experiment_manager_process.run() - task_launcher_process.run() - device_manager_process.run() - resource_manager_process.run() + threads = [] + for process in processes: + thread = Thread(target=process.run) + thread.start() + threads.append(thread) return threads + def terminate_all_processes(): """Set the termination event to stop all processes.""" termination_event.set() diff --git a/tests/fake_lab/devices/robot_arm.py b/tests/fake_lab/devices/robot_arm.py index 48e5cf03..d6b98f61 100644 --- a/tests/fake_lab/devices/robot_arm.py +++ b/tests/fake_lab/devices/robot_arm.py @@ -1,11 +1,16 @@ from typing import ClassVar +from alab_management import value_in_database from alab_management.device_view import BaseDevice from alab_management.sample_view import SamplePosition class RobotArm(BaseDevice): description: ClassVar[str] = "Fake robot arm" + ensure_single_creation = value_in_database("ensure_single_creation", 0) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) @property def sample_positions(self): @@ -26,7 +31,11 @@ def is_running(self) -> bool: return False def connect(self): - pass + if self.ensure_single_creation != 0: + raise ValueError("Robot arm already connected") + self.ensure_single_creation = 1 def disconnect(self): - pass + if self.ensure_single_creation != 1: + raise ValueError("Robot arm not connected") + self.ensure_single_creation = 0 diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index a29a152a..56216805 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -66,9 +66,9 @@ def tearDown(self) -> None: def test_task_requester(self): furnace_type = self.devices["furnace_1"].__class__ - # 1 + # 1A result = self.resource_requester.request_resources( - {furnace_type: {"inside": 1}}, timeout=4 + {furnace_type: {"inside": 1}}, timeout=4, return_device_instance=False ) _id = result.pop("request_id") self.assertDictEqual( @@ -95,11 +95,48 @@ def test_task_requester(self): (SamplePositionStatus.EMPTY, None), ) + # 1B + result = self.resource_requester.request_resources( + {furnace_type: {"inside": 1}}, timeout=4, return_device_instance=True + ) + result["devices"] = { + device_type: device.name + for device_type, device in result["devices"].items() + } + _id = result.pop("request_id") + self.assertDictEqual( + { + "devices": {furnace_type: "furnace_1"}, + "sample_positions": {furnace_type: {"inside": ["furnace_1/inside/1"]}}, + }, + result, + ) + self.assertEqual( + self.device_view.get_status("furnace_1"), DeviceTaskStatus.OCCUPIED + ) + self.assertEqual( + self.sample_view.get_sample_position_status("furnace_1/inside/1"), + (SamplePositionStatus.LOCKED, self.resource_requester.task_id), + ) + self.resource_requester.release_resources(_id) + time.sleep(0.5) + self.assertEqual( + self.device_view.get_status("furnace_1"), DeviceTaskStatus.IDLE + ) + self.assertEqual( + self.sample_view.get_sample_position_status("furnace_1/inside/1"), + (SamplePositionStatus.EMPTY, None), + ) + # 2 result = self.resource_requester.request_resources( - {furnace_type: {"inside": 1}}, timeout=4 + {furnace_type: {"inside": 1}}, timeout=4, return_device_instance=True ) _id = result.pop("request_id") + result["devices"] = { + device_type: device.name + for device_type, device in result["devices"].items() + } self.assertDictEqual( { "devices": {furnace_type: "furnace_1"},