From 3e0e9b5b42cc1f0d4874fc9197bb04b822143f0e Mon Sep 17 00:00:00 2001 From: Yuxing Fei Date: Mon, 6 May 2024 22:10:54 -0700 Subject: [PATCH 1/3] temp commit --- alab_management/lab_view.py | 26 ++- .../resource_manager/resource_manager.py | 32 ++-- .../resource_manager/resource_requester.py | 148 +++++------------- alab_management/utils/data_objects.py | 4 + tests/fake_lab/tasks/infinite_task.py | 7 +- tests/test_launch.py | 77 +++++---- tests/test_task_manager.py | 5 - 7 files changed, 125 insertions(+), 174 deletions(-) diff --git a/alab_management/lab_view.py b/alab_management/lab_view.py index 7c7010ee..340e4e1e 100644 --- a/alab_management/lab_view.py +++ b/alab_management/lab_view.py @@ -120,22 +120,18 @@ def request_resources( resource_request=resource_request, timeout=timeout, priority=priority ) request_id = result["request_id"] - timeout_error = result["timeout_error"] - if timeout_error: - raise TimeoutError - else: - devices = result["devices"] - sample_positions = result["sample_positions"] - devices = { - device_type: self._device_client.create_device_wrapper(device_name) - for device_type, device_name in devices.items() - } # type: ignore - self._task_view.update_status( - task_id=self.task_id, status=TaskStatus.RUNNING - ) - yield devices, sample_positions + devices = result["devices"] + sample_positions = result["sample_positions"] + devices = { + device_type: self._device_client.create_device_wrapper(device_name) + for device_type, device_name in devices.items() + } # type: ignore + self._task_view.update_status( + task_id=self.task_id, status=TaskStatus.RUNNING + ) + yield devices, sample_positions - self._resource_requester.release_resources(request_id=request_id) + self._resource_requester.release_resources(request_id=request_id) def _sample_name_to_id(self, sample_name: str) -> ObjectId: """ diff --git a/alab_management/resource_manager/resource_manager.py b/alab_management/resource_manager/resource_manager.py index f3765366..a3aa908d 100644 --- a/alab_management/resource_manager/resource_manager.py +++ b/alab_management/resource_manager/resource_manager.py @@ -5,6 +5,7 @@ import time from datetime import datetime +from traceback import format_exc from typing import Any, cast import dill @@ -21,7 +22,7 @@ from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView from alab_management.task_view import TaskView from alab_management.task_view.task_enums import TaskStatus -from alab_management.utils.data_objects import get_collection +from alab_management.utils.data_objects import DocumentNotUpdatedError, get_collection from alab_management.utils.module_ops import load_definition @@ -62,7 +63,7 @@ def handle_released_resources(self): self._release_devices(devices) self._release_sample_positions(sample_positions) self.update_request_status( - request_id=request["_id"], status=RequestStatus.RELEASED + request_id=request["_id"], status=RequestStatus.RELEASED, original_status=RequestStatus.NEED_RELEASE ) def handle_requested_resources(self): @@ -89,6 +90,7 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]): self.update_request_status( request_id=request_entry["_id"], status=RequestStatus.CANCELED, + original_status=RequestStatus.PENDING, ) return @@ -147,8 +149,10 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]): # in case some errors happen, we will raise the error in the task process instead of the main process except Exception as error: # pylint: disable=broad-except - self._request_collection.update_one( - {"_id": request_entry["_id"]}, + # we will store the error in the database for easier debugging + error.args = (format_exc(),) + returned_value = self._request_collection.update_one( + {"_id": request_entry["_id"], "status": RequestStatus.PENDING.name}, { "$set": { "status": RequestStatus.ERROR.name, @@ -158,11 +162,16 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]): } }, ) + if returned_value.modified_count != 1: + raise DocumentNotUpdatedError( + f"Error updating request {request_entry['_id']}: cannot update the request status from PENDING " + f"to ERROR." + ) from error return # if both devices and sample positions can be satisfied - self._request_collection.update_one( - {"_id": request_entry["_id"]}, + returned_value = self._request_collection.update_one( + {"_id": request_entry["_id"], "status": RequestStatus.PENDING.name}, { "$set": { "assigned_devices": devices, @@ -172,11 +181,12 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]): } }, ) - # label the resources as occupied - self._occupy_devices(devices=devices, task_id=task_id) - self._occupy_sample_positions( - sample_positions=sample_positions, task_id=task_id - ) + if returned_value.modified_count == 1: + # label the resources as occupied + self._occupy_devices(devices=devices, task_id=task_id) + self._occupy_sample_positions( + sample_positions=sample_positions, task_id=task_id + ) def _occupy_devices(self, devices: dict[str, dict[str, Any]], task_id: ObjectId): for device in devices.values(): diff --git a/alab_management/resource_manager/resource_requester.py b/alab_management/resource_manager/resource_requester.py index 06f1decb..2988b5c5 100644 --- a/alab_management/resource_manager/resource_requester.py +++ b/alab_management/resource_manager/resource_requester.py @@ -12,7 +12,6 @@ import dill from bson import ObjectId -from dramatiq_abort import Abort from pydantic import BaseModel, root_validator from alab_management.device_view.device import BaseDevice @@ -21,7 +20,7 @@ from alab_management.sample_view.sample import SamplePosition from alab_management.sample_view.sample_view import SamplePositionRequest from alab_management.task_view import TaskPriority -from alab_management.utils.data_objects import get_collection +from alab_management.utils.data_objects import DocumentNotUpdatedError, get_collection _SampleRequestDict = dict[str, int] _ResourceRequestDict = dict[ @@ -91,16 +90,29 @@ class RequestMixin: def __init__(self): self._request_collection = get_collection("requests") - def update_request_status(self, request_id: ObjectId, status: RequestStatus): + def update_request_status(self, request_id: ObjectId, status: RequestStatus, + original_status: RequestStatus | list[RequestStatus] = None): """Update the status of a request by request_id.""" - value_returned = self._request_collection.update_one( - {"_id": request_id}, {"$set": {"status": status.name}} - ) - # wait for the request to be updated - while ( - self.get_request(request_id, projection=["status"])["status"] != status.name - ): - time.sleep(0.5) + if original_status is not None: + if isinstance(original_status, list): + value_returned = self._request_collection.update_one( + {"_id": request_id, "status": {"$in": [status.name for status in original_status]}}, + {"$set": {"status": status.name}}, + ) + else: + value_returned = self._request_collection.update_one( + {"_id": request_id, "status": original_status.name}, + {"$set": {"status": status.name}}, + ) + else: + value_returned = self._request_collection.update_one( + {"_id": request_id}, {"$set": {"status": status.name}} + ) + if value_returned.modified_count == 0: + raise DocumentNotUpdatedError( + f"Request {request_id} was not updated to {status.name}, " + f"because it is not in {original_status.name} status." + ) return value_returned def get_request(self, request_id: ObjectId, **kwargs) -> dict[str, Any] | None: @@ -219,38 +231,7 @@ def request_resources( ) # DB_ACCESS_OUTSIDE_VIEW _id: ObjectId = cast(ObjectId, result.inserted_id) self._waiting[_id] = {"f": f, "device_str_to_request": device_str_to_request} - - try: - # wait for the request to be fulfilled - start_time = time.time() - while self.get_request(_id, projection=["status"])["status"] not in [ - RequestStatus.FULFILLED.name, - RequestStatus.CANCELED.name, - RequestStatus.ERROR.name, - ]: - if timeout is not None and time.time() - start_time > timeout: - raise TimeoutError - time.sleep(0.1) - - result = f.result(timeout=None) - except ( - TimeoutError - ): # cancel the task if timeout, make sure it is not fulfilled - if ( - self.get_request(_id, projection=["status"])["status"] - != RequestStatus.FULFILLED.name - ): - self.update_request_status( - request_id=_id, status=RequestStatus.CANCELED - ) - # wait for the request status to be updated - while (self.get_request(_id, projection=["status"]))[ - "status" - ] != RequestStatus.CANCELED.name: - time.sleep(0.5) - return {"request_id": _id, "timeout_error": True} - else: # if the request is fulfilled, return the result normally, wrong timeout - result = f.result(timeout=None) + result = f.result(timeout=timeout) return { **self._post_process_requested_resource( devices=result["devices"], @@ -258,7 +239,6 @@ def request_resources( resource_request=resource_request, ), "request_id": result["request_id"], - "timeout_error": False, } def release_resources(self, request_id: ObjectId): @@ -269,35 +249,15 @@ def release_resources(self, request_id: ObjectId): if ("assigned_devices" in request) or ( "assigned_sample_positions" in request ): - self.update_request_status(request_id, RequestStatus.NEED_RELEASE) - # wait for the request to be updated to NEED_RELEASE or have been released - while self.get_request(request_id, projection=["status"])["status"] in [ - RequestStatus.CANCELED.name, - RequestStatus.ERROR.name, - ]: - time.sleep(0.5) + self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=[ + RequestStatus.CANCELED, RequestStatus.ERROR + ]) else: # If it doesn't have assigned resources, just leave it as CANCELED or ERROR return # For the requests that were fulfilled, definitely have assigned resources, release them elif request["status"] == RequestStatus.FULFILLED.name: - self._request_collection.update_one( - { - "_id": request_id, - "status": RequestStatus.FULFILLED.name, - }, - { - "$set": { - "status": RequestStatus.NEED_RELEASE.name, - } - }, - ) - # wait for the request to be updated to NEED_RELEASE or have been released - while ( - self.get_request(request_id, projection=["status"])["status"] - == RequestStatus.FULFILLED.name - ): - time.sleep(0.5) + self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=RequestStatus.FULFILLED) # wait for the request to be released or canceled or errored during the release while self.get_request(request_id, projection=["status"])["status"] not in [ @@ -340,38 +300,10 @@ def release_all_resources(self): ("assigned_devices" in request) or ("assigned_sample_positions" in request) ): - self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE) + self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE, + original_status=[RequestStatus.CANCELED, RequestStatus.ERROR]) assigned_cancel_error_requests_id.append(request["_id"]) - # For the requests that were PENDING, mark them as CANCELED, they don't have assigned resources - self._request_collection.update_many( - { - "task_id": self.task_id, - "status": RequestStatus.PENDING.name, - }, - { - "$set": { - "status": RequestStatus.CANCELED.name, - } - }, - ) - # wait for all the requests to be updated - while any( - request["status"] - in [RequestStatus.FULFILLED.name, RequestStatus.PENDING.name] - for request in self.get_requests_by_task_id(self.task_id) - ): - time.sleep(0.5) - if assigned_cancel_error_requests_id: - # wait for the requests to be updated to updated to NEED_RELEASE or have been released - while any( - request["status"] - in [RequestStatus.CANCELED.name, RequestStatus.ERROR.name] - for request in self.get_requests_by_task_id(self.task_id) - if request["_id"] in assigned_cancel_error_requests_id - ): - time.sleep(0.5) - # wait for all the requests to be released or canceled or errored during the release while any( ( @@ -452,16 +384,22 @@ def _handle_canceled_request(self, request_id: ObjectId): request: dict[str, Any] = self._waiting.pop(request_id) f: Future = request["f"] - f.set_exception(Abort(f"Request {request_id} was canceled.")) + # for the canceled request, we will return an empty result + # and wait for the abort to be handled by the task actor + f.set_result({ + "devices": {}, + "sample_positions": {}, + "request_id": request_id, + }) @staticmethod def _post_process_requested_resource( - devices: dict[type[BaseDevice], str], + devices: dict[type[BaseDevice] | str, str], sample_positions: dict[str, list[str]], - resource_request: dict[str, list[dict[str, int | str]]], + resource_request: dict[str | type[BaseDevice] | None, dict[str, int]], ): processed_sample_positions: dict[ - type[BaseDevice] | None, dict[str, list[str]] + type[BaseDevice] | str | None, dict[str, list[str]] ] = {} for device_request, sample_position_dict in resource_request.items(): @@ -483,12 +421,6 @@ def _post_process_requested_resource( processed_sample_positions[device_request][prefix] = sample_positions[ reply_prefix ] - # { - # device_request: { - # prefix: sample_positions[prefix] for prefix in sample_position_dict - # } - # for device_request, sample_position_dict in resource_request.items() - # } return { "devices": devices, "sample_positions": processed_sample_positions, diff --git a/alab_management/utils/data_objects.py b/alab_management/utils/data_objects.py index 52a7ee1f..3d89d040 100644 --- a/alab_management/utils/data_objects.py +++ b/alab_management/utils/data_objects.py @@ -155,3 +155,7 @@ def make_jsonable(obj): get_completed_collection = _GetCompletedMongoCollection.get_collection get_completed_lock = _GetCompletedMongoCollection.get_lock + + +class DocumentNotUpdatedError(Exception): + """Exception raised when a document is not updated in the database.""" diff --git a/tests/fake_lab/tasks/infinite_task.py b/tests/fake_lab/tasks/infinite_task.py index 7e4f9c6f..abb700f0 100644 --- a/tests/fake_lab/tasks/infinite_task.py +++ b/tests/fake_lab/tasks/infinite_task.py @@ -2,6 +2,8 @@ from alab_management.task_view.task import BaseTask +from ..devices.device_that_never_ends import DeviceThatNeverEnds # noqa: TID252 + class InfiniteTask(BaseTask): def __init__(self, samples: list[ObjectId], *args, **kwargs): @@ -14,5 +16,6 @@ def __init__(self, samples: list[ObjectId], *args, **kwargs): self.sample = samples[0] def run(self): - while True: - pass + with self.lab_view.request_resources({DeviceThatNeverEnds: {}}) as (devices, _): + while True: + pass diff --git a/tests/test_launch.py b/tests/test_launch.py index 8ccc467c..f7199dbc 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -196,7 +196,7 @@ def compose_exp(exp_name): "type": "Starting", "prev_tasks": [], "parameters": { - "dest": "furnace_table", + "dest": "furnace_temp", }, "samples": ["test_sample"], }, @@ -209,40 +209,51 @@ def compose_exp(exp_name): ], } - experiment = compose_exp("Experiment with cancel") - resp = requests.post( - SUBMISSION_API, json=experiment - ) - resp_json = resp.json() - exp_id = ObjectId(resp_json["data"]["exp_id"]) - self.assertTrue("success", resp_json["status"]) + exp_ids = {} + for exp_name in ["Experiment with cancel when running", "Experiment with cancel when requesting resources"]: + experiment = compose_exp(exp_name) + resp = requests.post( + SUBMISSION_API, json=experiment + ) + resp_json = resp.json() + exp_id = ObjectId(resp_json["data"]["exp_id"]) + exp_ids[exp_name] = exp_id + self.assertTrue("success", resp_json["status"]) + time.sleep(2) - time.sleep(10) - self.assertEqual( - "RUNNING", self.experiment_view.get_experiment(exp_id)["status"] - ) + time.sleep(15) + for exp_id in exp_ids.values(): + self.assertEqual( + "RUNNING", self.experiment_view.get_experiment(exp_id)["status"] + ) - resp = requests.get( - f"http://127.0.0.1:8896/api/experiment/cancel/{exp_id!s}", - ) - self.assertEqual("success", resp.json()["status"]) - time.sleep(10) + for exp_name in ["Experiment with cancel when requesting resources", "Experiment with cancel when running"]: + exp_id = exp_ids[exp_name] + resp = requests.get( + f"http://127.0.0.1:8896/api/experiment/cancel/{exp_id!s}", + ) + self.assertEqual("success", resp.json()["status"]) + time.sleep(10) - pending_user_input = requests.get("http://127.0.0.1:8896/api/userinput/pending").json() - self.assertEqual(len(pending_user_input["pending_requests"].get(str(exp_id), [])), 1) - request_id = pending_user_input["pending_requests"][str(exp_id)][0]["id"] - # acknowledge the request - resp = requests.post( - "http://127.0.0.1:8896/api/userinput/submit", - json={ - "request_id": request_id, - "response": "OK", - "note": "dummy", - }, - ) - self.assertEqual("success", resp.json()["status"]) + pending_user_input = requests.get("http://127.0.0.1:8896/api/userinput/pending").json() + self.assertEqual(len(pending_user_input["pending_requests"].get(str(exp_id), [])), 1) + request_id = pending_user_input["pending_requests"][str(exp_id)][0]["id"] + request_prompt = pending_user_input["pending_requests"][str(exp_id)][0]["prompt"] + self.assertIn("dramatiq_abort.abort_manager.Abort", request_prompt) + # acknowledge the request + resp = requests.post( + "http://127.0.0.1:8896/api/userinput/submit", + json={ + "request_id": request_id, + "response": "OK", + "note": "dummy", + }, + ) + self.assertEqual("success", resp.json()["status"]) time.sleep(10) - self.assertEqual( - "COMPLETED", self.experiment_view.get_experiment(exp_id)["status"] - ) + + for exp_id in exp_ids.values(): + self.assertEqual( + "COMPLETED", self.experiment_view.get_experiment(exp_id)["status"] + ) diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index 5256ce94..a29a152a 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -75,7 +75,6 @@ def test_task_requester(self): { "devices": {furnace_type: "furnace_1"}, "sample_positions": {furnace_type: {"inside": ["furnace_1/inside/1"]}}, - "timeout_error": False, }, result, ) @@ -105,7 +104,6 @@ def test_task_requester(self): { "devices": {furnace_type: "furnace_1"}, "sample_positions": {furnace_type: {"inside": ["furnace_1/inside/1"]}}, - "timeout_error": False, }, result, ) @@ -134,7 +132,6 @@ def test_task_requester(self): { "devices": {furnace_type: "furnace_1"}, "sample_positions": {furnace_type: {"inside": ["furnace_1/inside/1"]}}, - "timeout_error": False, }, result, ) @@ -159,7 +156,6 @@ def test_task_requester(self): ] } }, - "timeout_error": False, }, result, ) @@ -189,7 +185,6 @@ def test_task_requester(self): ] }, }, - "timeout_error": False, }, result, ) From ca4e025b1a40081a5910cfa70a6c15fb74639b99 Mon Sep 17 00:00:00 2001 From: Yuxing Fei Date: Mon, 6 May 2024 22:54:05 -0700 Subject: [PATCH 2/3] temp --- alab_management/resource_manager/resource_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/alab_management/resource_manager/resource_manager.py b/alab_management/resource_manager/resource_manager.py index a3aa908d..f9124f2e 100644 --- a/alab_management/resource_manager/resource_manager.py +++ b/alab_management/resource_manager/resource_manager.py @@ -21,7 +21,7 @@ from alab_management.sample_view.sample import SamplePosition from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView from alab_management.task_view import TaskView -from alab_management.task_view.task_enums import TaskStatus +from alab_management.task_view.task_enums import CancelingProgress, TaskStatus from alab_management.utils.data_objects import DocumentNotUpdatedError, get_collection from alab_management.utils.module_ops import load_definition @@ -84,7 +84,8 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]): task_id = request_entry["task_id"] task_status = self.task_view.get_status(task_id=task_id) - if task_status != TaskStatus.REQUESTING_RESOURCES: + if (task_status != TaskStatus.REQUESTING_RESOURCES or + self.task_view.get_tasks_to_be_canceled(canceling_progress=CancelingProgress.WORKER_NOTIFIED)): # this implies the Task has been cancelled or errored somewhere else in the chain -- we should # not allocate any resources to the broken Task. self.update_request_status( From dae2dcbe115639f682c6c840f81cbb9e5c0c088d Mon Sep 17 00:00:00 2001 From: Yuxing Fei Date: Mon, 6 May 2024 23:41:25 -0700 Subject: [PATCH 3/3] remove the timeout to better reflect the real case --- tests/test_lab_view.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_lab_view.py b/tests/test_lab_view.py index 9f2a3212..736c923d 100644 --- a/tests/test_lab_view.py +++ b/tests/test_lab_view.py @@ -99,7 +99,6 @@ def test_request_resources(self): "LOCKED", self.sample_view.get_sample_position_status("furnace_table")[0].name, ) - time.sleep(0.5) self.assertEqual("IDLE", self.device_view.get_status("furnace_1").name) self.assertEqual("IDLE", self.device_view.get_status("dummy").name)