Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/paper-rebuttal' into paper-rebuttal
Browse files Browse the repository at this point in the history
  • Loading branch information
idocx committed Jul 12, 2024
2 parents bed488e + 1ab093f commit 25291c3
Show file tree
Hide file tree
Showing 22 changed files with 516 additions and 156 deletions.
8 changes: 8 additions & 0 deletions alab_management/_default/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ email_password = " "
# the slack_bot_token is the token of the slack bot, you can get it from https://api.slack.com/apps
slack_bot_token = " "
slack_channel_id = " "

[large_result_storage]
# the default storage configuration for tasks that generate large results
# (>16 MB, cannot be contained in MongoDB)
# currently only gridfs is supported
# storage_type is defined by using LargeResult class located in alab_management/task_view/task.py
# you can override this default configuration by setting the storage_type in the task definition
default_storage_type = "gridfs"
18 changes: 9 additions & 9 deletions alab_management/experiment_view/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
from bson import BSON, ObjectId # type: ignore
from pydantic import (
BaseModel,
constr, # pylint: disable=no-name-in-module
validator,
Field,
field_validator,
)


class _Sample(BaseModel):
name: constr(regex=r"^[^$.]+$") # type: ignore
name: str = Field(pattern=r"^[^$.]+$")
sample_id: str | None = None
tags: list[str]
metadata: dict[str, Any]

@validator("sample_id")
@field_validator("sample_id")
def if_provided_must_be_valid_objectid(cls, v):
if v is None:
return # management will autogenerate a valid objectid
Expand All @@ -29,7 +29,7 @@ def if_provided_must_be_valid_objectid(cls, v):
"set to {v}, which is not a valid ObjectId."
) from exc

@validator("metadata")
@field_validator("metadata")
def must_be_bsonable(cls, v):
"""If v is not None, we must confirm that it can be encoded to BSON."""
try:
Expand All @@ -49,7 +49,7 @@ class _Task(BaseModel):
samples: list[str]
task_id: str | None = None

@validator("task_id")
@field_validator("task_id")
def if_provided_must_be_valid_objectid(cls, v):
if v is None:
return # management will autogenerate a valid objectid
Expand All @@ -62,7 +62,7 @@ def if_provided_must_be_valid_objectid(cls, v):
"{v}, which is not a valid ObjectId."
) from exc

@validator("parameters")
@field_validator("parameters")
def must_be_bsonable(cls, v):
"""If v is not None, we must confirm that it can be encoded to BSON."""
try:
Expand All @@ -78,13 +78,13 @@ def must_be_bsonable(cls, v):
class InputExperiment(BaseModel):
"""This is the format that user should follow to write to experiment database."""

name: constr(regex=r"^[^$.]+$") # type: ignore
name: str = Field(pattern=r"^[^$.]+$")
samples: list[_Sample]
tasks: list[_Task]
tags: list[str]
metadata: dict[str, Any]

@validator("metadata")
@field_validator("metadata")
def must_be_bsonable(cls, v):
"""If v is not None, we must confirm that it can be encoded to BSON."""
try:
Expand Down
31 changes: 1 addition & 30 deletions alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
from typing import Any

from bson import ObjectId
from pydantic import root_validator
from pydantic.main import BaseModel

from alab_management.device_manager import DevicesClient
from alab_management.device_view.device import BaseDevice
from alab_management.experiment_view.experiment_view import ExperimentView
from alab_management.logger import DBLogger
from alab_management.resource_manager.resource_requester import ResourceRequester
from alab_management.sample_view.sample import Sample
from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView
from alab_management.sample_view.sample_view import SampleView
from alab_management.task_view.task import BaseTask
from alab_management.task_view.task_enums import TaskPriority, TaskStatus
from alab_management.task_view.task_view import TaskView
Expand All @@ -31,33 +29,6 @@ class DeviceRunningException(Exception):
"""Raise when a task try to release a device that is still running."""


class ResourcesRequest(BaseModel):
"""
This class is used to validate the resource request. Each request should have a format of
[DeviceType: List of SamplePositionRequest].
See Also
--------
:py:class:`SamplePositionRequest <alab_management.sample_view.sample_view.SamplePositionRequest>`
"""

__root__: dict[type[BaseDevice] | None, list[SamplePositionRequest]] # type: ignore

@root_validator(pre=True, allow_reuse=True)
def preprocess(cls, values):
"""Preprocess the request to make sure the request is in the correct format."""
values = values["__root__"]
# if the sample position request is string, we will automatically add a number attribute = 1.
values = {
k: [
SamplePositionRequest.from_str(v_) if isinstance(v_, str) else v_
for v_ in v
]
for k, v in values.items()
}
return {"__root__": values}


class LabView:
"""
LabView is a wrapper over device view and sample view.
Expand Down
33 changes: 22 additions & 11 deletions alab_management/resource_manager/resource_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import dill
from bson import ObjectId
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from pydantic.root_model import RootModel

from alab_management.device_view.device import BaseDevice
from alab_management.device_view.device_view import DeviceView
Expand Down Expand Up @@ -41,7 +42,21 @@ class CombinedTimeoutError(TimeoutError, concurrent.futures.TimeoutError):
"""


