Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better task manager #65

Merged
merged 28 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Set up environment
run: |
pip install --upgrade pip
pip install --quiet .
pip install '.[dev]'
- name: Set up pyright
run: |
npm install -g pyright
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/page.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
with:
python-version: '3.10'
- name: Set up dependencies
run: pip install --quiet .
run: pip install '.[docs]'
- name: Compile sphinx
working-directory: .
run: |
Expand Down
2 changes: 1 addition & 1 deletion alab_management/dashboard/routes/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def cancel_experiment(exp_id: str):
# tasks = experiment_view.get_experiment(exp_id)["tasks"]

for task in tasks:
task_view.mark_task_as_cancelling(task["task_id"])
task_view.mark_task_as_canceling(task["task_id"])
except Exception as e:
return {"status": "error", "reason": e.args[0]}, 400
else:
Expand Down
2 changes: 1 addition & 1 deletion alab_management/dashboard/routes/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def cancel_task(task_id: str):
"""API to cancel a task."""
try:
task_id_obj: ObjectId = ObjectId(task_id)
task_view.mark_task_as_cancelling(task_id_obj)
task_view.mark_task_as_canceling(task_id_obj)

return {"status": "success"}
except Exception as exception:
Expand Down
27 changes: 12 additions & 15 deletions alab_management/experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
done.
"""

import time
from typing import Any

from .config import AlabOSConfig
Expand All @@ -31,7 +30,7 @@ def __init__(self):

config = AlabOSConfig()
self.__copy_to_completed_db = (
"mongodb_completed" in config
"mongodb_completed" in config
) # if this is not defined in the config, assume it this feature is not being used.
if self.__copy_to_completed_db:
self.completed_experiment_view = CompletedExperimentView()
Expand All @@ -47,7 +46,7 @@ def run(self):
)
while True:
self._loop()
time.sleep(1)
# time.sleep()

def _loop(self):
self.handle_pending_experiments()
Expand Down Expand Up @@ -92,10 +91,9 @@ def _handle_pending_experiment(self, experiment: dict[str, Any]):
},
)
if task_graph.has_cycle():
raise ValueError(
"Detect cycle in task graph, which is supposed "
"to be a DAG (directed acyclic graph)."
)
self.experiment_view.update_experiment_status(experiment["_id"], ExperimentStatus.ERROR)
print(f"Experiment ({experiment['_id']}) has a cycle in the graph.")
return

# create samples in the sample database
sample_ids = {
Expand Down Expand Up @@ -158,14 +156,13 @@ def mark_completed_experiments(self):

# if all the tasks of an experiment have been finished
if all(
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
TaskStatus.STOPPED,
odartsi marked this conversation as resolved.
Show resolved Hide resolved
}
for task_id in task_ids
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
):
self.experiment_view.update_experiment_status(
exp_id=experiment["_id"], status=ExperimentStatus.COMPLETED
Expand Down
7 changes: 3 additions & 4 deletions alab_management/experiment_view/experiment_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from .completed_experiment_view import CompletedExperimentView
from .experiment import InputExperiment

completed_experiment_view = CompletedExperimentView()


class ExperimentStatus(Enum):
"""
Expand All @@ -39,6 +37,7 @@ def __init__(self):
self._experiment_collection = get_collection("experiment")
self.sample_view = SampleView()
self.task_view = TaskView()
self.completed_experiment_view = CompletedExperimentView()

def create_experiment(self, experiment: InputExperiment) -> ObjectId:
"""
Expand Down Expand Up @@ -103,7 +102,7 @@ def get_experiment(self, exp_id: ObjectId) -> dict[str, Any] | None:
experiment = self._experiment_collection.find_one({"_id": exp_id})
if experiment is None:
try:
experiment = completed_experiment_view.get_experiment(
experiment = self.completed_experiment_view.get_experiment(
experiment_id=exp_id
)
except ValueError:
Expand Down Expand Up @@ -167,7 +166,7 @@ def get_experiment_by_task_id(self, task_id: ObjectId) -> dict[str, Any] | None:
experiment = self._experiment_collection.find_one({"tasks.task_id": task_id})
if experiment is None:
raise ValueError(f"Cannot find experiment containing task_id: {task_id}")
return experiment
return dict(experiment)

def get_experiment_by_sample_id(self, sample_id: ObjectId) -> dict[str, Any] | None:
"""Get an experiment that contains a sample with the given sample_id."""
Expand Down
2 changes: 1 addition & 1 deletion alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from alab_management.device_view.device import BaseDevice
from alab_management.experiment_view.experiment_view import ExperimentView
from alab_management.logger import DBLogger
from alab_management.resource_manager.resource_requester import ResourceRequester
from alab_management.sample_view.sample import Sample
from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView
from alab_management.task_manager.resource_requester import ResourceRequester
from alab_management.task_view.task import BaseTask
from alab_management.task_view.task_enums import TaskPriority, TaskStatus
from alab_management.task_view.task_view import TaskView
Expand Down
1 change: 1 addition & 0 deletions alab_management/resource_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""This file contains the resource manager module."""
207 changes: 207 additions & 0 deletions alab_management/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
TaskLauncher is the core module of the system,
which actually executes the tasks.
"""

import time
from datetime import datetime
from typing import Any, cast

import dill
from bson import ObjectId

from alab_management.device_view.device_view import DeviceView
from alab_management.logger import DBLogger
from alab_management.resource_manager.enums import _EXTRA_REQUEST
from alab_management.resource_manager.resource_requester import (
RequestMixin,
RequestStatus,
)
from alab_management.sample_view.sample import SamplePosition
from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView
from alab_management.task_view import TaskView
from alab_management.task_view.task_enums import TaskStatus
from alab_management.utils.data_objects import get_collection
from alab_management.utils.module_ops import load_definition


class ResourceManager(RequestMixin):
"""
TaskManager will.

