Skip to content

Commit

Permalink
lint,black,ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardusrendy committed Apr 20, 2024
1 parent 6a8b885 commit d7cff64
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 24 deletions.
20 changes: 14 additions & 6 deletions alab_management/device_view/device_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,13 @@ def set_message(self, device_name: str, message: str):
self.get_device(device_name=device_name)
previous_update_time = self.get_device(device_name=device_name)["last_updated"]
self._device_collection.update_one(
{"name": device_name}, {"$set": {
"message": message,
"last_updated": datetime.now()}}
{"name": device_name},
{"$set": {"message": message, "last_updated": datetime.now()}},
)
while self.get_device(device_name=device_name)["last_updated"] == previous_update_time:
while (
self.get_device(device_name=device_name)["last_updated"]
== previous_update_time
):
time.sleep(0.5)

def get_message(self, device_name: str) -> str:
Expand Down Expand Up @@ -511,7 +513,10 @@ def set_all_attributes(self, device_name: str, attributes: dict):
}
},
)
while self.get_device(device_name=device_name)["last_updated"] == previous_update_time:
while (
self.get_device(device_name=device_name)["last_updated"]
== previous_update_time
):
time.sleep(0.5)

def set_attribute(self, device_name: str, attribute: str, value: Any):
Expand All @@ -534,7 +539,10 @@ def set_attribute(self, device_name: str, attribute: str, value: Any):
}
},
)
while self.get_device(device_name=device_name)["last_updated"] == previous_update_time:
while (
self.get_device(device_name=device_name)["last_updated"]
== previous_update_time
):
time.sleep(0.5)

def pause_device(self, device_name: str):
Expand Down
11 changes: 7 additions & 4 deletions alab_management/experiment_view/completed_experiment_view.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""A wrapper over the ``experiment`` class."""

import time
from typing import Any

from bson import ObjectId # type: ignore
import time

from alab_management.sample_view import CompletedSampleView
from alab_management.task_view import CompletedTaskView
Expand Down Expand Up @@ -48,9 +48,12 @@ def save_experiment(self, experiment_id: ObjectId):
update={"$set": experiment_dict},
)
# wait for the update to complete
while self._completed_experiment_collection.find_one(
{"_id": ObjectId(experiment_id)}
) is None:
while (
self._completed_experiment_collection.find_one(
{"_id": ObjectId(experiment_id)}
)
is None
):
time.sleep(0.5)
else:
self._completed_experiment_collection.insert_one(experiment_dict)
Expand Down
4 changes: 3 additions & 1 deletion alab_management/experiment_view/experiment_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def update_sample_task_id(
update = "not completed"
while update != "completed":
experiment = self.get_experiment(exp_id=exp_id)
updated_sample_ids = [sample["sample_id"] for sample in experiment["samples"]]
updated_sample_ids = [
sample["sample_id"] for sample in experiment["samples"]
]
updated_task_ids = [task["task_id"] for task in experiment["tasks"]]
if updated_sample_ids == sample_ids and updated_task_ids == task_ids:
update = "completed"
Expand Down
10 changes: 6 additions & 4 deletions alab_management/sample_view/completed_sample_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
saving samples to the completed database.
"""

from bson import ObjectId # type: ignore
import time

from bson import ObjectId # type: ignore

from alab_management.utils.data_objects import get_collection, get_completed_collection


Expand Down Expand Up @@ -41,9 +42,10 @@ def save_sample(self, sample_id: ObjectId):
else:
self._completed_sample_collection.insert_one(sample_dict)
# wait for the insert to complete
while self._completed_sample_collection.find_one(
{"_id": ObjectId(sample_id)}
) is None:
while (
self._completed_sample_collection.find_one({"_id": ObjectId(sample_id)})
is None
):
time.sleep(0.5)

def exists(self, sample_id: ObjectId | str) -> bool:
Expand Down
10 changes: 8 additions & 2 deletions alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,14 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
!= "FULFILLED"
):
# handle if the request is cancelled or errored
if (self.get_request(request_entry["_id"], projection=["status"])["status"] == "CANCELED"
or self.get_request(request_entry["_id"], projection=["status"])["status"] == "ERROR"):
if (
self.get_request(request_entry["_id"], projection=["status"])["status"]
== "CANCELED"
or self.get_request(request_entry["_id"], projection=["status"])[
"status"
]
== "ERROR"
):
return
time.sleep(0.5)
# label the resources as occupied
Expand Down
20 changes: 13 additions & 7 deletions alab_management/task_view/task_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_task(self, task_id: ObjectId, encode: bool = False) -> dict[str, Any]:
encode: whether to encode the task using ``self.encode_task`` method
"""
task_id = ObjectId(task_id)

result = self._task_collection.find_one({"_id": task_id})

if result is None:
Expand Down Expand Up @@ -237,7 +237,10 @@ def update_status(self, task_id: ObjectId, status: TaskStatus):
},
)
# Wait until the status is updated
while self.get_task(task_id=next_task_id)["last_updated"] == previous_update_time:
while (
self.get_task(task_id=next_task_id)["last_updated"]
== previous_update_time
):
time.sleep(0.5)
self.try_to_mark_task_ready(
task_id=next_task_id
Expand Down Expand Up @@ -267,9 +270,7 @@ def update_subtask_status(
previous_update_time = self.get_task(task_id=task_id)["last_updated"]
self._task_collection.update_one(
{"_id": task_id},
{"$set": {
"subtasks": subtasks,
"last_updated": datetime.now()}},
{"$set": {"subtasks": subtasks, "last_updated": datetime.now()}},
)
# Wait until the status is updated
while self.get_task(task_id=task_id)["last_updated"] == previous_update_time:
Expand Down Expand Up @@ -430,7 +431,9 @@ def update_task_dependency(
for next_task in next_tasks:
if self.get_task(task_id=next_task) is None:
raise ValueError(f"Non-exist task id: {next_task}")
previous_update_time = self.get_task(task_id=task_id, encode=False)["last_updated"]
previous_update_time = self.get_task(task_id=task_id, encode=False)[
"last_updated"
]
self._task_collection.update_one(
{"_id": task_id},
{
Expand All @@ -444,7 +447,10 @@ def update_task_dependency(
},
)
# Wait until the status is updated
while self.get_task(task_id=task_id, encode=False)["last_updated"] == previous_update_time:
while (
self.get_task(task_id=task_id, encode=False)["last_updated"]
== previous_update_time
):
time.sleep(0.5)

def set_message(self, task_id: ObjectId, message: str):
Expand Down

0 comments on commit d7cff64

Please sign in to comment.