Skip to content

Commit

Permalink
Implemented waiting for all MongoDB updates
Browse files Browse the repository at this point in the history
Need transaction for future update
  • Loading branch information
bernardusrendy committed Apr 20, 2024
1 parent 4adb72b commit d32b92f
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 10 deletions.
27 changes: 23 additions & 4 deletions alab_management/device_view/device_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from enum import Enum, auto, unique
from typing import Any, TypeVar, cast
import time

import pymongo # type: ignore
from bson import ObjectId # type: ignore
Expand Down Expand Up @@ -106,6 +107,9 @@ def sync_device_status(self):
required_status=None,
task_id=None,
)
# Wait until the device status has been updated to the target status
while self.get_status(device_name=device.name).name != status.name:
time.sleep(0.5)

def add_devices_to_db(self):
"""
Expand Down Expand Up @@ -285,7 +289,6 @@ def get_device(self, device_name: str) -> dict[str, Any]:
def get_status(self, device_name: str) -> DeviceTaskStatus:
"""Get device status by device name, if not found, raise ``ValueError``."""
device_entry = self.get_device(device_name=device_name)

return DeviceTaskStatus[device_entry["status"]]

def occupy_device(self, device: BaseDevice | str, task_id: ObjectId):
Expand All @@ -296,6 +299,10 @@ def occupy_device(self, device: BaseDevice | str, task_id: ObjectId):
target_status=DeviceTaskStatus.OCCUPIED,
task_id=task_id,
)
device_name = device.name if isinstance(device, BaseDevice) else device
# Wait until the device status has been updated to OCCUPIED
while self.get_status(device_name=device_name).name != "OCCUPIED":
time.sleep(0.5)

def get_devices_by_task(self, task_id: ObjectId | None) -> list[BaseDevice]:
"""Get devices given a task id (regardless of its status!)."""
Expand Down Expand Up @@ -332,6 +339,9 @@ def release_device(self, device_name: str):
{"name": device_name},
{"$set": update_dict},
)
# wait until the device status has been updated to IDLE
while self.get_status(device_name=device_name).name != "IDLE":
time.sleep(0.5)

def get_samples_on_device(self, device_name: str):
"""Get all samples on a device."""
Expand Down Expand Up @@ -396,6 +406,9 @@ def _update_status(
}
},
)
# wait until the device status has been updated to target_status
while self.get_status(device_name=device_name).name != target_status.name:
time.sleep(0.5)

def query_property(self, device_name: str, prop: str):
"""
Expand Down Expand Up @@ -437,15 +450,15 @@ def set_message(self, device_name: str, message: str):
{"name": device_name}, {"$set": {"message": message}}
)

def get_message(self, device_name: str):
def get_message(self, device_name: str) -> str:
"""Gets the current device message. Message is used to communicate device state with the user dashboard.
Args:
device_name (str): name of the device to set the message for
"""
return self.get_device(device_name=device_name)["message"]

def get_all_attributes(self, device_name: str):
def get_all_attributes(self, device_name: str) -> dict[str, Any]:
"""Returns the device attributes.
Args:
Expand All @@ -458,7 +471,7 @@ def get_all_attributes(self, device_name: str):
device = self.get_device(device_name=device_name)
return device["attributes"]

def get_attribute(self, device_name: str, attribute: str):
def get_attribute(self, device_name: str, attribute: str) -> Any:
"""Gets a device attribute. Attributes are used to store device-specific values in the database.
Args:
Expand Down Expand Up @@ -534,6 +547,9 @@ def pause_device(self, device_name: str):
}
},
)
# wait until the device pause status has been updated
while self.get_device(device_name=device_name)["pause_status"].name != new_pause_status:
time.sleep(0.5)

def unpause_device(self, device_name: str):
"""Unpause a device."""
Expand All @@ -556,6 +572,9 @@ def unpause_device(self, device_name: str):
{"name": device_name},
{"$set": update_dict},
)
# wait until the device pause status has been updated
while self.get_device(device_name=device_name)["pause_status"].name != "RELEASED":
time.sleep(0.5)

def __exit__(self, exc_type, exc_value, traceback):
"""Disconnect from all devices when exiting the context manager."""
Expand Down
13 changes: 13 additions & 0 deletions alab_management/experiment_view/experiment_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from enum import Enum, auto
from typing import Any, cast
import time

from bson import ObjectId # type: ignore

Expand Down Expand Up @@ -125,6 +126,9 @@ def update_experiment_status(self, exp_id: ObjectId, status: ExperimentStatus):
{"_id": exp_id},
{"$set": update_dict},
)
# Wait until experiment status is updated in the database
while self.get_experiment(exp_id=exp_id)["status"] != status.name:
time.sleep(0.5)

def update_sample_task_id(
self, exp_id, sample_ids: list[ObjectId], task_ids: list[ObjectId]
Expand Down Expand Up @@ -161,6 +165,15 @@ def update_sample_task_id(
}
},
)
# Wait until the sample and task id's are updated
update = "not completed"
while update != "completed":
experiment = self.get_experiment(exp_id=exp_id)
if all(
sample["sample_id"] is not None for sample in experiment["samples"]
) and all(task["task_id"] is not None for task in experiment["tasks"]):
update = "completed"
time.sleep(0.5)

def get_experiment_by_task_id(self, task_id: ObjectId) -> dict[str, Any] | None:
"""Get an experiment that contains a task with the given task_id."""
Expand Down
21 changes: 21 additions & 0 deletions alab_management/sample_view/sample_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from enum import Enum, auto
from typing import Any, cast
import time

import pymongo # type: ignore
from bson import ObjectId # type: ignore
Expand Down Expand Up @@ -244,6 +245,13 @@ def is_unoccupied_position(self, position: str) -> bool:
self.get_sample_position_status(position)[0]
is not SamplePositionStatus.OCCUPIED
)

