Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User input request with note #74

Merged
merged 4 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions alab_management/experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
tasks and samples and mark the finished tasks in the database when it is
done.
"""

import time
from typing import Any

Expand All @@ -30,7 +31,7 @@ def __init__(self):

config = AlabOSConfig()
self.__copy_to_completed_db = (
"mongodb_completed" in config
"mongodb_completed" in config
) # if this is not defined in the config, assume it this feature is not being used.
if self.__copy_to_completed_db:
self.completed_experiment_view = CompletedExperimentView()
Expand Down Expand Up @@ -91,7 +92,9 @@ def _handle_pending_experiment(self, experiment: dict[str, Any]):
},
)
if task_graph.has_cycle():
self.experiment_view.update_experiment_status(experiment["_id"], ExperimentStatus.ERROR)
self.experiment_view.update_experiment_status(
experiment["_id"], ExperimentStatus.ERROR
)
print(f"Experiment ({experiment['_id']}) has a cycle in the graph.")
return

Expand Down Expand Up @@ -156,13 +159,13 @@ def mark_completed_experiments(self):

# if all the tasks of an experiment have been finished
if all(
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
):
self.experiment_view.update_experiment_status(
exp_id=experiment["_id"], status=ExperimentStatus.COMPLETED
Expand Down
16 changes: 12 additions & 4 deletions alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from alab_management.task_view.task import BaseTask
from alab_management.task_view.task_enums import TaskPriority, TaskStatus
from alab_management.task_view.task_view import TaskView
from alab_management.user_input import request_user_input
from alab_management.user_input import request_user_input, request_user_input_with_note


class DeviceRunningException(Exception):
Expand Down Expand Up @@ -126,9 +126,7 @@ def request_resources(
device_type: self._device_client.create_device_wrapper(device_name)
for device_type, device_name in devices.items()
} # type: ignore
self._task_view.update_status(
task_id=self.task_id, status=TaskStatus.RUNNING
)
self._task_view.update_status(task_id=self.task_id, status=TaskStatus.RUNNING)
yield devices, sample_positions

self._resource_requester.release_resources(request_id=request_id)
Expand Down Expand Up @@ -335,6 +333,16 @@ def request_user_input(self, prompt: str, options: list[str]) -> str:
"""
return request_user_input(task_id=self.task_id, prompt=prompt, options=options)

def request_user_input_with_note(
self, prompt: str, options: list[str]
) -> tuple[str, str]:
"""Request user input from the user. This function will block until the user inputs something. Returns the
value returned by the user and the note.
"""
return request_user_input_with_note(
task_id=self.task_id, prompt=prompt, options=options
)

@property
def priority(self) -> int:
"""Get the priority of the task."""
Expand Down
17 changes: 11 additions & 6 deletions alab_management/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def handle_released_resources(self):
self._release_devices(devices)
self._release_sample_positions(sample_positions)
self.update_request_status(
request_id=request["_id"], status=RequestStatus.RELEASED, original_status=RequestStatus.NEED_RELEASE
request_id=request["_id"],
status=RequestStatus.RELEASED,
original_status=RequestStatus.NEED_RELEASE,
)

def handle_requested_resources(self):
Expand All @@ -84,9 +86,12 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
task_id = request_entry["task_id"]

task_status = self.task_view.get_status(task_id=task_id)
if (task_status != TaskStatus.REQUESTING_RESOURCES or
task_id in {task["task_id"] for task in self.task_view.get_tasks_to_be_canceled(
canceling_progress=CancelingProgress.WORKER_NOTIFIED)}):
if task_status != TaskStatus.REQUESTING_RESOURCES or task_id in {
task["task_id"]
for task in self.task_view.get_tasks_to_be_canceled(
canceling_progress=CancelingProgress.WORKER_NOTIFIED
)
}:
# this implies the Task has been cancelled or errored somewhere else in the chain -- we should
# not allocate any resources to the broken Task.
self.update_request_status(
Expand Down Expand Up @@ -197,7 +202,7 @@ def _occupy_devices(self, devices: dict[str, dict[str, Any]], task_id: ObjectId)
)

