Skip to content

Commit

Permalink
finish step 1. (To be tested with fixed launch_lab.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
idocx committed Jul 3, 2024
1 parent fdb59e3 commit 49aa9e8
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 60 deletions.
41 changes: 21 additions & 20 deletions alab_management/device_view/device_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import time
from collections.abc import Collection
from datetime import datetime
from enum import auto, Enum, unique
from typing import Any, cast, TypeVar
from enum import Enum, auto, unique
from typing import Any, TypeVar, cast

import pymongo # type: ignore
from bson import ObjectId # type: ignore

from alab_management.sample_view import SamplePosition, SampleView
from alab_management.utils.data_objects import get_collection, get_lock

from .device import BaseDevice, get_all_devices

_DeviceType = TypeVar("_DeviceType", bound=BaseDevice) # pylint: disable=invalid-name
Expand Down Expand Up @@ -161,12 +162,12 @@ def _clean_up_device_collection(self):
self._device_collection.drop()

def request_devices(
self,
task_id: ObjectId,
device_names_str: Collection[str] | None = None,
device_types_str: (
Collection[str] | None
) = None, # pylint: disable=unsubscriptable-object
self,
task_id: ObjectId,
device_names_str: Collection[str] | None = None,
device_types_str: (
Collection[str] | None
) = None, # pylint: disable=unsubscriptable-object
) -> dict[str, dict[str, str | bool]] | None:
"""
Request a list of device, this function will return the name of devices if all the requested device is ready.
Expand Down Expand Up @@ -231,7 +232,7 @@ def request_devices(
return idle_devices

def get_available_devices(
self, device_str: str, type_or_name: str, task_id: ObjectId | None = None
self, device_str: str, type_or_name: str, task_id: ObjectId | None = None
) -> list[dict[str, str | bool]]:
"""
Given device type, it will return all the device with this type.
Expand Down Expand Up @@ -311,8 +312,8 @@ def occupy_device(self, device: BaseDevice | str, task_id: ObjectId):
device_name = device.name if isinstance(device, BaseDevice) else device
# Wait until the device status has been updated to OCCUPIED
while (
self.get_status(device_name=device_name).name
!= DeviceTaskStatus.OCCUPIED.name
self.get_status(device_name=device_name).name
!= DeviceTaskStatus.OCCUPIED.name
):
time.sleep(0.5)

Expand All @@ -338,8 +339,8 @@ def release_device(self, device_name: str):
}

