diff --git a/alab_management/lab_view.py b/alab_management/lab_view.py index 01845b52..9daa30e7 100644 --- a/alab_management/lab_view.py +++ b/alab_management/lab_view.py @@ -354,6 +354,17 @@ def priority(self, priority: int): priority = priority.value self._priority = priority + def update_result(self, name: str, value: Any): + """ + Update a result of the task. This result will be saved in the task collection under `results.name` and can be + retrieved later. + + Args: name (str): name of the result (ie "diffraction pattern"). This will be used as the key in the results + dictionary. value (Any): value of the result. This can be a numpy array, a set, or any other + bson-serializable object (most standard Python types). + """ + self._task_view.update_result(task_id=self.task_id, name=name, value=value) + def request_cleanup(self): """Request cleanup of the task. This function will block until the task is cleaned up.""" all_reserved_sample_positions = self._sample_view.get_sample_positions_by_task( diff --git a/alab_management/task_view/task.py b/alab_management/task_view/task.py index b3c44c34..004e73a7 100644 --- a/alab_management/task_view/task.py +++ b/alab_management/task_view/task.py @@ -137,6 +137,89 @@ def result_specification(self) -> BaseModel: "The .result_specification method must be implemented by a subclass of BaseTask." ) + def update_result(self, key: str, value: Any): + """Attach a result to the task. This will be saved in the database and + can be accessed later. Subsequent calls to this function with the same + key will overwrite the previous value. + + Args: + key (str): The name of the result. + value (Any): The value of the result. + """ + if key not in self.result_specification: + raise ValueError( + f"Result key {key} is not included in the result specification for this task!" + ) + + # TODO type checking? + + if not self.__simulation: + self.lab_view.update_result(name=key, value=value) + + def export_result(self, key: str) -> dict: + """ + Creates a reference to a result generated by this Task. This + result can then be imported by another task. This is useful in + cases where tasks are chained together. For instance, the + diffraction results from a "PowderDiffraction" task could be + exported, then imported by a "RietveldRefinement" analysis task. + + Args: + key (str): The name of the result. + + Returns + ------- + Any: The value of the result. + """ + if key not in self.result_specification: + raise ValueError( + f"Result key {key} is not included in the result specification for this task!" + ) + + return ResultPointer(task_id=self.task_id, key=key).to_json() + + def import_result( + self, + pointer: ResultPointer | dict[str, Any], + allow_explicit_value: bool = False, + ) -> Any: + """ + Imports a result from another task. This is useful in cases where + tasks are chained together. For instance, the diffraction results from a + ``PowderDiffraction`` task could be exported, then imported by a "RietveldRefinement" + analysis task. + + Args: + pointer (Union[ResultPointer, Dict[str, Any]]): Either a ResultPointer object + or a dictionary with the same format as a ResultPointer. + allow_explicit_value (bool, optional): If true, users can pass values here instead + of pointers. If False, only ResultPointers are valid. Defaults to False. + + Raises + ------ + ValueError: If allow_explicit_value is False and the user passes a value instead of a pointer. + + Returns + ------- + Any: The value of the result. + """ + if isinstance(pointer, dict) and pointer.get("type", None) == "ResultPointer": + pointer = ResultPointer.from_json(pointer) + elif isinstance(pointer, ResultPointer): + pass # already in correct format + else: + if allow_explicit_value: + return pointer # user passed a specific value instead of a pointer + else: + raise ValueError( + f"Invalid pointer: {pointer}. This value was expected to be a pointer " + f"to an existing task result, but an explicit value was passed instead! " + f"If you want to allow explicit values, set allow_explicit_value=True." + ) + + reference_task = self.lab_view._task_view.get_task(task_id=pointer.task_id) + return reference_task["result"][pointer.key] + @priority.setter def priority(self, value: int | TaskPriority): if value < 0: