diff --git a/alab_management/_default/config.toml b/alab_management/_default/config.toml index 31f90d70..f7db3124 100644 --- a/alab_management/_default/config.toml +++ b/alab_management/_default/config.toml @@ -33,3 +33,11 @@ email_password = " " # the slack_bot_token is the token of the slack bot, you can get it from https://api.slack.com/apps slack_bot_token = " " slack_channel_id = " " + +[large_result_storage] +# the default storage configuration for tasks that generate large results +# (>16 MB, cannot be contained in MongoDB) +# currently only gridfs is supported +# storage_type is defined by using LargeResult class located in alab_management/task_view/task.py +# you can override this default configuration by setting the storage_type in the task definition +default_storage_type = "gridfs" diff --git a/alab_management/experiment_view/experiment.py b/alab_management/experiment_view/experiment.py index 337f4694..53e7ea52 100644 --- a/alab_management/experiment_view/experiment.py +++ b/alab_management/experiment_view/experiment.py @@ -5,18 +5,18 @@ from bson import BSON, ObjectId # type: ignore from pydantic import ( BaseModel, - constr, # pylint: disable=no-name-in-module - validator, + Field, + field_validator, ) class _Sample(BaseModel): - name: constr(regex=r"^[^$.]+$") # type: ignore + name: str = Field(pattern=r"^[^$.]+$") sample_id: str | None = None tags: list[str] metadata: dict[str, Any] - @validator("sample_id") + @field_validator("sample_id") def if_provided_must_be_valid_objectid(cls, v): if v is None: return # management will autogenerate a valid objectid @@ -29,7 +29,7 @@ def if_provided_must_be_valid_objectid(cls, v): "set to {v}, which is not a valid ObjectId." ) from exc - @validator("metadata") + @field_validator("metadata") def must_be_bsonable(cls, v): """If v is not None, we must confirm that it can be encoded to BSON.""" try: @@ -49,7 +49,7 @@ class _Task(BaseModel): samples: list[str] task_id: str | None = None - @validator("task_id") + @field_validator("task_id") def if_provided_must_be_valid_objectid(cls, v): if v is None: return # management will autogenerate a valid objectid @@ -62,7 +62,7 @@ def if_provided_must_be_valid_objectid(cls, v): "{v}, which is not a valid ObjectId." ) from exc - @validator("parameters") + @field_validator("parameters") def must_be_bsonable(cls, v): """If v is not None, we must confirm that it can be encoded to BSON.""" try: @@ -78,13 +78,13 @@ def must_be_bsonable(cls, v): class InputExperiment(BaseModel): """This is the format that user should follow to write to experiment database.""" - name: constr(regex=r"^[^$.]+$") # type: ignore + name: str = Field(pattern=r"^[^$.]+$") samples: list[_Sample] tasks: list[_Task] tags: list[str] metadata: dict[str, Any] - @validator("metadata") + @field_validator("metadata") def must_be_bsonable(cls, v): """If v is not None, we must confirm that it can be encoded to BSON.""" try: diff --git a/alab_management/lab_view.py b/alab_management/lab_view.py index b6a05bfd..d607016e 100644 --- a/alab_management/lab_view.py +++ b/alab_management/lab_view.py @@ -11,8 +11,6 @@ from typing import Any from bson import ObjectId -from pydantic import root_validator -from pydantic.main import BaseModel from alab_management.device_manager import DevicesClient from alab_management.device_view.device import BaseDevice @@ -20,7 +18,7 @@ from alab_management.logger import DBLogger from alab_management.resource_manager.resource_requester import ResourceRequester from alab_management.sample_view.sample import Sample -from alab_management.sample_view.sample_view import SamplePositionRequest, SampleView +from alab_management.sample_view.sample_view import SampleView from alab_management.task_view.task import BaseTask from alab_management.task_view.task_enums import TaskPriority, TaskStatus from alab_management.task_view.task_view import TaskView @@ -31,33 +29,6 @@ class DeviceRunningException(Exception): """Raise when a task try to release a device that is still running.""" -class ResourcesRequest(BaseModel): - """ - This class is used to validate the resource request. Each request should have a format of - [DeviceType: List of SamplePositionRequest]. - - See Also - -------- - :py:class:`SamplePositionRequest ` - """ - - __root__: dict[type[BaseDevice] | None, list[SamplePositionRequest]] # type: ignore - - @root_validator(pre=True, allow_reuse=True) - def preprocess(cls, values): - """Preprocess the request to make sure the request is in the correct format.""" - values = values["__root__"] - # if the sample position request is string, we will automatically add a number attribute = 1. - values = { - k: [ - SamplePositionRequest.from_str(v_) if isinstance(v_, str) else v_ - for v_ in v - ] - for k, v in values.items() - } - return {"__root__": values} - - class LabView: """ LabView is a wrapper over device view and sample view. diff --git a/alab_management/resource_manager/resource_requester.py b/alab_management/resource_manager/resource_requester.py index 066d07fc..451e371b 100644 --- a/alab_management/resource_manager/resource_requester.py +++ b/alab_management/resource_manager/resource_requester.py @@ -13,7 +13,8 @@ import dill from bson import ObjectId -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, model_validator +from pydantic.root_model import RootModel from alab_management.device_view.device import BaseDevice from alab_management.device_view.device_view import DeviceView @@ -41,7 +42,21 @@ class CombinedTimeoutError(TimeoutError, concurrent.futures.TimeoutError): """ -class ResourcesRequest(BaseModel): +class DeviceRequest(BaseModel): + """Pydantic model for device request.""" + + identifier: str + content: str + + +class ResourceRequestItem(BaseModel): + """Pydantic model for resource request item.""" + + device: DeviceRequest + sample_positions: list[SamplePositionRequest] + + +class ResourcesRequest(RootModel): """ This class is used to validate the resource request. Each request should have a format of [ @@ -66,15 +81,11 @@ class ResourcesRequest(BaseModel): :py:class:`SamplePositionRequest ` """ - __root__: list[ - dict[str, list[dict[str, str | int]] | dict[str, str]] - ] # type: ignore + root: list[ResourceRequestItem] - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def preprocess(cls, values): """Preprocess the request.""" - values = values["__root__"] - new_values = [] for request_dict in values: if request_dict["device"]["identifier"] not in [ @@ -94,7 +105,7 @@ def preprocess(cls, values): "sample_positions": request_dict["sample_positions"], } ) - return {"__root__": new_values} + return new_values class RequestMixin: @@ -237,8 +248,8 @@ def request_resources( ) if not isinstance(formatted_resource_request, ResourcesRequest): - formatted_resource_request = ResourcesRequest(__root__=formatted_resource_request) # type: ignore - formatted_resource_request = formatted_resource_request.dict()["__root__"] + formatted_resource_request = ResourcesRequest(root=formatted_resource_request) # type: ignore + formatted_resource_request = formatted_resource_request.model_dump(mode="json") result = self._request_collection.insert_one( { diff --git a/alab_management/sample_view/sample_view.py b/alab_management/sample_view/sample_view.py index 62313835..38330610 100644 --- a/alab_management/sample_view/sample_view.py +++ b/alab_management/sample_view/sample_view.py @@ -8,7 +8,7 @@ import pymongo # type: ignore from bson import ObjectId # type: ignore -from pydantic import BaseModel, conint +from pydantic import BaseModel, ConfigDict, conint from alab_management.utils.data_objects import get_collection, get_lock @@ -23,10 +23,8 @@ class SamplePositionRequest(BaseModel): the number you request. By default, the number is set to be 1. """ - class Config: - """raise error when extra kwargs.""" - - extra = "forbid" + # raise error when extra kwargs are passed + model_config = ConfigDict(extra="forbid") prefix: str number: conint(ge=0) = 1 # type: ignore diff --git a/alab_management/task_actor.py b/alab_management/task_actor.py index 79984632..ec19504c 100644 --- a/alab_management/task_actor.py +++ b/alab_management/task_actor.py @@ -15,6 +15,7 @@ from alab_management.sample_view import SampleView from alab_management.task_view import BaseTask, TaskStatus, TaskView from alab_management.task_view.task import LargeResult +from alab_management.utils.data_objects import make_bsonable from alab_management.utils.middleware import register_abortable_middleware from alab_management.utils.module_ops import load_definition @@ -170,6 +171,7 @@ def run_task(task_id_str: str): # assume that all field are replaced by the value if the result is a pydantic model # convert pydantic model to dict dict_result = result.model_dump(mode="python") + bsonable_value = make_bsonable(dict_result) for key, value in dict_result.items(): task_view.update_result(task_id=task_id, name=key, value=value) elif isinstance(result, dict): @@ -189,25 +191,37 @@ def run_task(task_id_str: str): result = task_view.get_task(task_id=task_id)["result"] if isinstance(result, dict): try: - encoded_result = task.result_specification(**result) - # if it is consistent, check if any field is a LargeResult - # if so, ensure that it is stored properly - for key, value in encoded_result.items(): - if ( - isinstance(value, LargeResult) - and not value.check_if_stored() - ): - try: - value.store() - except Exception as e: - # if storing fails, log the error and continue - print( - f"WARNING: Failed to store LargeResult {key} for task_id {task_id_str}: {e}" - ) - except Exception as e: + model = task.result_specification + encoded_result = model(**result) + # if it is consistent, check which fields are LargeResults + # if so, ensure that they are stored properly + for key, value in dict(encoded_result).items(): + if isinstance(value, LargeResult): + if not value.check_if_stored(): + try: + # get storage type from the config file + value.store() + # update the LargeResult entry in the MongoDB for the corresponding field in the task result + value_as_dict = value.model_dump(mode="python") + # ensure bson serializable + bsonable_value = make_bsonable(value_as_dict) + task_view.update_result( + task_id=task_id, name=key, value=bsonable_value + ) + except Exception: + # if storing fails, log the error and continue + print( + f"WARNING: Failed to store LargeResult {key} for task_id {task_id_str}." + f"{format_exc()}" + ) + else: + pass + except Exception: print( - f"WARNING: Task result for task_id {task_id_str} is inconsistent with the task result specification: {e}" + f"WARNING: Task result for task_id {task_id_str} is inconsistent with the task result specification." + f"{format_exc()}" ) + print() else: print( f"WARNING: Task result for task_id {task_id_str} is not a dictionary, but a {type(result)}." diff --git a/alab_management/task_view/task.py b/alab_management/task_view/task.py index 241042de..2efcfc17 100644 --- a/alab_management/task_view/task.py +++ b/alab_management/task_view/task.py @@ -1,27 +1,29 @@ """Define the base class of task, which will be used for defining more tasks.""" import inspect +import time from abc import ABC, abstractmethod from inspect import getfullargspec from pathlib import Path from typing import TYPE_CHECKING, Any, Optional import gridfs -from bson import BSON -from bson.errors import InvalidBSON from bson.objectid import ObjectId -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, ConfigDict, model_validator from alab_management.builders.experimentbuilder import ExperimentBuilder from alab_management.builders.samplebuilder import SampleBuilder from alab_management.config import AlabOSConfig -from alab_management.device_view.device import BaseDevice from alab_management.task_view.task_enums import TaskPriority from alab_management.utils.data_objects import _GetMongoCollection if TYPE_CHECKING: + from alab_management.device_view.device import BaseDevice from alab_management.lab_view import LabView +config = AlabOSConfig() +default_storage_type = str(config["large_result_storage"]["default_storage_type"]) + class LargeResult(BaseModel): """ @@ -29,17 +31,19 @@ class LargeResult(BaseModel): Stored in either gridFS or other filesystems (Cloud AWS S3, etc.). """ - storage_type: str + storage_type: str = default_storage_type # The path to the local file, used for uploading - local_path: str | Path | None + local_path: str | Path | None = None # The identifier of the file in the storage system, can be a path or a key (e.g., gridfs id) # Obtained after storing the file, used for retrieving - identifier: str | None + identifier: str | ObjectId | None = None # alternative to local path, used for uploading, local path has higher priority - file_like_data: Any | None + file_like_data: Any | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) # if file_like_data is provided check if it has .read() method - @root_validator(pre=True) + @model_validator(mode="before") def check_file_like_data(cls, values): """Check if file_like_data has a .read() method.""" file_like_data = values.get("file_like_data") @@ -48,11 +52,20 @@ def check_file_like_data(cls, values): return values def store(self): - """Store the large result in the storage system.""" + """ + Store the large result in the storage system. + This method should block until the result is confirmed to be stored. + This method should have a timeout regardless of the storage system to not block indefinitely. + """ if self.storage_type == "gridfs": _GetMongoCollection.init() config = AlabOSConfig() - db = _GetMongoCollection.client.get_database(config["general"]["name"]) + if config.is_sim_mode(): + db = _GetMongoCollection.client.get_database( + config["general"]["name"] + "_sim" + ) + else: + db = _GetMongoCollection.client.get_database(config["general"]["name"]) fs = gridfs.GridFS(db) if self.local_path: with open(self.local_path, "rb") as file: @@ -64,16 +77,35 @@ def store(self): "Either local_path or serializable_data must be provided for storing in gridfs." ) self.identifier = file_id + # check if the file is stored, wait until it is stored for maximum 10 seconds + for _ in range(10): + if fs.exists(file_id): + return + time.sleep(1) + raise ValueError(f"File with identifier {file_id} failed to be stored.") else: raise ValueError("Only gridfs storage is supported for now.") def retrieve(self): """Retrieve the large result from the storage system.""" if self.storage_type == "gridfs": + if self.identifier is None: + raise ValueError( + "Identifier is not provided for retrieving from gridfs." + ) _GetMongoCollection.init() config = AlabOSConfig() - db = _GetMongoCollection.client.get_database(config["general"]["name"]) + if config.is_sim_mode(): + db = _GetMongoCollection.client.get_database( + config["general"]["name"] + "_sim" + ) + else: + db = _GetMongoCollection.client.get_database(config["general"]["name"]) fs = gridfs.GridFS(db) + if fs.get(self.identifier) is None: + raise ValueError( + f"File with identifier {self.identifier} does not exist." + ) return fs.get(self.identifier).read() else: raise ValueError("Only gridfs storage is supported for now.") @@ -81,9 +113,16 @@ def retrieve(self): def check_if_stored(self): """Check if the large result is stored in the storage system.""" if self.storage_type == "gridfs": + if self.identifier is None: + return False _GetMongoCollection.init() config = AlabOSConfig() - db = _GetMongoCollection.client.get_database(config["general"]["name"]) + if config.is_sim_mode(): + db = _GetMongoCollection.client.get_database( + config["general"]["name"] + "_sim" + ) + else: + db = _GetMongoCollection.client.get_database(config["general"]["name"]) fs = gridfs.GridFS(db) return fs.exists(self.identifier) else: @@ -177,6 +216,7 @@ def priority(self) -> int: return 0 return self.lab_view._resource_requester.priority + @property def result_specification(self) -> BaseModel | None: """ Returns a pydantic model describing the results to be generated by this task. @@ -194,63 +234,68 @@ def result_specification(self) -> BaseModel | None: """ return None - def update_result(self, key: str, value: Any): - """Attach a result to the task. This will be saved in the database and - can be accessed later. Subsequent calls to this function with the same - key will overwrite the previous value. - - Args: - key (str): The name of the result. - value (Any): The value of the result. - """ - if key not in self.result_specification: - raise ValueError( - f"Result key {key} is not included in the result specification for this task!" - ) - - # check if value is bson serializable - try: - BSON.encode({key: value}) - except (InvalidBSON, TypeError) as e: - raise ValueError( - f"Value {value} for key {key} is not BSON serializable!" - ) from e - - if not self.__offline: - self.lab_view.update_result(name=key, value=value) - else: - raise ValueError("Cannot update a result for an offline task!") - - def export_result(self, encode: bool = False) -> type[BaseModel]: - """ - Export all data from the result. - - Args - ---- - encode: bool, optional. If True, the result will be encoded into the Pydantic model defined in result_specification. - If False, the raw result will be returned. - - Returns - ------- - BaseModel: A Pydantic model describing the results generated by this task. - """ - if self._is_taskid_generated: - raise ValueError( - "Cannot export a result from a task with an automatically generated task_id!" - ) - if not self.__offline: - results = self.lab_view._task_view.get_task(task_id=self.task_id)["results"] - if encode: - try: - return self.result_specification(**results) - except Exception as e: - raise ValueError( - f"Task result is inconsistent with the task result specification: {e}" - ) from e - else: - return results - else: - raise ValueError("Cannot export a result from an offline task!") + # TODO: Delete these two methods because task_view are better suited + # since both requires DB access + + # def update_result(self, key: str, value: Any): + # """Attach a result to the task. This will be saved in the database and + # can be accessed later. Subsequent calls to this function with the same + # key will overwrite the previous value. + + # Args: + # key (str): The name of the result. + # value (Any): The value of the result. + # """ + # if key not in self.result_specification: + # raise ValueError( + # f"Result key {key} is not included in the result specification for this task!" + # ) + + # # check if value is bson serializable + # try: + # BSON.encode({key: value}) + # except (InvalidBSON, TypeError) as e: + # raise ValueError( + # f"Value {value} for key {key} is not BSON serializable!" + # ) from e + + # if not self.__offline: + # self.lab_view.update_result(name=key, value=value) + # else: + # raise ValueError("Cannot update a result for an offline task!") + + # def export_result(self, encode: bool = False) -> type[BaseModel]: + # """ + # Export all data from the result. + + # Args + # ---- + # encode: bool, optional. If True, the result will be encoded into the Pydantic model defined in result_specification. + # If False, the raw result will be returned. + + # Returns + # ------- + # BaseModel: A Pydantic model describing the results generated by this task. + # """ + # if self._is_taskid_generated: + # raise ValueError( + # "Cannot export a result from a task with an automatically generated task_id!" + # ) + # if not self.__offline: + # results = self.lab_view._task_view.get_task(task_id=self.task_id)["result"] + # if encode: + # try: + # model=self.result_specification + # return model(**results) + # except Exception as e: + # raise ValueError( + # f"Task result is inconsistent with the task result specification." + # f"{format_exc()}" + # ) from e + # else: + # return results + # else: + # raise ValueError("Cannot export a result from an offline task!") @priority.setter def priority(self, value: int | TaskPriority): @@ -377,7 +422,7 @@ def add_to( _task_registry: dict[str, type[BaseTask]] = {} -SUPPORTED_SAMPLE_POSITIONS_TYPE = dict[type[BaseDevice] | str | None, str | list[str]] +SUPPORTED_SAMPLE_POSITIONS_TYPE = dict[type["BaseDevice"] | str | None, str | list[str]] _reroute_task_registry: list[ dict[str, type[BaseTask] | SUPPORTED_SAMPLE_POSITIONS_TYPE] ] = [] diff --git a/alab_management/task_view/task_view.py b/alab_management/task_view/task_view.py index 441957a7..bb6f40a6 100644 --- a/alab_management/task_view/task_view.py +++ b/alab_management/task_view/task_view.py @@ -260,9 +260,13 @@ def update_result( """ Update result to completed job. - Args: task_id: the id of task to be updated name: the name of the result to be updated. If ``None``, - will update the entire ``result`` field. Otherwise, will update the field ``result.name``. value: the value - to be stored. This must be bson-encodable (ie can be written into MongoDB!) + Args: + task_id: the id of task to be updated name: the name of the result to be updated. If ``None``, + will update the entire ``result`` field. Otherwise, will update the field ``result.name``. value: the value + to be stored. This must be bson-encodable (ie can be written into MongoDB!) + name: the name of the result to be updated. If ``None``, will update the entire ``result`` field. + Otherwise, will update the field ``result.name``. + value: the value to be stored. This must be bson-encodable (i.e. can be written into MongoDB!) """ _ = self.get_task( task_id=task_id diff --git a/alab_management/utils/data_objects.py b/alab_management/utils/data_objects.py index 4dffd800..5546792a 100644 --- a/alab_management/utils/data_objects.py +++ b/alab_management/utils/data_objects.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum +from pathlib import Path import numpy as np import pika @@ -115,6 +116,8 @@ def make_bsonable(obj): elif isinstance(obj, str): with contextlib.suppress(Exception): obj = ObjectId(obj) + elif isinstance(obj, Path): + obj = str(obj) return obj diff --git a/pyproject.toml b/pyproject.toml index 4e0b454a..0e7bc40c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "toml>=0.10.1", "pymongo>=3.12.3", "flask>=2.2.2", - "pydantic==1.10.2", + "pydantic==2.8.2", "click", "gevent>=21.8.0", "monty>=2022.9.9", diff --git a/tests/fake_lab/__init__.py b/tests/fake_lab/__init__.py index 55b30fd2..ed0b4690 100644 --- a/tests/fake_lab/__init__.py +++ b/tests/fake_lab/__init__.py @@ -16,6 +16,7 @@ from .tasks.infinite_task import InfiniteTask from .tasks.moving import Moving from .tasks.starting import Starting +from .tasks.take_picture import TakePicture, TakePictureMissingResult add_device(Furnace(name="furnace_1")) add_device(Furnace(name="furnace_2")) @@ -49,3 +50,5 @@ add_task(ErrorHandlingUnrecoverable) add_task(ErrorHandlingRecoverable) add_task(InfiniteTask) +add_task(TakePicture) +add_task(TakePictureMissingResult) diff --git a/tests/fake_lab/config.toml b/tests/fake_lab/config.toml index 307aa652..957192db 100644 --- a/tests/fake_lab/config.toml +++ b/tests/fake_lab/config.toml @@ -24,3 +24,7 @@ host = 'localhost' port = 27017 username = '' password = '' + +[large_result_storage] +default_storage_type = "gridfs" + diff --git a/tests/fake_lab/devices/robot_arm.py b/tests/fake_lab/devices/robot_arm.py index 48e5cf03..10064038 100644 --- a/tests/fake_lab/devices/robot_arm.py +++ b/tests/fake_lab/devices/robot_arm.py @@ -1,3 +1,5 @@ +from importlib import util +from pathlib import Path from typing import ClassVar from alab_management.device_view import BaseDevice @@ -22,6 +24,16 @@ def emergent_stop(self): def run_program(self, program): pass + def get_most_recent_picture_location(self): + return ( + Path( + util.find_spec("alab_management").origin.split("__init__.py")[0] + ).parent + / "tests" + / "fake_lab" + / "large_file_example.zip" + ) + def is_running(self) -> bool: return False diff --git a/tests/fake_lab/large_file_example.zip b/tests/fake_lab/large_file_example.zip new file mode 100644 index 00000000..ca5a1e9c Binary files /dev/null and b/tests/fake_lab/large_file_example.zip differ diff --git a/tests/fake_lab/tasks/ending.py b/tests/fake_lab/tasks/ending.py index 7c77599b..41c284e5 100644 --- a/tests/fake_lab/tasks/ending.py +++ b/tests/fake_lab/tasks/ending.py @@ -4,7 +4,7 @@ class Ending(BaseTask): - def __init__(self, samples: list[ObjectId], *args, **kwargs): + def __init__(self, samples: list[str | ObjectId], *args, **kwargs): super().__init__(samples=samples, *args, **kwargs) self.sample = samples[0] diff --git a/tests/fake_lab/tasks/error_handling_task.py b/tests/fake_lab/tasks/error_handling_task.py index eb8c6b96..6008d9d0 100644 --- a/tests/fake_lab/tasks/error_handling_task.py +++ b/tests/fake_lab/tasks/error_handling_task.py @@ -6,7 +6,7 @@ class ErrorHandlingUnrecoverable(BaseTask): - def __init__(self, samples: list[ObjectId], *args, **kwargs): + def __init__(self, samples: list[str | ObjectId], *args, **kwargs): super().__init__(samples=samples, *args, **kwargs) self.sample = samples[0] @@ -20,7 +20,7 @@ def run(self): class ErrorHandlingRecoverable(BaseTask): - def __init__(self, samples: list[ObjectId], *args, **kwargs): + def __init__(self, samples: list[str | ObjectId], *args, **kwargs): super().__init__(samples=samples, *args, **kwargs) self.sample = samples[0] diff --git a/tests/fake_lab/tasks/heating.py b/tests/fake_lab/tasks/heating.py index d140eb71..54547293 100644 --- a/tests/fake_lab/tasks/heating.py +++ b/tests/fake_lab/tasks/heating.py @@ -11,7 +11,7 @@ class Heating(BaseTask): def __init__( self, - samples: list[ObjectId], + samples: list[str | ObjectId], setpoints: list[tuple[float, float]], *args, **kwargs, @@ -19,7 +19,7 @@ def __init__( """Heating task. Args: - samples (list[ObjectId]): List of sample ids. + samples (list[str|ObjectId]): List of sample names or ids. setpoints (list[tuple[float, float]]): List of setpoints as the heating profile. Since it is a fake lab, the setpoints are just a list of tuples, each tuple contains two float numbers, the first number is the temperature, and the second number is the time in seconds. diff --git a/tests/fake_lab/tasks/infinite_task.py b/tests/fake_lab/tasks/infinite_task.py index abb700f0..d7fdb5b1 100644 --- a/tests/fake_lab/tasks/infinite_task.py +++ b/tests/fake_lab/tasks/infinite_task.py @@ -6,11 +6,11 @@ class InfiniteTask(BaseTask): - def __init__(self, samples: list[ObjectId], *args, **kwargs): + def __init__(self, samples: list[str | ObjectId], *args, **kwargs): """Infinite task. Args: - samples (list[ObjectId]): List of sample ids. + samples (list[str|ObjectId]): List of sample names or sample IDs. """ super().__init__(samples=samples, *args, **kwargs) self.sample = samples[0] diff --git a/tests/fake_lab/tasks/moving.py b/tests/fake_lab/tasks/moving.py index 9fbd0eb3..4648a829 100644 --- a/tests/fake_lab/tasks/moving.py +++ b/tests/fake_lab/tasks/moving.py @@ -8,7 +8,7 @@ class Moving(BaseTask): - def __init__(self, samples: list[ObjectId], dest: str, *args, **kwargs): + def __init__(self, samples: list[str | ObjectId], dest: str, *args, **kwargs): super().__init__(samples=samples, *args, **kwargs) self.sample = samples[0] self.dest = dest diff --git a/tests/fake_lab/tasks/starting.py b/tests/fake_lab/tasks/starting.py index 545fefc9..cea86069 100644 --- a/tests/fake_lab/tasks/starting.py +++ b/tests/fake_lab/tasks/starting.py @@ -6,7 +6,7 @@ class Starting(BaseTask): - def __init__(self, samples: list[ObjectId], dest: str, *args, **kwargs): + def __init__(self, samples: list[str | ObjectId], dest: str, *args, **kwargs): super().__init__(samples=samples, *args, **kwargs) self.sample = samples[0] self.dest = dest diff --git a/tests/fake_lab/tasks/take_picture.py b/tests/fake_lab/tasks/take_picture.py new file mode 100644 index 00000000..a534e6c7 --- /dev/null +++ b/tests/fake_lab/tasks/take_picture.py @@ -0,0 +1,68 @@ +import datetime + +from bson import ObjectId +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from alab_management.task_view.task import BaseTask, LargeResult + +from ..devices.robot_arm import RobotArm # noqa + + +class TakePictureResult(BaseModel): + sample_name: str | None = None + sample_id: ObjectId | None = None + picture: LargeResult + timestamp: datetime.datetime = Field(default_factory=datetime.datetime.now) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("timestamp", mode="before") + def validate_timestamp(cls, v): + if isinstance(v, str): + return datetime.datetime.fromisoformat(v) + return v + + +class TakePicture(BaseTask): + def __init__(self, samples: list[str | ObjectId], *args, **kwargs): + super().__init__(samples=samples, *args, **kwargs) + self.sample = samples[0] + + @property + def result_specification(self) -> BaseModel: + return TakePictureResult + + def run(self): + with self.lab_view.request_resources({RobotArm: {}}) as ( + devices, + sample_positions, + ): + robot_arm: RobotArm = devices[RobotArm] + robot_arm.run_program("take_picture.urp") + sample_id = self.lab_view.get_sample(self.sample).sample_id + picture_location = robot_arm.get_most_recent_picture_location() + picture = LargeResult(local_path=picture_location) + # by doing this, checking is done during running. + # this can lead to task being marked as failed if result does not meet the specification. + return TakePictureResult( + sample_id=sample_id, picture=picture, timestamp=datetime.datetime.now() + ) + + +class TakePictureMissingResult(BaseTask): + def __init__(self, samples: list[str | ObjectId], *args, **kwargs): + super().__init__(samples=samples, *args, **kwargs) + self.sample = samples[0] + + @property + def result_specification(self) -> BaseModel: + return TakePictureResult + + def run(self): + with self.lab_view.request_resources({RobotArm: {}}) as ( + devices, + sample_positions, + ): + robot_arm: RobotArm = devices[RobotArm] + robot_arm.run_program("take_picture.urp") + return {"timestamp": datetime.datetime.now()} diff --git a/tests/test_task_actor.py b/tests/test_task_actor.py new file mode 100644 index 00000000..6abfd8f7 --- /dev/null +++ b/tests/test_task_actor.py @@ -0,0 +1,219 @@ +import subprocess +import time +import unittest +from importlib import util +from pathlib import Path + +import requests +from bson import ObjectId + +from alab_management.experiment_view import ExperimentView +from alab_management.scripts.cleanup_lab import cleanup_lab +from alab_management.scripts.setup_lab import setup_lab +from alab_management.task_view import TaskView +from alab_management.task_view.task import LargeResult + +SUBMISSION_API = "http://127.0.0.1:8896/api/experiment/submit" + + +class TestTaskActor(unittest.TestCase): + def setUp(self) -> None: + time.sleep(0.5) + cleanup_lab( + all_collections=True, + _force_i_know_its_dangerous=True, + sim_mode=True, + database_name="Alab_sim", + user_confirmation="y", + ) + setup_lab() + self.task_view = TaskView() + self.experiment_view = ExperimentView() + self.main_process = subprocess.Popen( + ["alabos", "launch", "--port", "8896"], shell=False + ) + self.worker_process = subprocess.Popen( + ["alabos", "launch_worker", "--processes", "8", "--threads", "16"], + shell=False, + ) + time.sleep(2) # waiting for starting up + + if self.main_process.poll() is not None: + raise RuntimeError("Main process failed to start") + if self.worker_process.poll() is not None: + raise RuntimeError("Worker process failed to start") + + def tearDown(self) -> None: + self.main_process.terminate() + self.worker_process.terminate() + time.sleep(5) + cleanup_lab( + all_collections=True, + _force_i_know_its_dangerous=True, + sim_mode=True, + database_name="Alab_sim", + user_confirmation="y", + ) + + def test_experiment_with_large_result(self): + # clean up lab and task collection + cleanup_lab( + all_collections=True, + _force_i_know_its_dangerous=True, + sim_mode=True, + database_name="Alab_sim", + user_confirmation="y", + ) + # setup lab + setup_lab() + + # run an experiment with large result + def compose_exp(exp_name): + return { + "name": exp_name, + "tags": [], + "metadata": {}, + "samples": [{"name": "test_sample", "tags": [], "metadata": {}}], + "tasks": [ + { + "type": "Starting", + "prev_tasks": [], + "parameters": { + "dest": "furnace_temp", + }, + "samples": ["test_sample"], + }, + { + "type": "TakePicture", + "prev_tasks": [0], + "parameters": {}, + "samples": ["test_sample"], + }, + ], + } + + exp_name = "Experiment with large result" + experiment = compose_exp(exp_name) + resp = requests.post(SUBMISSION_API, json=experiment) + resp_json = resp.json() + exp_id = ObjectId(resp_json["data"]["exp_id"]) + self.assertTrue("success", resp_json["status"]) + time.sleep(15) + # check if large result is stored successfully in database and can be retrieved + ## get the experiment + experiment = self.experiment_view.get_experiment(exp_id) + ## get the task + tasks = experiment["tasks"] + ## find the task with type "TakePicture" + task_id = next( + [task["task_id"] for task in tasks if task["type"] == "TakePicture"] + ) + task = self.task_view.get_task(task_id) + ## check if the result is stored correctly + self.assertTrue(task["result"]["picture"]["local_path"] is not None) + self.assertTrue(task["result"]["picture"]["storage_type"] == "gridfs") + self.assertTrue(task["result"]["picture"]["identifier"] is not None) + self.assertTrue(task["result"]["picture"]["file_like_data"] is None) + # try to retrieve the large result + self.assertTrue(LargeResult(**task["result"]["picture"]).check_if_stored()) + # read the zip file + file_path = ( + Path( + util.find_spec("alab_management").origin.split("__init__.py")[0] + ).parent + / "tests" + / "fake_lab" + / "large_file_example.zip" + ) + with open(file_path, "rb") as f: + file_content = f.read() + self.assertEqual( + LargeResult(**task["result"]["picture"]).retrieve(), file_content + ) + # clean up lab and task collection + cleanup_lab( + all_collections=True, + _force_i_know_its_dangerous=True, + sim_mode=True, + database_name="Alab_sim", + user_confirmation="y", + ) + # setup lab + setup_lab() + pass + + def test_incorrect_schema(self): + # clean up lab and task collection + cleanup_lab( + all_collections=True, + _force_i_know_its_dangerous=True, + sim_mode=True, + database_name="Alab_sim", + user_confirmation="y", + ) + # setup lab + setup_lab() + + # check if the result is not consistent with the schema + ## run an experiment with large result + def compose_exp(exp_name): + return { + "name": exp_name, + "tags": [], + "metadata": {}, + "samples": [{"name": "test_sample", "tags": [], "metadata": {}}], + "tasks": [ + { + "type": "Starting", + "prev_tasks": [], + "parameters": { + "dest": "furnace_temp", + }, + "samples": ["test_sample"], + }, + { + "type": "TakePictureMissingResult", + "prev_tasks": [0], + "parameters": {}, + "samples": ["test_sample"], + }, + ], + } + + exp_name = "Experiment with large result" + experiment = compose_exp(exp_name) + resp = requests.post(SUBMISSION_API, json=experiment) + resp_json = resp.json() + exp_id = ObjectId(resp_json["data"]["exp_id"]) + self.assertTrue("success", resp_json["status"]) + time.sleep(15) + # check no exception is raised if the schema is not correct + ## get the experiment + experiment = self.experiment_view.get_experiment(exp_id) + ## get the task + tasks = experiment["tasks"] + ## find the task with type "TakePictureMissingResult" + task_id = next( + [ + task["task_id"] + for task in tasks + if task["type"] == "TakePictureMissingResult" + ] + ) + task = self.task_view.get_task(task_id) + ## check if the result is still stored correctly + self.assertTrue(task["result"]["timestamp"] is not None) + self.assertTrue(set(task["result"].keys()) == set("timestamp")) + ## check that the task is completed + self.assertTrue(task["status"] == "COMPLETED") + # clean up lab and task collection + cleanup_lab( + all_collections=True, + _force_i_know_its_dangerous=True, + sim_mode=True, + database_name="Alab_sim", + user_confirmation="y", + ) + # setup lab + setup_lab() + pass