def _occupy_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
):
for sample_positions_ in sample_positions.values():
for sample_position_ in sample_positions_:
Expand All @@ -211,7 +216,7 @@ def _release_devices(self, devices: dict[str, dict[str, Any]]):
self.device_view.release_device(device["name"])

def _release_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]]
self, sample_positions: dict[str, list[dict[str, Any]]]
):
for sample_positions_ in sample_positions.values():
for sample_position in sample_positions_:
Expand Down
91 changes: 52 additions & 39 deletions alab_management/resource_manager/resource_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TaskLauncher is the core module of the system,
which actually executes the tasks.
"""

import concurrent
import time
from concurrent.futures import Future
Expand Down Expand Up @@ -102,13 +103,20 @@ class RequestMixin:
def __init__(self):
self._request_collection = get_collection("requests")

def update_request_status(self, request_id: ObjectId, status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None):
def update_request_status(
self,
request_id: ObjectId,
status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None,
):
"""Update the status of a request by request_id."""
if original_status is not None:
if isinstance(original_status, list):
value_returned = self._request_collection.update_one(
{"_id": request_id, "status": {"$in": [status.name for status in original_status]}},
{
"_id": request_id,
"status": {"$in": [status.name for status in original_status]},
},
{"$set": {"status": status.name}},
)
else:
Expand Down Expand Up @@ -153,8 +161,8 @@ class ResourceRequester(RequestMixin):
"""

def __init__(
self,
task_id: ObjectId,
self,
task_id: ObjectId,
):
self._request_collection = get_collection("requests")
self._waiting: dict[ObjectId, dict[str, Any]] = {}
Expand All @@ -180,10 +188,10 @@ def __close__(self):
__del__ = __close__

def request_resources(
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
) -> dict[str, Any]:
"""
Request lab resources.
Expand Down Expand Up @@ -247,14 +255,10 @@ def request_resources(
result = self.get_concurrent_result(f, timeout=timeout)
except concurrent.futures.TimeoutError as e:
# if the request is not fulfilled, cancel it to make sure the resources are released
request = self._request_collection.find_one_and_update({
"_id": _id,
"status": {"$ne": RequestStatus.FULFILLED.name}
}, {
"$set": {
"status": RequestStatus.CANCELED.name
}
})
request = self._request_collection.find_one_and_update(
{"_id": _id, "status": {"$ne": RequestStatus.FULFILLED.name}},
{"$set": {"status": RequestStatus.CANCELED.name}},
)
if request is not None:
raise CombinedTimeoutError(
f"Request {result.inserted_id} timed out after {timeout} seconds."
Expand Down Expand Up @@ -292,17 +296,23 @@ def release_resources(self, request_id: ObjectId):
request = self.get_request(request_id)
if request["status"] in [RequestStatus.CANCELED.name, RequestStatus.ERROR.name]:
if ("assigned_devices" in request) or (
"assigned_sample_positions" in request
"assigned_sample_positions" in request
):
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=[
RequestStatus.CANCELED, RequestStatus.ERROR
])
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
else:
# If it doesn't have assigned resources, just leave it as CANCELED or ERROR
return
# For the requests that were fulfilled, definitely have assigned resources, release them
elif request["status"] == RequestStatus.FULFILLED.name:
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=RequestStatus.FULFILLED)
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=RequestStatus.FULFILLED,
)

# wait for the request to be released or canceled or errored during the release
while self.get_request(request_id, projection=["status"])["status"] not in [
Expand Down Expand Up @@ -342,24 +352,27 @@ def release_all_resources(self):
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
] and (
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
):
self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR])
self.update_request_status(
request["_id"],
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
assigned_cancel_error_requests_id.append(request["_id"])

# wait for all the requests to be released or canceled or errored during the release
while any(
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
):
time.sleep(0.5)

Expand Down Expand Up @@ -435,9 +448,9 @@ def _handle_canceled_request(self, request_id: ObjectId):

@staticmethod
def _post_process_requested_resource(
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
):
processed_sample_positions: dict[
type[BaseDevice] | str | None, dict[str, list[str]]
Expand All @@ -456,7 +469,7 @@ def _post_process_requested_resource(
f"{devices[device_request]}{SamplePosition.SEPARATOR}"
)
if not reply_prefix.startswith(
device_prefix
device_prefix
): # dont extra prepend for nested requests
reply_prefix = device_prefix + reply_prefix
processed_sample_positions[device_request][prefix] = sample_positions[
Expand Down
1 change: 1 addition & 0 deletions alab_management/scripts/launch_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Launch Dramatiq worker to submit tasks."""

