Skip to content

Commit

Permalink
Extra checks
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardusrendy committed Apr 20, 2024
1 parent e4e5930 commit 6a8b885
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 19 deletions.
16 changes: 12 additions & 4 deletions alab_management/device_view/device_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,14 @@ def set_message(self, device_name: str, message: str):
message (str): message to be set
"""
self.get_device(device_name=device_name)

previous_update_time = self.get_device(device_name=device_name)["last_updated"]
self._device_collection.update_one(
{"name": device_name}, {"$set": {"message": message}}
{"name": device_name}, {"$set": {
"message": message,
"last_updated": datetime.now()}}
)
while self.get_device(device_name=device_name)["last_updated"] == previous_update_time:
time.sleep(0.5)

def get_message(self, device_name: str) -> str:
"""Gets the current device message. Message is used to communicate device state with the user dashboard.
Expand Down Expand Up @@ -497,7 +501,7 @@ def set_all_attributes(self, device_name: str, attributes: dict):
attributes (dict): attributes to be set
"""
self.get_device(device_name=device_name)

previous_update_time = self.get_device(device_name=device_name)["last_updated"]
self._device_collection.update_one(
{"name": device_name},
{
Expand All @@ -507,6 +511,8 @@ def set_all_attributes(self, device_name: str, attributes: dict):
}
},
)
while self.get_device(device_name=device_name)["last_updated"] == previous_update_time:
time.sleep(0.5)

def set_attribute(self, device_name: str, attribute: str, value: Any):
"""Sets a device attribute. Attributes are used to store device-specific values in the database.
Expand All @@ -518,7 +524,7 @@ def set_attribute(self, device_name: str, attribute: str, value: Any):
"""
attributes = self.get_all_attributes(device_name=device_name)
attributes[attribute] = value

previous_update_time = self.get_device(device_name=device_name)["last_updated"]
self._device_collection.update_one(
{"name": device_name},
{
Expand All @@ -528,6 +534,8 @@ def set_attribute(self, device_name: str, attribute: str, value: Any):
}
},
)
while self.get_device(device_name=device_name)["last_updated"] == previous_update_time:
time.sleep(0.5)

def pause_device(self, device_name: str):
"""Request pause for a specific device."""
Expand Down
6 changes: 6 additions & 0 deletions alab_management/experiment_view/completed_experiment_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from bson import ObjectId # type: ignore
import time

from alab_management.sample_view import CompletedSampleView
from alab_management.task_view import CompletedTaskView
Expand Down Expand Up @@ -46,6 +47,11 @@ def save_experiment(self, experiment_id: ObjectId):
filter={"_id": ObjectId(experiment_id)},
update={"$set": experiment_dict},
)
# wait for the update to complete
while self._completed_experiment_collection.find_one(
{"_id": ObjectId(experiment_id)}
) is None:
time.sleep(0.5)
else:
self._completed_experiment_collection.insert_one(experiment_dict)

Expand Down
6 changes: 3 additions & 3 deletions alab_management/experiment_view/experiment_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def update_sample_task_id(
update = "not completed"
while update != "completed":
experiment = self.get_experiment(exp_id=exp_id)
if all(
sample["sample_id"] is not None for sample in experiment["samples"]
) and all(task["task_id"] is not None for task in experiment["tasks"]):
updated_sample_ids = [sample["sample_id"] for sample in experiment["samples"]]
updated_task_ids = [task["task_id"] for task in experiment["tasks"]]
if updated_sample_ids == sample_ids and updated_task_ids == task_ids:
update = "completed"
time.sleep(0.5)

Expand Down
6 changes: 6 additions & 0 deletions alab_management/sample_view/completed_sample_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from bson import ObjectId # type: ignore
import time

from alab_management.utils.data_objects import get_collection, get_completed_collection

Expand Down Expand Up @@ -39,6 +40,11 @@ def save_sample(self, sample_id: ObjectId):
)
else:
self._completed_sample_collection.insert_one(sample_dict)
# wait for the insert to complete
while self._completed_sample_collection.find_one(
{"_id": ObjectId(sample_id)}
) is None:
time.sleep(0.5)

def exists(self, sample_id: ObjectId | str) -> bool:
"""Check if a sample exists in the database.
Expand Down
13 changes: 10 additions & 3 deletions alab_management/sample_view/sample_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def add_sample_positions_to_db(
if parent_device_name:
new_entry["parent_device"] = parent_device_name
self._sample_positions_collection.insert_one(new_entry)
# Wait until the sample position is created
while self.get_sample_position(name) is None:
time.sleep(0.5)

def clean_up_sample_position_collection(self):
"""Drop the sample position collection."""
Expand Down Expand Up @@ -373,7 +376,6 @@ def create_sample(
f"Unsupported sample name: {name}. "
f"Sample name should not contain '.' or '$'"
)

entry = {
"name": name,
"tags": tags or [],
Expand All @@ -392,7 +394,9 @@ def create_sample(
entry["_id"] = sample_id

result = self._sample_collection.insert_one(entry)

# Wait until the sample is created
while not self.exists(result.inserted_id):
time.sleep(0.5)
return cast(ObjectId, result.inserted_id)

def get_sample(self, sample_id: ObjectId) -> Sample:
Expand Down Expand Up @@ -449,11 +453,14 @@ def update_sample_metadata(self, sample_id: ObjectId, metadata: dict[str, Any]):

update_dict = {f"metadata.{k}": v for k, v in metadata.items()}
update_dict["last_updated"] = datetime.now()

previous_update_time = result["last_updated"]
self._sample_collection.update_one(
{"_id": sample_id},
{"$set": update_dict},
)
# Wait until the metadata is updated
while self.get_sample(sample_id).last_updated == previous_update_time:
time.sleep(0.5)

def move_sample(self, sample_id: ObjectId, position: str | None):
"""Update the sample with new position."""
Expand Down
2 changes: 1 addition & 1 deletion alab_management/scripts/launch_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def launch_lab(host, port, debug):
task_launcher_thread.start()

while True:
time.sleep(1)
time.sleep(0.001)
if not experiment_manager_thread.is_alive():
sys.exit(1001)

Expand Down
7 changes: 6 additions & 1 deletion alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,16 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
}
},
)
# Wait until the status of the request is updated in the database
# Wait until the status of the request is updated in the database,
# TODO: This process seems to be slow somehow
while (
self.get_request(request_entry["_id"], projection=["status"])["status"]
!= "FULFILLED"
):
# handle if the request is cancelled or errored
if (self.get_request(request_entry["_id"], projection=["status"])["status"] == "CANCELED"
or self.get_request(request_entry["_id"], projection=["status"])["status"] == "ERROR"):
return
time.sleep(0.5)
# label the resources as occupied
self._occupy_devices(devices=devices, task_id=task_id)
Expand Down
46 changes: 39 additions & 7 deletions alab_management/task_view/task_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def create_task(
if isinstance(task_id, ObjectId):
entry["_id"] = task_id
result = self._task_collection.insert_one(entry)
# Wait until the task is inserted
while not self.exists(task_id=cast(ObjectId, result.inserted_id)):
time.sleep(0.5)

return cast(ObjectId, result.inserted_id)

Expand All @@ -100,7 +103,7 @@ def create_subtask(
"last_updated": datetime.now(),
}
)