class ResourcesRequest(BaseModel):
class DeviceRequest(BaseModel):
"""Pydantic model for device request."""

identifier: str
content: str


class ResourceRequestItem(BaseModel):
"""Pydantic model for resource request item."""

device: DeviceRequest
sample_positions: list[SamplePositionRequest]


class ResourcesRequest(RootModel):
"""
This class is used to validate the resource request. Each request should have a format of
[
Expand All @@ -66,15 +81,11 @@ class ResourcesRequest(BaseModel):
:py:class:`SamplePositionRequest <alab_management.sample_view.sample_view.SamplePositionRequest>`
"""

__root__: list[
dict[str, list[dict[str, str | int]] | dict[str, str]]
] # type: ignore
root: list[ResourceRequestItem]

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def preprocess(cls, values):
"""Preprocess the request."""
values = values["__root__"]

new_values = []
for request_dict in values:
if request_dict["device"]["identifier"] not in [
Expand All @@ -94,7 +105,7 @@ def preprocess(cls, values):
"sample_positions": request_dict["sample_positions"],
}
)
return {"__root__": new_values}
return new_values


class RequestMixin:
Expand Down Expand Up @@ -237,8 +248,8 @@ def request_resources(
)

if not isinstance(formatted_resource_request, ResourcesRequest):
formatted_resource_request = ResourcesRequest(__root__=formatted_resource_request) # type: ignore
formatted_resource_request = formatted_resource_request.dict()["__root__"]
formatted_resource_request = ResourcesRequest(root=formatted_resource_request) # type: ignore
formatted_resource_request = formatted_resource_request.model_dump(mode="json")

result = self._request_collection.insert_one(
{
Expand Down
8 changes: 3 additions & 5 deletions alab_management/sample_view/sample_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pymongo # type: ignore
from bson import ObjectId # type: ignore
from pydantic import BaseModel, conint
from pydantic import BaseModel, ConfigDict, conint

from alab_management.utils.data_objects import get_collection, get_lock

Expand All @@ -23,10 +23,8 @@ class SamplePositionRequest(BaseModel):
the number you request. By default, the number is set to be 1.
"""

class Config:
"""raise error when extra kwargs."""

extra = "forbid"
# raise error when extra kwargs are passed
model_config = ConfigDict(extra="forbid")

prefix: str
number: conint(ge=0) = 1 # type: ignore
Expand Down
48 changes: 31 additions & 17 deletions alab_management/task_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from alab_management.sample_view import SampleView
from alab_management.task_view import BaseTask, TaskStatus, TaskView
from alab_management.task_view.task import LargeResult
from alab_management.utils.data_objects import make_bsonable
from alab_management.utils.middleware import register_abortable_middleware
from alab_management.utils.module_ops import load_definition

Expand Down Expand Up @@ -170,6 +171,7 @@ def run_task(task_id_str: str):
# assume that all field are replaced by the value if the result is a pydantic model
# convert pydantic model to dict
dict_result = result.model_dump(mode="python")
bsonable_value = make_bsonable(dict_result)
for key, value in dict_result.items():
task_view.update_result(task_id=task_id, name=key, value=value)
elif isinstance(result, dict):
Expand All @@ -189,25 +191,37 @@ def run_task(task_id_str: str):
result = task_view.get_task(task_id=task_id)["result"]
if isinstance(result, dict):
try:
encoded_result = task.result_specification(**result)
# if it is consistent, check if any field is a LargeResult
# if so, ensure that it is stored properly
for key, value in encoded_result.items():
if (
isinstance(value, LargeResult)
and not value.check_if_stored()
):
try:
value.store()
except Exception as e:
# if storing fails, log the error and continue
print(
f"WARNING: Failed to store LargeResult {key} for task_id {task_id_str}: {e}"
)
except Exception as e:
model = task.result_specification
encoded_result = model(**result)
# if it is consistent, check which fields are LargeResults
# if so, ensure that they are stored properly
for key, value in dict(encoded_result).items():
if isinstance(value, LargeResult):
if not value.check_if_stored():
try:
# get storage type from the config file
value.store()
# update the LargeResult entry in the MongoDB for the corresponding field in the task result
value_as_dict = value.model_dump(mode="python")
# ensure bson serializable
bsonable_value = make_bsonable(value_as_dict)
task_view.update_result(
task_id=task_id, name=key, value=bsonable_value
)
except Exception:
# if storing fails, log the error and continue
print(
f"WARNING: Failed to store LargeResult {key} for task_id {task_id_str}."
f"{format_exc()}"
)
else:
pass
except Exception:
print(
f"WARNING: Task result for task_id {task_id_str} is inconsistent with the task result specification: {e}"
f"WARNING: Task result for task_id {task_id_str} is inconsistent with the task result specification."
f"{format_exc()}"
)
print()
else:
print(
f"WARNING: Task result for task_id {task_id_str} is not a dictionary, but a {type(result)}."
Expand Down
Loading

0 comments on commit 25291c3

Please sign in to comment.