(1) find all the ready tasks and submit them,
(2) handle all the resource requests
"""

def __init__(self):
load_definition()
self.task_view = TaskView()
self.sample_view = SampleView()
self.device_view = DeviceView()
self._request_collection = get_collection("requests")

self.logger = DBLogger(task_id=None)
super().__init__()
time.sleep(1) # allow some time for other modules to launch

def run(self):
"""Start the loop."""
while True:
self._loop()
# time.sleep(1)

def _loop(self):
self.handle_released_resources()
self.handle_requested_resources()

def handle_released_resources(self):
"""Release the resources."""
for request in self.get_requests_by_status(RequestStatus.NEED_RELEASE):
devices = request["assigned_devices"]
sample_positions = request["assigned_sample_positions"]
self._release_devices(devices)
self._release_sample_positions(sample_positions)
self.update_request_status(
request_id=request["_id"], status=RequestStatus.RELEASED
)

def handle_requested_resources(self):
"""
Check if there are any requests that are in PENDING status. If so,
try to assign the resources to it.
"""
requests = list(self.get_requests_by_status(RequestStatus.PENDING))
# prioritize the oldest requests at the highest priority value
requests.sort(key=lambda x: x["submitted_at"])
requests.sort(key=lambda x: x["priority"], reverse=True)
for request in requests:
self._handle_requested_resources(request)

def _handle_requested_resources(self, request_entry: dict[str, Any]):
try:
resource_request = request_entry["request"]
task_id = request_entry["task_id"]

task_status = self.task_view.get_status(task_id=task_id)
if task_status != TaskStatus.REQUESTING_RESOURCES:
odartsi marked this conversation as resolved.
Show resolved Hide resolved
# this implies the Task has been cancelled or errored somewhere else in the chain -- we should
# not allocate any resources to the broken Task.
self.update_request_status(
request_id=request_entry["_id"],
status=RequestStatus.CANCELED,
)
return

devices = self.device_view.request_devices(
task_id=task_id,
device_names_str=[
entry["device"]["content"]
for entry in resource_request
if entry["device"]["identifier"] == "name"
],
device_types_str=[
entry["device"]["content"]
for entry in resource_request
if entry["device"]["identifier"] == "type"
],
)
# some devices are not available now
# the request cannot be fulfilled
if devices is None:
odartsi marked this conversation as resolved.
Show resolved Hide resolved
return

# replace device placeholder in sample position request
# and make it into a single list
parsed_sample_positions_request = []
for request in resource_request:
if request["device"]["identifier"] == _EXTRA_REQUEST:
device_prefix = ""
else:
device_name = devices[request["device"]["content"]]["name"]
device_prefix = f"{device_name}{SamplePosition.SEPARATOR}"

for pos in request["sample_positions"]:
prefix = pos["prefix"]
# if this is a nested resource request, lets not prepend the device name twice.
if not prefix.startswith(device_prefix):
prefix = device_prefix + prefix
parsed_sample_positions_request.append(
SamplePositionRequest(prefix=prefix, number=pos["number"])
)

self._request_collection.update_one(
{"_id": request_entry["_id"]},
{
"$set": {
"parsed_sample_positions_request": [
dict(spr) for spr in parsed_sample_positions_request
]
}
},
)
sample_positions = self.sample_view.request_sample_positions(
task_id=task_id, sample_positions=parsed_sample_positions_request
)
if sample_positions is None:
return

# in case some errors happen, we will raise the error in the task process instead of the main process
except Exception as error: # pylint: disable=broad-except
self._request_collection.update_one(
{"_id": request_entry["_id"]},
{
"$set": {
"status": RequestStatus.ERROR.name,
"error": dill.dumps(error),
"assigned_devices": None,
"assigned_sample_positions": None,
}
},
)
return

# if both devices and sample positions can be satisfied
self._request_collection.update_one(
{"_id": request_entry["_id"]},
{
"$set": {
"assigned_devices": devices,
"assigned_sample_positions": sample_positions,
"status": RequestStatus.FULFILLED.name,
"fulfilled_at": datetime.now(),
}
},
)
# label the resources as occupied
self._occupy_devices(devices=devices, task_id=task_id)
self._occupy_sample_positions(
sample_positions=sample_positions, task_id=task_id
)

def _occupy_devices(self, devices: dict[str, dict[str, Any]], task_id: ObjectId):
for device in devices.values():
self.device_view.occupy_device(
device=cast(str, device["name"]), task_id=task_id
)

def _occupy_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
):
for sample_positions_ in sample_positions.values():
for sample_position_ in sample_positions_:
self.sample_view.lock_sample_position(
task_id, cast(str, sample_position_["name"])
)

def _release_devices(self, devices: dict[str, dict[str, Any]]):
for device in devices.values():
if device["need_release"]:
self.device_view.release_device(device["name"])

def _release_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]]
):
for sample_positions_ in sample_positions.values():
for sample_position in sample_positions_:
if sample_position["need_release"]:
self.sample_view.release_sample_position(sample_position["name"])
Loading
Loading