previous_update_time = task["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -110,6 +113,9 @@ def create_subtask(
}
},
)
# Wait until the subtask is inserted
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)
return subtask_id

def get_task(self, task_id: ObjectId, encode: bool = False) -> dict[str, Any]:
Expand All @@ -121,7 +127,7 @@ def get_task(self, task_id: ObjectId, encode: bool = False) -> dict[str, Any]:
encode: whether to encode the task using ``self.encode_task`` method
"""
task_id = ObjectId(task_id)

result = self._task_collection.find_one({"_id": task_id})

if result is None:
Expand Down Expand Up @@ -217,6 +223,7 @@ def update_status(self, task_id: ObjectId, status: TaskStatus):
task_id=next_task_id, status=TaskStatus.CANCELLED
)
else:
previous_update_time = next_task["last_updated"]
self._task_collection.update_one(
{"_id": next_task_id},
{
Expand All @@ -229,6 +236,9 @@ def update_status(self, task_id: ObjectId, status: TaskStatus):
},
},
)
# Wait until the status is updated
while self.get_task(task_id=next_task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)
self.try_to_mark_task_ready(
task_id=next_task_id
) # in case it was only waiting on task we just cancelled
Expand All @@ -254,11 +264,16 @@ def update_subtask_status(
raise ValueError(
f"No subtask found with id: {subtask_id} within task: {task_id}"
)

previous_update_time = self.get_task(task_id=task_id)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{"$set": {"subtasks": subtasks}},
{"$set": {
"subtasks": subtasks,
"last_updated": datetime.now()}},
)
# Wait until the status is updated
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)

def update_result(
self, task_id: ObjectId, name: str | None = None, value: Any = None
Expand All @@ -278,7 +293,7 @@ def update_result(
# raise ValueError("Must provide a value to update result with!")

update_path = "result" if name is None else f"result.{name}"

previous_update_time = self.get_task(task_id=task_id)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -288,6 +303,9 @@ def update_result(
}
},
)
# Wait until the status is updated
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)

def update_subtask_result(
self, task_id: ObjectId, subtask_id: ObjectId, result: Any
Expand All @@ -314,7 +332,7 @@ def update_subtask_result(
raise ValueError(
f"No subtask found with id: {subtask_id} within task: {task_id}"
)

previous_update_time = self.get_task(task_id=task_id)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -324,6 +342,9 @@ def update_subtask_result(
}
},
)
# Wait until the status is updated
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)

def try_to_mark_task_ready(self, task_id: ObjectId):
"""
Expand Down Expand Up @@ -409,7 +430,7 @@ def update_task_dependency(
for next_task in next_tasks:
if self.get_task(task_id=next_task) is None:
raise ValueError(f"Non-exist task id: {next_task}")

previous_update_time = self.get_task(task_id=task_id, encode=False)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -422,9 +443,13 @@ def update_task_dependency(
},
},
)
# Wait until the status is updated
while self.get_task(task_id=task_id, encode=False)["last_updated"] == previous_update_time:
time.sleep(0.5)

def set_message(self, task_id: ObjectId, message: str):
"""Set message for one task. This is displayed on the dashboard."""
previous_update_time = self.get_task(task_id=task_id)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -434,6 +459,9 @@ def set_message(self, task_id: ObjectId, message: str):
}
},
)
# Wait until the status is updated
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)

def set_task_actor_id(self, task_id: ObjectId, message_id: str):
"""
Expand All @@ -443,6 +471,7 @@ def set_task_actor_id(self, task_id: ObjectId, message_id: str):
task_id: the task id of the task
message_id: a uid generated by dramatiq (message_id)
"""
previous_update_time = self.get_task(task_id=task_id)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -452,6 +481,9 @@ def set_task_actor_id(self, task_id: ObjectId, message_id: str):
}
},
)
# Wait until the status is updated
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
time.sleep(0.5)

def mark_task_as_cancelling(self, task_id: ObjectId):
"""
Expand Down

0 comments on commit 6a8b885

Please sign in to comment.