if (
DevicePauseStatus[device_entry["pause_status"]]
== DevicePauseStatus.REQUESTED
DevicePauseStatus[device_entry["pause_status"]]
== DevicePauseStatus.REQUESTED
):
update_dict.update(
{
Expand Down Expand Up @@ -367,11 +368,11 @@ def get_samples_on_device(self, device_name: str):
return samples_per_position

def _update_status(
self,
device: BaseDevice | str,
required_status: DeviceTaskStatus | list[DeviceTaskStatus] | None,
target_status: DeviceTaskStatus,
task_id: ObjectId | None,
self,
device: BaseDevice | str,
required_status: DeviceTaskStatus | list[DeviceTaskStatus] | None,
target_status: DeviceTaskStatus,
task_id: ObjectId | None,
):
"""
A method that check and update the status of a device.
Expand Down Expand Up @@ -399,8 +400,8 @@ def _update_status(
required_status = None

if (
required_status is not None
and DeviceTaskStatus[device_entry["status"]] not in required_status
required_status is not None
and DeviceTaskStatus[device_entry["status"]] not in required_status
):
raise ValueError(
f"Device's current status ({device_entry['status']}) is "
Expand Down
2 changes: 1 addition & 1 deletion alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def request_resources(
sample_positions = result["sample_positions"]
self._task_view.update_status(task_id=self.task_id, status=TaskStatus.RUNNING)
yield devices, sample_positions

# todo: disconnect devices
self._resource_requester.release_resources(request_id=request_id)

def _sample_name_to_id(self, sample_name: str) -> ObjectId:
Expand Down
110 changes: 76 additions & 34 deletions alab_management/scripts/launch_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,54 @@
import sys
import time
from threading import Thread

from gevent.pywsgi import WSGIServer # type: ignore
from multiprocessing import Process

with contextlib.suppress(RuntimeError):
multiprocessing.set_start_method("spawn")

# Create a global termination event
termination_event = multiprocessing.Event()


class RestartableProcess:
"""A class for creating processes that can be automatically restarted after failures."""
def __init__(self, target_func, live_time=None):
self.target_func = target_func

def __init__(self, target, args=(), live_time=None, termination_event=None):
self.target = target
self.live_time = live_time
self.process = None
self.args = args
self.termination_event = termination_event or multiprocessing.Event()

def run(self):
start = time.time()
while not self.termination_event.is_set() and (self.live_time is None or time.time() - start < self.live_time):
while not self.termination_event.is_set() and (
self.live_time is None or time.time() - start < self.live_time
):
try:
process = multiprocessing.Process(target=self.target, args=self.args)
process.start()
process.join() # Wait for process to finish

# Check exit code and restart if needed
if self.process.exitcode == 0:
print(f"Process {self.process.name} exited normally. Restarting...")
else:
print(f"Process {self.process.name} exited with code {self.process.exitcode}.")
time.sleep(self.live_time or 0) # Restart after live_time or immediately if None
# Check exit code, handle errors, and restart if needed
if process.exitcode == 0:
print(f"Process {process.name} exited normally. Restarting...")
else:
print(
f"Process {process.name} exited with code {process.exitcode}."
)
except Exception as e:
print(f"Error occurred while running process: {e}")

# Check for termination before restarting
if self.termination_event.is_set():
break

time.sleep(
self.live_time or 0
) # Restart after live_time or immediately if None


def launch_dashboard(host: str, port: int, debug: bool = False):
"""Launch the dashboard alone."""
Expand All @@ -54,7 +75,9 @@ def launch_experiment_manager():
from alab_management.utils.module_ops import load_definition

load_definition()
experiment_manager = ExperimentManager(live_time=3600, termination_event=termination_event)
experiment_manager = ExperimentManager(
live_time=3600, termination_event=termination_event
)
experiment_manager.run()


Expand All @@ -68,23 +91,15 @@ def launch_task_manager():
task_launcher.run()


def launch_device_manager():
"""Launch the device manager."""
from alab_management.device_manager import DeviceManager
from alab_management.utils.module_ops import load_definition

load_definition()
device_manager = DeviceManager()
device_manager.run()


def launch_resource_manager():
"""Launch the resource manager."""
from alab_management.resource_manager.resource_manager import ResourceManager
from alab_management.utils.module_ops import load_definition

load_definition()
resource_manager = ResourceManager(live_time=3600, termination_event=termination_event)
resource_manager = ResourceManager(
live_time=3600, termination_event=termination_event
)
resource_manager.run()


Expand All @@ -100,22 +115,49 @@ def launch_lab(host, port, debug):
)
sys.exit(1)

# Create RestartableProcess objects for each process
dashboard_process = RestartableProcess(target=launch_dashboard, args=(host, port, debug), live_time=3600) # Restart every hour
experiment_manager_process = RestartableProcess(target=launch_experiment_manager)
task_launcher_process = RestartableProcess(target=launch_task_manager)
device_manager_process = RestartableProcess(target=launch_device_manager)
resource_manager_process = RestartableProcess(target=launch_resource_manager)
# Create RestartableProcess objects for each process with shared termination_event
dashboard_process = RestartableProcess(
target=launch_dashboard,
args=(host, port, debug),
live_time=3600,
termination_event=termination_event,
)
experiment_manager_process = RestartableProcess(
target=launch_experiment_manager,
args=(host, port, debug),
live_time=3600,
termination_event=termination_event,
)
task_launcher_process = RestartableProcess(
target=launch_task_manager,
args=(host, port, debug),
live_time=3600,
termination_event=termination_event,
)
resource_manager_process = RestartableProcess(
target=launch_resource_manager,
args=(host, port, debug),
live_time=3600,
termination_event=termination_event,
)

# Start the processes in separate threads to allow termination event setting
processes = [
dashboard_process,
experiment_manager_process,
task_launcher_process,
resource_manager_process,
]

# Start the processes
dashboard_process.run()
experiment_manager_process.run()
task_launcher_process.run()
device_manager_process.run()
resource_manager_process.run()
threads = []
for process in processes:
thread = Thread(target=process.run)
thread.start()
threads.append(thread)

return threads


def terminate_all_processes():
"""Set the termination event to stop all processes."""
termination_event.set()
13 changes: 11 additions & 2 deletions tests/fake_lab/devices/robot_arm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import ClassVar

from alab_management import value_in_database
from alab_management.device_view import BaseDevice
from alab_management.sample_view import SamplePosition


class RobotArm(BaseDevice):
description: ClassVar[str] = "Fake robot arm"
ensure_single_creation = value_in_database("ensure_single_creation", 0)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def sample_positions(self):
Expand All @@ -26,7 +31,11 @@ def is_running(self) -> bool:
return False

def connect(self):
pass
if self.ensure_single_creation != 0:
raise ValueError("Robot arm already connected")
self.ensure_single_creation = 1

def disconnect(self):
pass
if self.ensure_single_creation != 1:
raise ValueError("Robot arm not connected")
self.ensure_single_creation = 0
43 changes: 40 additions & 3 deletions tests/test_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def tearDown(self) -> None:
def test_task_requester(self):
furnace_type = self.devices["furnace_1"].__class__

# 1
# 1A
result = self.resource_requester.request_resources(
{furnace_type: {"inside": 1}}, timeout=4
{furnace_type: {"inside": 1}}, timeout=4, return_device_instance=False
)
_id = result.pop("request_id")
self.assertDictEqual(
Expand All @@ -95,11 +95,48 @@ def test_task_requester(self):
(SamplePositionStatus.EMPTY, None),
)

# 1B
result = self.resource_requester.request_resources(
{furnace_type: {"inside": 1}}, timeout=4, return_device_instance=True
)
result["devices"] = {
device_type: device.name
for device_type, device in result["devices"].items()
}
_id = result.pop("request_id")
self.assertDictEqual(
{
"devices": {furnace_type: "furnace_1"},
"sample_positions": {furnace_type: {"inside": ["furnace_1/inside/1"]}},
},
result,
)
self.assertEqual(
self.device_view.get_status("furnace_1"), DeviceTaskStatus.OCCUPIED
)
self.assertEqual(
self.sample_view.get_sample_position_status("furnace_1/inside/1"),
(SamplePositionStatus.LOCKED, self.resource_requester.task_id),
)
self.resource_requester.release_resources(_id)
time.sleep(0.5)
self.assertEqual(
self.device_view.get_status("furnace_1"), DeviceTaskStatus.IDLE
)
self.assertEqual(
self.sample_view.get_sample_position_status("furnace_1/inside/1"),
(SamplePositionStatus.EMPTY, None),
)

# 2
result = self.resource_requester.request_resources(
{furnace_type: {"inside": 1}}, timeout=4
{furnace_type: {"inside": 1}}, timeout=4, return_device_instance=True
)
_id = result.pop("request_id")
result["devices"] = {
device_type: device.name
for device_type, device in result["devices"].items()
}
self.assertDictEqual(
{
"devices": {furnace_type: "furnace_1"},
Expand Down

0 comments on commit 49aa9e8

Please sign in to comment.