from alab_management.task_manager.task_manager import TaskManager


Expand Down
4 changes: 3 additions & 1 deletion alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,7 @@ def handle_tasks_to_be_canceled(self):
self.task_view.update_canceling_progress(
task_id=task_entry["task_id"],
canceling_progress=CancelingProgress.WORKER_NOTIFIED,
original_progress=CancelingProgress[task_entry["canceling_progress"]],
original_progress=CancelingProgress[
task_entry["canceling_progress"]
],
)
51 changes: 51 additions & 0 deletions alab_management/user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,29 @@ def get_all_pending_requests(self) -> list:
self._input_collection.find({"status": UserRequestStatus.PENDING.value}),
)

def retrieve_user_input_with_note(self, request_id: ObjectId) -> tuple[str, str]:
"""
Retrive response from user for a given request. Blocks until request is marked as completed.

Returns the user response, which is one of a list of options
"""
status = UserRequestStatus.PENDING
try:
while status == UserRequestStatus.PENDING:
request = self._input_collection.find_one({"_id": request_id})
if request is None:
raise ValueError(
f"User input request id {request_id} does not exist!"
)
status = UserRequestStatus(request["status"])
time.sleep(0.5)
except: # noqa: E722
self._input_collection.update_one(
{"_id": request_id}, {"$set": {"status": UserRequestStatus.ERROR.name}}
)
raise
return request["response"], request["note"]


def request_user_input(
task_id: ObjectId | None,
Expand Down Expand Up @@ -185,3 +208,31 @@ def request_maintenance_input(prompt: str, options: list[str]):
maintenance=True,
category="Maintenance",
)


def request_user_input_with_note(
task_id: ObjectId | None,
prompt: str,
options: list[str],
maintenance: bool = False,
category: str = "Unknown Category",
) -> tuple[str, str]:
"""
Request user input through the dashboard. Blocks until response is given.

task_id (ObjectId): task id requesting user input
prompt (str): prompt to give user
options (List[str]): response options to give user
maintenance (bool): if true, mark this as a request for overall system maintenance

Returns user response as string.
"""
user_input_view = UserInputView()
request_id = user_input_view.insert_request(
task_id=task_id,
prompt=prompt,
options=options,
maintenance=maintenance,
category=category,
)
return user_input_view.retrieve_user_input_with_note(request_id=request_id)
6 changes: 4 additions & 2 deletions alab_management/utils/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def init(cls):
)
sim_mode_flag = AlabOSConfig().is_sim_mode()
# force to enable sim mode, just in case
cls.db = cls.client[AlabOSConfig()["general"]["name"] + ("_sim" * sim_mode_flag)]
cls.db = cls.client[
AlabOSConfig()["general"]["name"] + ("_sim" * sim_mode_flag)
]


class _GetCompletedMongoCollection(_BaseGetMongoCollection):
Expand All @@ -76,7 +78,7 @@ def init(cls):
if sim_mode_flag:
cls.db = cls.client[
AlabOSConfig()["general"]["name"] + "(completed)" + "_sim"
]
]
else:
cls.db = cls.client[AlabOSConfig()["general"]["name"] + "(completed)"]
# type: ignore # pylint: disable=unsubscriptable-object
Expand Down
Loading
Loading