Skip to content

Commit

Permalink
Merge pull request #67 from CederGroupHub/better-resource-manager
Browse files Browse the repository at this point in the history
Better resource manager
  • Loading branch information
odartsi authored May 8, 2024
2 parents fc5bf31 + dae2dcb commit 061f298
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 177 deletions.
26 changes: 11 additions & 15 deletions alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,18 @@ def request_resources(
resource_request=resource_request, timeout=timeout, priority=priority
)
request_id = result["request_id"]
timeout_error = result["timeout_error"]
if timeout_error:
raise TimeoutError
else:
devices = result["devices"]
sample_positions = result["sample_positions"]
devices = {
device_type: self._device_client.create_device_wrapper(device_name)
for device_type, device_name in devices.items()
} # type: ignore
self._task_view.update_status(
task_id=self.task_id, status=TaskStatus.RUNNING
)
yield devices, sample_positions
devices = result["devices"]
sample_positions = result["sample_positions"]
devices = {
device_type: self._device_client.create_device_wrapper(device_name)
for device_type, device_name in devices.items()
} # type: ignore
self._task_view.update_status(
task_id=self.task_id, status=TaskStatus.RUNNING
)
yield devices, sample_positions

self._resource_requester.release_resources(request_id=request_id)
self._resource_requester.release_resources(request_id=request_id)