def is_locked_position(self, position: str) -> bool:
"""Tell if a sample position is locked or not."""
sample_position = self.get_sample_position(position=position)
if sample_position is None:
raise ValueError(f"Invalid sample position: {position}")
return sample_position["task_id"] is not None

def get_available_sample_position(
self, task_id: ObjectId, position_prefix: str
Expand Down Expand Up @@ -310,6 +318,10 @@ def lock_sample_position(self, task_id: ObjectId, position: str):
}
},
)
# Wait until the position is locked successfully
while not self.is_locked_position(position):
time.sleep(0.5)


def release_sample_position(self, position: str):
"""Unlock a sample position."""
Expand All @@ -324,6 +336,9 @@ def release_sample_position(self, position: str):
}
},
)
# Wait until the position is released successfully
while self.is_locked_position(position):
time.sleep(0.5)

def get_sample_positions_by_task(self, task_id: ObjectId | None) -> list[str]:
"""Get the list of sample positions that is locked by a task (given task id)."""
Expand Down Expand Up @@ -423,6 +438,9 @@ def update_sample_task_id(self, sample_id: ObjectId, task_id: ObjectId | None):
}
},
)
# Wait until the task id is updated
while self.get_sample(sample_id).task_id != task_id:
time.sleep(0.5)

def update_sample_metadata(self, sample_id: ObjectId, metadata: dict[str, Any]):
"""Update the metadata for a sample. This adds new metadata or updates existing metadata."""
Expand Down Expand Up @@ -461,6 +479,9 @@ def move_sample(self, sample_id: ObjectId, position: str | None):
}
},
)
# Wait until the position is updated
while self.get_sample(sample_id).position != position:
time.sleep(0.5)

def exists(self, sample_id: ObjectId | str) -> bool:
"""Check if a sample exists in the database.
Expand Down
2 changes: 1 addition & 1 deletion alab_management/scripts/launch_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def launch_lab(host, port, debug):
task_launcher_thread.start()

while True:
time.sleep(2)
time.sleep(1)
if not experiment_manager_thread.is_alive():
sys.exit(1001)

Expand Down
17 changes: 12 additions & 5 deletions alab_management/task_manager/resource_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def __init__(self):

def update_request_status(self, request_id: ObjectId, status: RequestStatus):
"""Update the status of a request by request_id."""
return self._request_collection.update_one(
value_returned = self._request_collection.update_one(
{"_id": request_id}, {"$set": {"status": status.name}}
)
# wait for the request to be updated
while self.get_request(request_id, projection=["status"])["status"] != status.name:
time.sleep(0.5)
return value_returned

def get_request(self, request_id: ObjectId, **kwargs):
def get_request(self, request_id: ObjectId, **kwargs) -> dict[str, Any] | None:
"""Get a request by request_id."""
return self._request_collection.find_one(
{"_id": request_id}, **kwargs
Expand Down Expand Up @@ -216,6 +220,11 @@ def request_resources(
result = f.result(timeout=timeout)
except TimeoutError: # cancel the task if timeout
self.update_request_status(request_id=_id, status=RequestStatus.CANCELED)
# wait for the request status to be updated
while (self.get_request(_id, projection=["status"]))[
"status"
] != "CANCELED":
time.sleep(0.5)
raise
return {
**self._post_process_requested_resource(
Expand All @@ -241,9 +250,7 @@ def release_resources(self, request_id: ObjectId) -> bool:
)

# wait for the request to be released
while (self.get_request(request_id, projection=["status"]))[
"status"
] == RequestStatus.NEED_RELEASE.name:
while self.get_request(request_id, projection=["status"])["status"] != "NEED_RELEASE":
time.sleep(0.5)

return result.modified_count == 1
Expand Down
11 changes: 11 additions & 0 deletions alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
}
},
)
# wait until the parsed_sample_positions_request is updated in the database
while self.get_request(request_entry["_id"], projection=["parsed_sample_positions_request"])["parsed_sample_positions_request"] is None:
time.sleep(0.5)

sample_positions = self.sample_view.request_sample_positions(
task_id=task_id, sample_positions=parsed_sample_positions_request
)
Expand All @@ -375,6 +379,9 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
}
},
)
while self.get_request(request_entry["_id"], projection=["status"])[
"status"].name != "ERROR":
time.sleep(0.5)
return

# if both devices and sample positions can be satisfied
Expand All @@ -389,6 +396,10 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
}
},
)
# Wait until the status of the request is updated in the database
while self.get_request(request_entry["_id"], projection=["status"])[
"status"] != "FULFILLED":
time.sleep(0.5)
# label the resources as occupied
self._occupy_devices(devices=devices, task_id=task_id)
self._occupy_sample_positions(
Expand Down
4 changes: 4 additions & 0 deletions alab_management/task_view/task_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from datetime import datetime
import time
from typing import Any, cast

from bson import ObjectId
Expand Down Expand Up @@ -178,6 +179,9 @@ def update_status(self, task_id: ObjectId, status: TaskStatus):
{"_id": task_id},
{"$set": update_dict},
)
# Wait until the status is updated
while self.get_status(task_id=task_id).name != status.name:
time.sleep(0.5)

if status is TaskStatus.COMPLETED:
# try to figure out tasks that is READY
Expand Down
3 changes: 3 additions & 0 deletions alab_management/user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def update_request_status(self, request_id: ObjectId, response: str, note: str):
}
},
)
# Wait until the status is updated
while self.get_request(request_id)["status"] != UserRequestStatus.FULLFILLED.value:
time.sleep(1)

def retrieve_user_input(self, request_id: ObjectId) -> str:
"""
Expand Down

0 comments on commit d32b92f

Please sign in to comment.