def _sample_name_to_id(self, sample_name: str) -> ObjectId:
"""
Expand Down
37 changes: 24 additions & 13 deletions alab_management/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import time
from datetime import datetime
from traceback import format_exc
from typing import Any, cast

import dill
Expand All @@ -20,8 +21,8 @@
from alab_management.sample_view.sample import SamplePosition
from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView
from alab_management.task_view import TaskView
from alab_management.task_view.task_enums import TaskStatus
from alab_management.utils.data_objects import get_collection
from alab_management.task_view.task_enums import CancelingProgress, TaskStatus
from alab_management.utils.data_objects import DocumentNotUpdatedError, get_collection
from alab_management.utils.module_ops import load_definition


Expand Down Expand Up @@ -62,7 +63,7 @@ def handle_released_resources(self):
self._release_devices(devices)
self._release_sample_positions(sample_positions)
self.update_request_status(
request_id=request["_id"], status=RequestStatus.RELEASED
request_id=request["_id"], status=RequestStatus.RELEASED, original_status=RequestStatus.NEED_RELEASE
)

def handle_requested_resources(self):
Expand All @@ -83,12 +84,14 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
task_id = request_entry["task_id"]

task_status = self.task_view.get_status(task_id=task_id)
if task_status != TaskStatus.REQUESTING_RESOURCES:
if (task_status != TaskStatus.REQUESTING_RESOURCES or
self.task_view.get_tasks_to_be_canceled(canceling_progress=CancelingProgress.WORKER_NOTIFIED)):
# this implies the Task has been cancelled or errored somewhere else in the chain -- we should
# not allocate any resources to the broken Task.
self.update_request_status(
request_id=request_entry["_id"],
status=RequestStatus.CANCELED,
original_status=RequestStatus.PENDING,
)
return

Expand Down Expand Up @@ -147,8 +150,10 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):

# in case some errors happen, we will raise the error in the task process instead of the main process
except Exception as error: # pylint: disable=broad-except
self._request_collection.update_one(
{"_id": request_entry["_id"]},
# we will store the error in the database for easier debugging
error.args = (format_exc(),)
returned_value = self._request_collection.update_one(
{"_id": request_entry["_id"], "status": RequestStatus.PENDING.name},
{
"$set": {
"status": RequestStatus.ERROR.name,
Expand All @@ -158,11 +163,16 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
}
},
)
if returned_value.modified_count != 1:
raise DocumentNotUpdatedError(
f"Error updating request {request_entry['_id']}: cannot update the request status from PENDING "
f"to ERROR."
) from error
return

# if both devices and sample positions can be satisfied
self._request_collection.update_one(
{"_id": request_entry["_id"]},
returned_value = self._request_collection.update_one(
{"_id": request_entry["_id"], "status": RequestStatus.PENDING.name},
{
"$set": {
"assigned_devices": devices,
Expand All @@ -172,11 +182,12 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
}
},
)
# label the resources as occupied
self._occupy_devices(devices=devices, task_id=task_id)
self._occupy_sample_positions(
sample_positions=sample_positions, task_id=task_id
)
if returned_value.modified_count == 1:
# label the resources as occupied
self._occupy_devices(devices=devices, task_id=task_id)
self._occupy_sample_positions(
sample_positions=sample_positions, task_id=task_id
)

def _occupy_devices(self, devices: dict[str, dict[str, Any]], task_id: ObjectId):
for device in devices.values():
Expand Down
148 changes: 40 additions & 108 deletions alab_management/resource_manager/resource_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import dill
from bson import ObjectId
from dramatiq_abort import Abort
from pydantic import BaseModel, root_validator

from alab_management.device_view.device import BaseDevice
Expand All @@ -21,7 +20,7 @@
from alab_management.sample_view.sample import SamplePosition
from alab_management.sample_view.sample_view import SamplePositionRequest
from alab_management.task_view import TaskPriority
from alab_management.utils.data_objects import get_collection
from alab_management.utils.data_objects import DocumentNotUpdatedError, get_collection

_SampleRequestDict = dict[str, int]
_ResourceRequestDict = dict[
Expand Down Expand Up @@ -91,16 +90,29 @@ class RequestMixin:
def __init__(self):
self._request_collection = get_collection("requests")

def update_request_status(self, request_id: ObjectId, status: RequestStatus):
def update_request_status(self, request_id: ObjectId, status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None):
"""Update the status of a request by request_id."""
value_returned = self._request_collection.update_one(
{"_id": request_id}, {"$set": {"status": status.name}}
)
# wait for the request to be updated
while (
self.get_request(request_id, projection=["status"])["status"] != status.name
):
time.sleep(0.5)
if original_status is not None:
if isinstance(original_status, list):
value_returned = self._request_collection.update_one(
{"_id": request_id, "status": {"$in": [status.name for status in original_status]}},
{"$set": {"status": status.name}},
)
else:
value_returned = self._request_collection.update_one(
{"_id": request_id, "status": original_status.name},
{"$set": {"status": status.name}},
)
else:
value_returned = self._request_collection.update_one(
{"_id": request_id}, {"$set": {"status": status.name}}
)
if value_returned.modified_count == 0:
raise DocumentNotUpdatedError(
f"Request {request_id} was not updated to {status.name}, "
f"because it is not in {original_status.name} status."
)
return value_returned

def get_request(self, request_id: ObjectId, **kwargs) -> dict[str, Any] | None:
Expand Down Expand Up @@ -219,46 +231,14 @@ def request_resources(
) # DB_ACCESS_OUTSIDE_VIEW
_id: ObjectId = cast(ObjectId, result.inserted_id)
self._waiting[_id] = {"f": f, "device_str_to_request": device_str_to_request}

try:
# wait for the request to be fulfilled
start_time = time.time()
while self.get_request(_id, projection=["status"])["status"] not in [
RequestStatus.FULFILLED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]:
if timeout is not None and time.time() - start_time > timeout:
raise TimeoutError
time.sleep(0.1)

result = f.result(timeout=None)
except (
TimeoutError
): # cancel the task if timeout, make sure it is not fulfilled
if (
self.get_request(_id, projection=["status"])["status"]
!= RequestStatus.FULFILLED.name
):
self.update_request_status(
request_id=_id, status=RequestStatus.CANCELED
)
# wait for the request status to be updated
while (self.get_request(_id, projection=["status"]))[
"status"
] != RequestStatus.CANCELED.name:
time.sleep(0.5)
return {"request_id": _id, "timeout_error": True}
else: # if the request is fulfilled, return the result normally, wrong timeout
result = f.result(timeout=None)
result = f.result(timeout=timeout)
return {
**self._post_process_requested_resource(
devices=result["devices"],
sample_positions=result["sample_positions"],
resource_request=resource_request,
),
"request_id": result["request_id"],
"timeout_error": False,
}

def release_resources(self, request_id: ObjectId):
Expand All @@ -269,35 +249,15 @@ def release_resources(self, request_id: ObjectId):
if ("assigned_devices" in request) or (
"assigned_sample_positions" in request
):
self.update_request_status(request_id, RequestStatus.NEED_RELEASE)
# wait for the request to be updated to NEED_RELEASE or have been released
while self.get_request(request_id, projection=["status"])["status"] in [
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]:
time.sleep(0.5)
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=[
RequestStatus.CANCELED, RequestStatus.ERROR
])
else:
# If it doesn't have assigned resources, just leave it as CANCELED or ERROR
return
# For the requests that were fulfilled, definitely have assigned resources, release them
elif request["status"] == RequestStatus.FULFILLED.name:
self._request_collection.update_one(
{
"_id": request_id,
"status": RequestStatus.FULFILLED.name,
},
{
"$set": {
"status": RequestStatus.NEED_RELEASE.name,
}
},
)
# wait for the request to be updated to NEED_RELEASE or have been released
while (
self.get_request(request_id, projection=["status"])["status"]
== RequestStatus.FULFILLED.name
):
time.sleep(0.5)
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=RequestStatus.FULFILLED)

# wait for the request to be released or canceled or errored during the release
while self.get_request(request_id, projection=["status"])["status"] not in [
Expand Down Expand Up @@ -340,38 +300,10 @@ def release_all_resources(self):
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
):
self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE)
self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR])
assigned_cancel_error_requests_id.append(request["_id"])

# For the requests that were PENDING, mark them as CANCELED, they don't have assigned resources
self._request_collection.update_many(
{
"task_id": self.task_id,
"status": RequestStatus.PENDING.name,
},
{
"$set": {
"status": RequestStatus.CANCELED.name,
}
},
)
# wait for all the requests to be updated
while any(
request["status"]
in [RequestStatus.FULFILLED.name, RequestStatus.PENDING.name]
for request in self.get_requests_by_task_id(self.task_id)
):
time.sleep(0.5)
if assigned_cancel_error_requests_id:
# wait for the requests to be updated to updated to NEED_RELEASE or have been released
while any(
request["status"]
in [RequestStatus.CANCELED.name, RequestStatus.ERROR.name]
for request in self.get_requests_by_task_id(self.task_id)
if request["_id"] in assigned_cancel_error_requests_id
):
time.sleep(0.5)

# wait for all the requests to be released or canceled or errored during the release
while any(
(
Expand Down Expand Up @@ -452,16 +384,22 @@ def _handle_canceled_request(self, request_id: ObjectId):
request: dict[str, Any] = self._waiting.pop(request_id)
f: Future = request["f"]

f.set_exception(Abort(f"Request {request_id} was canceled."))
# for the canceled request, we will return an empty result
# and wait for the abort to be handled by the task actor
f.set_result({
"devices": {},
"sample_positions": {},
"request_id": request_id,
})

@staticmethod
def _post_process_requested_resource(
devices: dict[type[BaseDevice], str],
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str, list[dict[str, int | str]]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
):
processed_sample_positions: dict[
type[BaseDevice] | None, dict[str, list[str]]
type[BaseDevice] | str | None, dict[str, list[str]]
] = {}

for device_request, sample_position_dict in resource_request.items():
Expand All @@ -483,12 +421,6 @@ def _post_process_requested_resource(
processed_sample_positions[device_request][prefix] = sample_positions[
reply_prefix
]
# {
# device_request: {
# prefix: sample_positions[prefix] for prefix in sample_position_dict
# }
# for device_request, sample_position_dict in resource_request.items()
# }
return {
"devices": devices,
"sample_positions": processed_sample_positions,
Expand Down
4 changes: 4 additions & 0 deletions alab_management/utils/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,7 @@ def make_jsonable(obj):

get_completed_collection = _GetCompletedMongoCollection.get_collection
get_completed_lock = _GetCompletedMongoCollection.get_lock


class DocumentNotUpdatedError(Exception):
"""Exception raised when a document is not updated in the database."""
7 changes: 5 additions & 2 deletions tests/fake_lab/tasks/infinite_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from alab_management.task_view.task import BaseTask

from ..devices.device_that_never_ends import DeviceThatNeverEnds # noqa: TID252


class InfiniteTask(BaseTask):
def __init__(self, samples: list[ObjectId], *args, **kwargs):
Expand All @@ -14,5 +16,6 @@ def __init__(self, samples: list[ObjectId], *args, **kwargs):
self.sample = samples[0]

def run(self):
while True:
pass
with self.lab_view.request_resources({DeviceThatNeverEnds: {}}) as (devices, _):
while True:
pass
1 change: 0 additions & 1 deletion tests/test_lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def test_request_resources(self):
"LOCKED",
self.sample_view.get_sample_position_status("furnace_table")[0].name,
)
time.sleep(0.5)
self.assertEqual("IDLE", self.device_view.get_status("furnace_1").name)
self.assertEqual("IDLE", self.device_view.get_status("dummy").name)

Expand Down
Loading

0 comments on commit 061f298

Please sign in to comment.