From 03e69517813f422be9205eac9718a54de9411d07 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 5 Dec 2024 13:57:04 +1100 Subject: [PATCH] resolved gnarly circular imports --- pydra/design/__init__.py | 4 +- pydra/design/base.py | 156 +- pydra/design/python.py | 2 +- pydra/design/shell.py | 62 +- pydra/design/tests/test_python.py | 3 +- pydra/design/tests/test_shell.py | 6 +- pydra/design/tests/test_workflow.py | 58 +- pydra/design/workflow.py | 36 +- pydra/engine/__init__.py | 2 - pydra/engine/audit.py | 9 +- pydra/engine/core.py | 328 +-- pydra/engine/helpers.py | 67 +- pydra/engine/specs.py | 1761 ++++++++--------- pydra/engine/state.py | 9 +- pydra/engine/task.py | 12 +- pydra/engine/workflow/__init__.py | 0 pydra/engine/workflow/base.py | 178 ++ pydra/engine/workflow/lazy.py | 250 +++ .../engine/{workflow.py => workflow/node.py} | 109 +- pydra/utils/typing.py | 51 +- 20 files changed, 1653 insertions(+), 1450 deletions(-) create mode 100644 pydra/engine/workflow/__init__.py create mode 100644 pydra/engine/workflow/base.py create mode 100644 pydra/engine/workflow/lazy.py rename pydra/engine/{workflow.py => workflow/node.py} (83%) diff --git a/pydra/design/__init__.py b/pydra/design/__init__.py index 9b86627949..0cfe94caa9 100644 --- a/pydra/design/__init__.py +++ b/pydra/design/__init__.py @@ -1,6 +1,6 @@ -from .base import TaskSpec, list_fields from . import python from . import shell +from . import workflow -__all__ = ["TaskSpec", "list_fields", "python", "shell"] +__all__ = ["python", "shell", "workflow"] diff --git a/pydra/design/base.py b/pydra/design/base.py index 64e0ab8510..0fbf79ac82 100644 --- a/pydra/design/base.py +++ b/pydra/design/base.py @@ -5,33 +5,35 @@ import enum from pathlib import Path from copy import copy -from typing_extensions import Self import attrs.validators from attrs.converters import default_if_none from fileformats.generic import File from pydra.utils.typing import TypeParser, is_optional, is_fileset_or_union - -# from pydra.utils.misc import get_undefined_symbols -from pydra.engine.helpers import from_list_if_single, ensure_list -from pydra.engine.specs import ( - LazyField, +from pydra.engine.helpers import ( + from_list_if_single, + ensure_list, + PYDRA_ATTR_METADATA, + list_fields, +) +from pydra.utils.typing import ( MultiInputObj, MultiInputFile, MultiOutputObj, MultiOutputFile, ) -from pydra.engine.core import Task, AuditFlag +from pydra.engine.workflow.lazy import LazyField +if ty.TYPE_CHECKING: + from pydra.engine.specs import OutputsSpec + from pydra.engine.core import Task + __all__ = [ "Field", "Arg", "Out", - "TaskSpec", - "OutputsSpec", "ensure_field_objects", "make_task_spec", - "list_fields", ] RESERVED_OUTPUT_NAMES = ("split", "combine") @@ -154,120 +156,6 @@ class Out(Field): pass -class OutputsSpec: - """Base class for all output specifications""" - - def split( - self, - splitter: ty.Union[str, ty.List[str], ty.Tuple[str, ...], None] = None, - /, - overwrite: bool = False, - cont_dim: ty.Optional[dict] = None, - **inputs, - ) -> Self: - """ - Run this task parametrically over lists of split inputs. - - Parameters - ---------- - splitter : str or list[str] or tuple[str] or None - the fields which to split over. If splitting over multiple fields, lists of - fields are interpreted as outer-products and tuples inner-products. If None, - then the fields to split are taken from the keyword-arg names. - overwrite : bool, optional - whether to overwrite an existing split on the node, by default False - cont_dim : dict, optional - Container dimensions for specific inputs, used in the splitter. - If input name is not in cont_dim, it is assumed that the input values has - a container dimension of 1, so only the most outer dim will be used for splitting. - **inputs - fields to split over, will automatically be wrapped in a StateArray object - and passed to the node inputs - - Returns - ------- - self : TaskBase - a reference to the task - """ - self._node.split(splitter, overwrite=overwrite, cont_dim=cont_dim, **inputs) - return self - - def combine( - self, - combiner: ty.Union[ty.List[str], str], - overwrite: bool = False, # **kwargs - ) -> Self: - """ - Combine inputs parameterized by one or more previous tasks. - - Parameters - ---------- - combiner : list[str] or str - the field or list of inputs to be combined (i.e. not left split) after the - task has been run - overwrite : bool - whether to overwrite an existing combiner on the node - **kwargs : dict[str, Any] - values for the task that will be "combined" before they are provided to the - node - - Returns - ------- - self : Self - a reference to the outputs object - """ - self._node.combine(combiner, overwrite=overwrite) - return self - - -OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) - - -class TaskSpec(ty.Generic[OutputType]): - """Base class for all task specifications""" - - Task: ty.Type[Task] - - def __call__( - self, - name: str | None = None, - audit_flags: AuditFlag = AuditFlag.NONE, - cache_dir=None, - cache_locations=None, - inputs: ty.Text | File | dict[str, ty.Any] | None = None, - cont_dim=None, - messenger_args=None, - messengers=None, - rerun=False, - **kwargs, - ): - self._check_for_unset_values() - task = self.Task( - self, - name=name, - audit_flags=audit_flags, - cache_dir=cache_dir, - cache_locations=cache_locations, - inputs=inputs, - cont_dim=cont_dim, - messenger_args=messenger_args, - messengers=messengers, - rerun=rerun, - ) - return task(**kwargs) - - def _check_for_unset_values(self): - if unset := [ - k - for k, v in attrs.asdict(self, recurse=False).items() - if v is attrs.NOTHING - ]: - raise ValueError( - f"The following values {unset} in the {self!r} interface need to be set " - "before the workflow can be constructed" - ) - - def extract_fields_from_class( klass: type, arg_type: type[Arg], @@ -352,7 +240,7 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]: def make_task_spec( - task_type: type[Task], + task_type: type["Task"], inputs: dict[str, Arg], outputs: dict[str, Out], klass: type | None = None, @@ -389,6 +277,8 @@ def make_task_spec( klass : type The class created using the attrs package """ + from pydra.engine.specs import TaskSpec + if name is None and klass is not None: name = klass.__name__ outputs_klass = make_outputs_spec(outputs, outputs_bases, name) @@ -457,7 +347,7 @@ def make_task_spec( def make_outputs_spec( outputs: dict[str, Out], bases: ty.Sequence[type], spec_name: str -) -> type[OutputsSpec]: +) -> type["OutputsSpec"]: """Create an outputs specification class and its outputs specification class from the output fields provided to the decorator/function. @@ -478,6 +368,8 @@ def make_outputs_spec( klass : type The class created using the attrs package """ + from pydra.engine.specs import OutputsSpec + if not any(issubclass(b, OutputsSpec) for b in bases): outputs_bases = bases + (OutputsSpec,) if reserved_names := [n for n in outputs if n in RESERVED_OUTPUT_NAMES]: @@ -880,16 +772,6 @@ def split_block(string: str) -> ty.Generator[str, None, None]: yield block.strip() -def list_fields(interface: TaskSpec) -> list[Field]: - if not attrs.has(interface): - return [] - return [ - f.metadata[PYDRA_ATTR_METADATA] - for f in attrs.fields(interface) - if PYDRA_ATTR_METADATA in f.metadata - ] - - def check_explicit_fields_are_none(klass, inputs, outputs): if inputs is not None: raise ValueError( @@ -918,5 +800,3 @@ def nothing_factory(): white_space_re = re.compile(r"\s+") - -PYDRA_ATTR_METADATA = "__PYDRA_METADATA__" diff --git a/pydra/design/python.py b/pydra/design/python.py index 154c5f8a6d..b25d36e010 100644 --- a/pydra/design/python.py +++ b/pydra/design/python.py @@ -2,12 +2,12 @@ import inspect import attrs from pydra.engine.task import FunctionTask +from pydra.engine.specs import TaskSpec from .base import ( Arg, Out, ensure_field_objects, make_task_spec, - TaskSpec, parse_doc_string, extract_function_inputs_and_outputs, check_explicit_fields_are_none, diff --git a/pydra/design/shell.py b/pydra/design/shell.py index 714f67e80f..21d5d435c9 100644 --- a/pydra/design/shell.py +++ b/pydra/design/shell.py @@ -11,20 +11,20 @@ from fileformats.core import from_mime from fileformats import generic from fileformats.core.exceptions import FormatRecognitionError +from pydra.engine.specs import TaskSpec from .base import ( Arg, Out, check_explicit_fields_are_none, extract_fields_from_class, ensure_field_objects, - TaskSpec, make_task_spec, EMPTY, ) -from pydra.utils.typing import is_fileset_or_union -from pydra.engine.specs import MultiInputObj +from pydra.utils.typing import is_fileset_or_union, MultiInputObj from pydra.engine.task import ShellCommandTask + __all__ = ["arg", "out", "outarg", "define"] @@ -180,7 +180,6 @@ class outarg(Out, arg): If provided, the field is treated also as an output field and it is added to the output spec. The template can use other fields, e.g. {file1}. Used in order to create an output specification. - """ path_template: str | None = attrs.field(default=None) @@ -204,7 +203,35 @@ def define( auto_attribs: bool = True, name: str | None = None, ) -> TaskSpec: - """Create a shell command interface + """Create a task specification for a shell command. Can be used either as a decorator on + the "canonical" dataclass-form of a task specification or as a function that takes a + "shell-command template string" of the form + + ``` + shell.define("command --output ") + ``` + + Fields are inferred from the template if not provided. In the template, inputs are + specified with `` and outputs with ``. + + ``` + my_command + ``` + + The types of the fields can be specified using their MIME like (see fileformats.core.from_mime), e.g. + + ``` + my_command + ``` + + The template can also specify options with `-` or `--` followed by the option name + and arguments with ``. The type is optional and will default to + `generic/fs-object` if not provided for arguments and `field/text` for + options. The file-formats namespace can be dropped for generic and field formats, e.g. + + ``` + another-command --output + ``` Parameters ---------- @@ -221,6 +248,11 @@ def define( as they appear in the template name: str | None The name of the returned class + + Returns + ------- + TaskSpec + The interface for the shell command """ def make( @@ -331,9 +363,10 @@ def parse_command_line_template( outputs: list[str | Out] | dict[str, Out | type] | None = None, ) -> ty.Tuple[str, dict[str, Arg | type], dict[str, Out | type]]: """Parses a command line template into a name and input and output fields. Fields - are inferred from the template if not provided, where inputs are specified with `` - and outputs with ``. The types of the fields can be specified using their - MIME like (see fileformats.core.from_mime), e.g. + are inferred from the template if not explicitly provided. + + In the template, inputs are specified with `` and outputs with ``. + The types of the fields can be specified using their MIME like (see fileformats.core.from_mime), e.g. ``` my_command @@ -345,7 +378,7 @@ def parse_command_line_template( options. The file-formats namespace can be dropped for generic and field formats, e.g. ``` - another-command --output + another-command --output ``` Parameters @@ -365,6 +398,13 @@ def parse_command_line_template( The input fields of the command line template outputs : dict[str, Out | type] The output fields of the command line template + + Raises + ------ + ValueError + If an unknown token is found in the command line template + TypeError + If an unknown type is found in the command line template """ if isinstance(inputs, list): inputs = {arg.name: arg for arg in inputs} @@ -437,9 +477,9 @@ def from_type_str(type_str) -> type: try: type_ = from_mime(f"generic/{tp}") except FormatRecognitionError: - raise ValueError( + raise TypeError( f"Found unknown type, {tp!r}, in command template: {template!r}" - ) + ) from None types.append(type_) if len(types) == 2 and types[1] == "...": type_ = MultiInputObj[types[0]] diff --git a/pydra/design/tests/test_python.py b/pydra/design/tests/test_python.py index 8939539d58..54dcd0fda4 100644 --- a/pydra/design/tests/test_python.py +++ b/pydra/design/tests/test_python.py @@ -3,7 +3,8 @@ from decimal import Decimal import attrs import pytest -from pydra.design import list_fields, TaskSpec +from pydra.engine.helpers import list_fields +from pydra.engine.specs import TaskSpec from pydra.design import python from pydra.engine.task import FunctionTask diff --git a/pydra/design/tests/test_shell.py b/pydra/design/tests/test_shell.py index cf9ea7db8a..6d4dc3cac5 100644 --- a/pydra/design/tests/test_shell.py +++ b/pydra/design/tests/test_shell.py @@ -3,10 +3,12 @@ import attrs import pytest import cloudpickle as cp -from pydra.design import shell, TaskSpec, list_fields +from pydra.design import shell +from pydra.engine.helpers import list_fields +from pydra.engine.specs import TaskSpec from fileformats.generic import File, Directory, FsObject from fileformats import text, image -from pydra.engine.specs import MultiInputObj +from pydra.utils.typing import MultiInputObj def test_interface_template(): diff --git a/pydra/design/tests/test_workflow.py b/pydra/design/tests/test_workflow.py index 1502ce4f6a..9311ddb601 100644 --- a/pydra/design/tests/test_workflow.py +++ b/pydra/design/tests/test_workflow.py @@ -1,10 +1,12 @@ from operator import attrgetter import pytest import attrs -from pydra.engine.workflow import Workflow -from pydra.engine.specs import LazyInField, LazyOutField +from pydra.engine.workflow.base import Workflow +from pydra.engine.workflow.lazy import LazyInField, LazyOutField import typing as ty -from pydra.design import shell, python, workflow, list_fields, TaskSpec +from pydra.design import shell, python, workflow +from pydra.engine.helpers import list_fields +from pydra.engine.specs import TaskSpec from fileformats import video, image @@ -44,7 +46,7 @@ def MyTestWorkflow(a, b): wf = Workflow.construct(workflow_spec) assert wf.inputs.a == 1 assert wf.inputs.b == 2.0 - assert wf.outputs.out == LazyOutField(name="Mul", field="out", type=ty.Any) + assert wf.outputs.out == LazyOutField(node=wf["Mul"], field="out", type=ty.Any) # Nodes are named after the specs by default assert list(wf.node_names) == ["Add", "Mul"] @@ -107,7 +109,7 @@ def MyTestShellWorkflow( assert wf.inputs.input_video == input_video assert wf.inputs.watermark == watermark assert wf.outputs.output_video == LazyOutField( - name="resize", field="out_video", type=video.Mp4 + node=wf["resize"], field="out_video", type=video.Mp4 ) assert list(wf.node_names) == ["add_watermark", "resize"] @@ -168,7 +170,7 @@ class Outputs: wf = Workflow.construct(workflow_spec) assert wf.inputs.a == 1 assert wf.inputs.b == 2.0 - assert wf.outputs.out == LazyOutField(name="Mul", field="out", type=ty.Any) + assert wf.outputs.out == LazyOutField(node=wf["Mul"], field="out", type=ty.Any) # Nodes are named after the specs by default assert list(wf.node_names) == ["Add", "Mul"] @@ -218,10 +220,10 @@ def MyTestShellWorkflow( ) wf = Workflow.construct(workflow_spec) assert wf["add_watermark"].inputs.in_video == LazyInField( - field="input_video", type=video.Mp4 + node=wf, field="input_video", type=video.Mp4 ) assert wf["add_watermark"].inputs.watermark == LazyInField( - field="watermark", type=image.Png + node=wf, field="watermark", type=image.Png ) @@ -277,9 +279,9 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]: wf = Workflow.construct(workflow_spec) assert wf.inputs.a == 1 assert wf.inputs.b == 2.0 - assert wf.outputs.out1 == LazyOutField(name="Mul", field="out", type=float) + assert wf.outputs.out1 == LazyOutField(node=wf["Mul"], field="out", type=float) assert wf.outputs.out2 == LazyOutField( - name="division", field="divided", type=ty.Any + node=wf["division"], field="divided", type=ty.Any ) assert list(wf.node_names) == ["addition", "Mul", "division"] @@ -323,12 +325,12 @@ def MyTestWorkflow(a: int, b: float): wf = Workflow.construct(workflow_spec) assert wf.inputs.a == 1 assert wf.inputs.b == 2.0 - assert wf.outputs.out1 == LazyOutField(name="Mul", field="out", type=ty.Any) - assert wf.outputs.out2 == LazyOutField(name="Add", field="out", type=ty.Any) + assert wf.outputs.out1 == LazyOutField(node=wf["Mul"], field="out", type=ty.Any) + assert wf.outputs.out2 == LazyOutField(node=wf["Add"], field="out", type=ty.Any) assert list(wf.node_names) == ["Add", "Mul"] -def test_workflow_split_combine(): +def test_workflow_split_combine1(): @python.define def Mul(x: float, y: float) -> float: @@ -347,7 +349,35 @@ def MyTestWorkflow(a: list[int], b: list[float]) -> list[float]: wf = Workflow.construct(MyTestWorkflow(a=[1, 2, 3], b=[1.0, 10.0, 100.0])) assert wf["Mul"].splitter == ["Mul.x", "Mul.y"] assert wf["Mul"].combiner == ["Mul.x"] - assert wf.outputs.out == LazyOutField(name="Sum", field="out", type=list[float]) + assert wf.outputs.out == LazyOutField(node=wf["Sum"], field="out", type=list[float]) + + +def test_workflow_split_combine2(): + + @python.define + def Mul(x: float, y: float) -> float: + return x * y + + @python.define + def Add(x: float, y: float) -> float: + return x + y + + @python.define + def Sum(x: list[float]) -> float: + return sum(x) + + @workflow.define + def MyTestWorkflow(a: list[int], b: list[float], c: float) -> list[float]: + mul = workflow.add(Mul()).split(x=a, y=b) + add = workflow.add(Add(x=mul.out, y=c)).combine("Mul.x") + sum = workflow.add(Sum(x=add.out)) + return sum.out + + wf = Workflow.construct(MyTestWorkflow(a=[1, 2, 3], b=[1.0, 10.0, 100.0], c=2.0)) + assert wf["Mul"].splitter == ["Mul.x", "Mul.y"] + assert wf["Mul"].combiner == ["Mul.x"] + assert wf["Add"].lzout.out.splits == frozenset(["Mul.x"]) + assert wf.outputs.out == LazyOutField(node=wf["Sum"], field="out", type=list[float]) def test_workflow_split_after_access_fail(): diff --git a/pydra/design/workflow.py b/pydra/design/workflow.py index bf804944ed..75ac13197f 100644 --- a/pydra/design/workflow.py +++ b/pydra/design/workflow.py @@ -2,18 +2,18 @@ import inspect import attrs from pydra.engine.core import WorkflowTask -from pydra.engine.workflow import Workflow +from pydra.engine.workflow.base import Workflow from .base import ( Arg, Out, ensure_field_objects, make_task_spec, - TaskSpec, parse_doc_string, extract_function_inputs_and_outputs, check_explicit_fields_are_none, extract_fields_from_class, ) +from pydra.engine.specs import TaskSpec __all__ = ["define", "add", "this", "arg", "out"] @@ -94,7 +94,8 @@ def define( auto_attribs: bool = True, ) -> TaskSpec: """ - Create an interface for a function or a class. + Create an interface for a function or a class. Can be used either as a decorator on + a constructor function or the "canonical" dataclass-form of a task specification. Parameters ---------- @@ -106,6 +107,11 @@ def define( The outputs of the function or class. auto_attribs : bool Whether to use auto_attribs mode when creating the class. + + Returns + ------- + TaskSpec + The interface for the function or class. """ if lazy is None: lazy = [] @@ -170,10 +176,30 @@ def make(wrapped: ty.Callable | type) -> TaskSpec: def this() -> Workflow: - """Get the workflow currently being constructed.""" + """Get the workflow currently being constructed. + + Returns + ------- + Workflow + The workflow currently being constructed. + """ return Workflow.under_construction def add(task_spec: TaskSpec[OutputType], name: str = None) -> OutputType: - """Add a task to the current workflow.""" + """Add a node to the workflow currently being constructed + + Parameters + ---------- + task_spec : TaskSpec + The specification of the task to add to the workflow as a node + name : str, optional + The name of the node, by default it will be the name of the task specification + class + + Returns + ------- + OutputType + The outputs specification of the node + """ return this().add(task_spec, name=name) diff --git a/pydra/engine/__init__.py b/pydra/engine/__init__.py index 6fbd7a0063..24ada3c366 100644 --- a/pydra/engine/__init__.py +++ b/pydra/engine/__init__.py @@ -1,12 +1,10 @@ """The core of the workflow engine.""" -from .submitter import Submitter import __main__ import logging from ._version import __version__ __all__ = [ - "Submitter", "logger", "check_latest_version", ] diff --git a/pydra/engine/audit.py b/pydra/engine/audit.py index 2db771da65..8d5695e4e4 100644 --- a/pydra/engine/audit.py +++ b/pydra/engine/audit.py @@ -4,10 +4,8 @@ import json import attr from pydra.utils.messenger import send_message, make_message, gen_uuid, now, AuditFlag -from pydra.utils.hash import hash_function -from .helpers import ensure_list, gather_runtime_info -from .specs import attr_fields from fileformats.core import FileSet +from pydra.utils.hash import hash_function try: import importlib_resources @@ -36,6 +34,8 @@ def __init__(self, audit_flags, messengers, messenger_args, develop=None): If True, the local context.jsonld file is used, otherwise the one from github is used. """ + from .helpers import ensure_list + self.audit_flags = audit_flags self.messengers = ensure_list(messengers) self.messenger_args = messenger_args @@ -93,6 +93,8 @@ def monitor(self): def finalize_audit(self, result): """End auditing.""" if self.audit_check(AuditFlag.RESOURCE): + from .helpers import gather_runtime_info + self.resource_monitor.stop() result.runtime = gather_runtime_info(self.resource_monitor.fname) if self.audit_check(AuditFlag.PROV): @@ -178,6 +180,7 @@ def audit_check(self, flag): def audit_task(self, task): import subprocess as sp + from .helpers import attr_fields label = task.name diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 13123cb02a..4607e23f71 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -21,20 +21,18 @@ from . import helpers_state as hlpst from .specs import ( File, - BaseSpec, + # BaseSpec, RuntimeSpec, Result, - SpecInfo, - LazyIn, - LazyOut, - LazyField, + # SpecInfo, + # LazyIn, + # LazyOut, TaskHook, - attr_fields, - StateArray, ) from .helpers import ( # make_klass, create_checksum, + attr_fields, print_help, load_result, save, @@ -592,140 +590,6 @@ def _collect_outputs(self, output_dir): ) return attr.evolve(output, **self.output_, **other_output) - def split( - self, - splitter: ty.Union[str, ty.List[str], ty.Tuple[str, ...], None] = None, - overwrite: bool = False, - cont_dim: ty.Optional[dict] = None, - **inputs, - ): - """ - Run this task parametrically over lists of split inputs. - - Parameters - ---------- - splitter : str or list[str] or tuple[str] or None - the fields which to split over. If splitting over multiple fields, lists of - fields are interpreted as outer-products and tuples inner-products. If None, - then the fields to split are taken from the keyword-arg names. - overwrite : bool, optional - whether to overwrite an existing split on the node, by default False - cont_dim : dict, optional - Container dimensions for specific inputs, used in the splitter. - If input name is not in cont_dim, it is assumed that the input values has - a container dimension of 1, so only the most outer dim will be used for splitting. - **split_inputs - fields to split over, will automatically be wrapped in a StateArray object - and passed to the node inputs - - Returns - ------- - self : TaskBase - a reference to the task - """ - if self._lzout: - raise RuntimeError( - f"Cannot split {self} as its output interface has already been accessed" - ) - if splitter is None and inputs: - splitter = list(inputs) - elif splitter: - missing = set(hlpst.unwrap_splitter(splitter)) - set(inputs) - missing = [m for m in missing if not m.startswith("_")] - if missing: - raise ValueError( - f"Split is missing values for the following fields {list(missing)}" - ) - splitter = hlpst.add_name_splitter(splitter, self.name) - # if user want to update the splitter, overwrite has to be True - if self.state and not overwrite and self.state.splitter != splitter: - raise Exception( - "splitter has been already set, " - "if you want to overwrite it - use overwrite=True" - ) - if cont_dim: - for key, vel in cont_dim.items(): - self._cont_dim[f"{self.name}.{key}"] = vel - if inputs: - new_inputs = {} - split_inputs = set( - f"{self.name}.{n}" if "." not in n else n - for n in hlpst.unwrap_splitter(splitter) - if not n.startswith("_") - ) - for inpt_name, inpt_val in inputs.items(): - new_val: ty.Any - if f"{self.name}.{inpt_name}" in split_inputs: # type: ignore - if isinstance(inpt_val, LazyField): - new_val = inpt_val.split(splitter) - elif isinstance(inpt_val, ty.Iterable) and not isinstance( - inpt_val, (ty.Mapping, str) - ): - new_val = StateArray(inpt_val) - else: - raise TypeError( - f"Could not split {inpt_val} as it is not a sequence type" - ) - else: - new_val = inpt_val - new_inputs[inpt_name] = new_val - self.inputs = attr.evolve(self.inputs, **new_inputs) - if not self.state or splitter != self.state.splitter: - self.set_state(splitter) - return self - - def combine( - self, - combiner: ty.Union[ty.List[str], str], - overwrite: bool = False, # **kwargs - ): - """ - Combine inputs parameterized by one or more previous tasks. - - Parameters - ---------- - combiner : list[str] or str - the - overwrite : bool - whether to overwrite an existing combiner on the node - **kwargs : dict[str, Any] - values for the task that will be "combined" before they are provided to the - node - - Returns - ------- - self : TaskBase - a reference to the task - """ - if self._lzout: - raise RuntimeError( - f"Cannot combine {self} as its output interface has already been " - "accessed" - ) - if not isinstance(combiner, (str, list)): - raise Exception("combiner has to be a string or a list") - combiner = hlpst.add_name_combiner(ensure_list(combiner), self.name) - if ( - self.state - and self.state.combiner - and combiner != self.state.combiner - and not overwrite - ): - raise Exception( - "combiner has been already set, " - "if you want to overwrite it - use overwrite=True" - ) - if not self.state: - self.split(splitter=None) - # a task can have a combiner without a splitter - # if is connected to one with a splitter; - # self.fut_combiner will be used later as a combiner - self.fut_combiner = combiner - else: # self.state and not self.state.combiner - self.combiner = combiner - self.set_state(splitter=self.state.splitter, combiner=self.combiner) - return self - def _extract_input_el(self, inputs, inp_nm, ind): """ Extracting element of the inputs taking into account @@ -955,13 +819,11 @@ def _check_for_hash_changes(self): def _sanitize_spec( - spec: ty.Union[ - SpecInfo, ty.List[str], ty.Dict[str, ty.Type[ty.Any]], BaseSpec, None - ], + spec: ty.Union[ty.List[str], ty.Dict[str, ty.Type[ty.Any]], None], wf_name: str, spec_name: str, allow_empty: bool = False, -) -> SpecInfo: +): """Makes sure the provided input specifications are valid. If the input specification is a list of strings, this will @@ -1040,14 +902,12 @@ def __init__( cache_dir=None, cache_locations=None, input_spec: ty.Optional[ - ty.Union[ty.List[ty.Text], ty.Dict[ty.Text, ty.Type[ty.Any]], SpecInfo] + ty.Union[ty.List[ty.Text], ty.Dict[ty.Text, ty.Type[ty.Any]]] ] = None, cont_dim=None, messenger_args=None, messengers=None, - output_spec: ty.Optional[ - ty.Union[ty.List[str], ty.Dict[str, type], SpecInfo, BaseSpec] - ] = None, + output_spec: ty.Optional[ty.Union[ty.List[str], ty.Dict[str, type]]] = None, rerun=False, propagate_rerun=True, **kwargs, @@ -1338,91 +1198,91 @@ async def _run_task(self, submitter, rerun=False, environment=None): # at this point Workflow is stateless so this should be fine await submitter.expand_workflow(self, rerun=rerun) - def set_output( - self, - connections: ty.Union[ - ty.Tuple[str, LazyField], ty.List[ty.Tuple[str, LazyField]] - ], - ): - """ - Set outputs of the workflow by linking them with lazy outputs of tasks - - Parameters - ---------- - connections : tuple[str, LazyField] or list[tuple[str, LazyField]] or None - single or list of tuples linking the name of the output to a lazy output - of a task in the workflow. - """ - from pydra.utils.typing import TypeParser - - if self._connections is None: - self._connections = [] - if isinstance(connections, tuple) and len(connections) == 2: - new_connections = [connections] - elif isinstance(connections, list) and all( - [len(el) == 2 for el in connections] - ): - new_connections = connections - elif isinstance(connections, dict): - new_connections = list(connections.items()) - else: - raise TypeError( - "Connections can be a 2-elements tuple, a list of these tuples, or dictionary" - ) - # checking if a new output name is already in the connections - connection_names = [name for name, _ in self._connections] - if self.output_spec: - output_types = {a.name: a.type for a in attr.fields(self.interface.Outputs)} - else: - output_types = {} - # Check for type matches with explicitly defined outputs - conflicting = [] - type_mismatches = [] - for conn_name, lazy_field in new_connections: - if conn_name in connection_names: - conflicting.append(conn_name) - try: - output_type = output_types[conn_name] - except KeyError: - pass - else: - if not TypeParser.matches_type(lazy_field.type, output_type): - type_mismatches.append((conn_name, output_type, lazy_field.type)) - if conflicting: - raise ValueError(f"the output names {conflicting} are already set") - if type_mismatches: - raise TypeError( - f"the types of the following outputs of {self} don't match their declared types: " - + ", ".join( - f"{n} (expected: {ex}, provided: {p})" - for n, ex, p in type_mismatches - ) - ) - self._connections += new_connections - fields = [] - for con in self._connections: - wf_out_nm, lf = con - task_nm, task_out_nm = lf.name, lf.field - if task_out_nm == "all_": - help_string = f"all outputs from {task_nm}" - fields.append((wf_out_nm, dict, {"help_string": help_string})) - else: - from pydra.utils.typing import TypeParser - - # getting information about the output field from the task output_spec - # providing proper type and some help string - task_output_spec = getattr(self, task_nm).output_spec - out_fld = attr.fields_dict(task_output_spec)[task_out_nm] - help_string = ( - f"{out_fld.metadata.get('help_string', '')} (from {task_nm})" - ) - if TypeParser.get_origin(lf.type) is StateArray: - type_ = TypeParser.get_item_type(lf.type) - else: - type_ = lf.type - fields.append((wf_out_nm, type_, {"help_string": help_string})) - self.output_spec = SpecInfo(name="Output", fields=fields, bases=(BaseSpec,)) - logger.info("Added %s to %s", self.output_spec, self) + # def set_output( + # self, + # connections: ty.Union[ + # ty.Tuple[str, LazyField], ty.List[ty.Tuple[str, LazyField]] + # ], + # ): + # """ + # Set outputs of the workflow by linking them with lazy outputs of tasks + + # Parameters + # ---------- + # connections : tuple[str, LazyField] or list[tuple[str, LazyField]] or None + # single or list of tuples linking the name of the output to a lazy output + # of a task in the workflow. + # """ + # from pydra.utils.typing import TypeParser + + # if self._connections is None: + # self._connections = [] + # if isinstance(connections, tuple) and len(connections) == 2: + # new_connections = [connections] + # elif isinstance(connections, list) and all( + # [len(el) == 2 for el in connections] + # ): + # new_connections = connections + # elif isinstance(connections, dict): + # new_connections = list(connections.items()) + # else: + # raise TypeError( + # "Connections can be a 2-elements tuple, a list of these tuples, or dictionary" + # ) + # # checking if a new output name is already in the connections + # connection_names = [name for name, _ in self._connections] + # if self.output_spec: + # output_types = {a.name: a.type for a in attr.fields(self.interface.Outputs)} + # else: + # output_types = {} + # # Check for type matches with explicitly defined outputs + # conflicting = [] + # type_mismatches = [] + # for conn_name, lazy_field in new_connections: + # if conn_name in connection_names: + # conflicting.append(conn_name) + # try: + # output_type = output_types[conn_name] + # except KeyError: + # pass + # else: + # if not TypeParser.matches_type(lazy_field.type, output_type): + # type_mismatches.append((conn_name, output_type, lazy_field.type)) + # if conflicting: + # raise ValueError(f"the output names {conflicting} are already set") + # if type_mismatches: + # raise TypeError( + # f"the types of the following outputs of {self} don't match their declared types: " + # + ", ".join( + # f"{n} (expected: {ex}, provided: {p})" + # for n, ex, p in type_mismatches + # ) + # ) + # self._connections += new_connections + # fields = [] + # for con in self._connections: + # wf_out_nm, lf = con + # task_nm, task_out_nm = lf.name, lf.field + # if task_out_nm == "all_": + # help_string = f"all outputs from {task_nm}" + # fields.append((wf_out_nm, dict, {"help_string": help_string})) + # else: + # from pydra.utils.typing import TypeParser + + # # getting information about the output field from the task output_spec + # # providing proper type and some help string + # task_output_spec = getattr(self, task_nm).output_spec + # out_fld = attr.fields_dict(task_output_spec)[task_out_nm] + # help_string = ( + # f"{out_fld.metadata.get('help_string', '')} (from {task_nm})" + # ) + # if TypeParser.get_origin(lf.type) is StateArray: + # type_ = TypeParser.get_item_type(lf.type) + # else: + # type_ = lf.type + # fields.append((wf_out_nm, type_, {"help_string": help_string})) + # self.output_spec = SpecInfo(name="Output", fields=fields, bases=(BaseSpec,)) + # logger.info("Added %s to %s", self.output_spec, self) def _collect_outputs(self): output_klass = self.interface.Outputs diff --git a/pydra/engine/helpers.py b/pydra/engine/helpers.py index f443a6fe69..92efc9de53 100644 --- a/pydra/engine/helpers.py +++ b/pydra/engine/helpers.py @@ -7,29 +7,47 @@ import sys from uuid import uuid4 import getpass +import typing as ty import subprocess as sp import re from time import strftime from traceback import format_exception -import attr +import attrs from filelock import SoftFileLock, Timeout import cloudpickle as cp -from .specs import ( - Runtime, - attr_fields, - Result, - LazyField, -) from .helpers_file import copy_nested_files from fileformats.core import FileSet +if ty.TYPE_CHECKING: + from .specs import TaskSpec + from pydra.design.base import Field + + +PYDRA_ATTR_METADATA = "__PYDRA_METADATA__" + + +def attr_fields(spec, exclude_names=()): + return [field for field in spec.__attrs_attrs__ if field.name not in exclude_names] + + +def list_fields(interface: "TaskSpec") -> list["Field"]: + if not attrs.has(interface): + return [] + return [ + f.metadata[PYDRA_ATTR_METADATA] + for f in attrs.fields(interface) + if PYDRA_ATTR_METADATA in f.metadata + ] + # from .specs import MultiInputFile, MultiInputObj, MultiOutputObj, MultiOutputFile def from_list_if_single(obj): """Converts a list to a single item if it is of length == 1""" - if obj is attr.NOTHING: + from pydra.engine.workflow.lazy import LazyField + + if obj is attrs.NOTHING: return obj if isinstance(obj, LazyField): return obj @@ -42,11 +60,11 @@ def from_list_if_single(obj): def print_help(obj): """Visit a task object and print its input/output interface.""" lines = [f"Help for {obj.__class__.__name__}"] - if attr.fields(obj.interface): + if attrs.fields(obj.interface): lines += ["Input Parameters:"] - for f in attr.fields(obj.interface): + for f in attrs.fields(obj.interface): default = "" - if f.default != attr.NOTHING and not f.name.startswith("_"): + if f.default != attrs.NOTHING and not f.name.startswith("_"): default = f" (default: {f.default})" try: name = f.type.__name__ @@ -54,9 +72,9 @@ def print_help(obj): name = str(f.type) lines += [f"- {f.name}: {name}{default}"] output_klass = obj.interface.Outputs - if attr.fields(output_klass): + if attrs.fields(output_klass): lines += ["Output Parameters:"] - for f in attr.fields(output_klass): + for f in attrs.fields(output_klass): try: name = f.type.__name__ except AttributeError: @@ -154,6 +172,8 @@ def gather_runtime_info(fname): A runtime object containing the collected information. """ + from .specs import Runtime + runtime = Runtime(rss_peak_gb=None, vms_peak_gb=None, cpu_peak_percent=None) # Read .prof file in and set runtime values @@ -370,9 +390,9 @@ def get_open_loop(): # TODO # """ -# current_output_spec_names = [f.name for f in attr.fields(interface.Outputs)] +# current_output_spec_names = [f.name for f in attrs.fields(interface.Outputs)] # new_fields = [] -# for fld in attr.fields(interface): +# for fld in attrs.fields(interface): # if "output_file_template" in fld.metadata: # if "output_field_name" in fld.metadata: # field_name = fld.metadata["output_field_name"] @@ -382,7 +402,7 @@ def get_open_loop(): # if field_name not in current_output_spec_names: # # TODO: should probably remove some of the keys # new_fields.append( -# (field_name, attr.ib(type=File, metadata=fld.metadata)) +# (field_name, attrs.field(type=File, metadata=fld.metadata)) # ) # output_spec.fields += new_fields # return output_spec @@ -423,6 +443,9 @@ def load_and_run( loading a task from a pickle file, settings proper input and running the task """ + + from .specs import Result + try: task = load_task(task_pkl=task_pkl, ind=ind) except Exception: @@ -470,7 +493,7 @@ def load_task(task_pkl, ind=None): task = cp.loads(task_pkl.read_bytes()) if ind is not None: ind_inputs = task.get_input_el(ind) - task.inputs = attr.evolve(task.inputs, **ind_inputs) + task.inputs = attrs.evolve(task.inputs, **ind_inputs) task._pre_split = True task.state = None # resetting uid for task @@ -540,7 +563,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): return None -def parse_copyfile(fld: attr.Attribute, default_collation=FileSet.CopyCollation.any): +def parse_copyfile(fld: attrs.Attribute, default_collation=FileSet.CopyCollation.any): """Gets the copy mode from the 'copyfile' value from a field attribute""" copyfile = fld.metadata.get("copyfile", FileSet.CopyMode.any) if isinstance(copyfile, tuple): @@ -580,7 +603,7 @@ def parse_format_string(fmtstr): identifier = r"[a-zA-Z_]\w*" attribute = rf"\.{identifier}" item = r"\[\w+\]" - # Example: var.attr[key][0].attr2 (capture "var") + # Example: var.attrs[key][0].attr2 (capture "var") field_with_lookups = ( f"({identifier})(?:{attribute}|{item})*" # Capture only the keyword ) @@ -614,8 +637,10 @@ def ensure_list(obj, tuple2list=False): [5.0] """ - if obj is attr.NOTHING: - return attr.NOTHING + from pydra.engine.workflow.lazy import LazyField + + if obj is attrs.NOTHING: + return attrs.NOTHING if obj is None: return [] # list or numpy.array (this might need some extra flag in case an array has to be converted) diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index f1430aad41..0cdc4f07f2 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -2,198 +2,166 @@ from pathlib import Path import typing as ty -import inspect -import re -import os -from copy import copy -from glob import glob -import attr -from typing_extensions import Self -from fileformats.core import FileSet -from fileformats.generic import ( - File, - Directory, -) -import pydra.engine -from .helpers_file import template_update_single -from pydra.utils.hash import hash_function, Cache - -# from pydra.utils.misc import add_exc_note - - -T = ty.TypeVar("T") - - -def attr_fields(spec, exclude_names=()): - return [field for field in spec.__attrs_attrs__ if field.name not in exclude_names] - - -# These are special types that are checked for in the construction of input/output specs -# and special converters inserted into the attrs fields. - - -class MultiInputObj(list, ty.Generic[T]): - pass - - -MultiInputFile = MultiInputObj[File] - -# Since we can't create a NewType from a type union, we add a dummy type to the union -# so we can detect the MultiOutput in the input/output spec creation -class MultiOutputType: - pass +# import inspect +# import re +# import os +from pydra.engine.audit import AuditFlag +# from glob import glob +import attrs +from typing_extensions import Self -MultiOutputObj = ty.Union[list, object, MultiOutputType] -MultiOutputFile = ty.Union[File, ty.List[File], MultiOutputType] - -OUTPUT_TEMPLATE_TYPES = ( - Path, - ty.List[Path], - ty.Union[Path, bool], - ty.Union[ty.List[Path], bool], - ty.List[ty.List[Path]], +# from fileformats.core import FileSet +from fileformats.generic import ( + File, + # Directory, ) +from .helpers import attr_fields +# from .helpers_file import template_update_single +# from pydra.utils.hash import hash_function, Cache -@attr.s(auto_attribs=True, kw_only=True) -class SpecInfo: - """Base data structure for metadata of specifications.""" - - name: str - """A name for the specification.""" - fields: ty.List[ty.Tuple] = attr.ib(factory=list) - """List of names of fields (can be inputs or outputs).""" - bases: ty.Sequence[ty.Type["BaseSpec"]] = attr.ib(factory=tuple) - """Keeps track of specification inheritance. - Should be a tuple containing at least one BaseSpec """ - - -@attr.s(auto_attribs=True, kw_only=True) -class BaseSpec: - """The base dataclass specs for all inputs and outputs.""" - - def collect_additional_outputs(self, inputs, output_dir, outputs): - """Get additional outputs.""" - return {} - - @property - def hash(self): - hsh, self._hashes = self._compute_hashes() - return hsh - - def hash_changes(self): - """Detects any changes in the hashed values between the current inputs and the - previously calculated values""" - _, new_hashes = self._compute_hashes() - return [k for k, v in new_hashes.items() if v != self._hashes[k]] - - def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: - """Compute a basic hash for any given set of fields.""" - inp_dict = {} - for field in attr_fields( - self, exclude_names=("_graph_checksums", "bindings", "files_hash") - ): - if field.metadata.get("output_file_template"): - continue - # removing values that are not set from hash calculation - if getattr(self, field.name) is attr.NOTHING: - continue - if "container_path" in field.metadata: - continue - inp_dict[field.name] = getattr(self, field.name) - hash_cache = Cache() - field_hashes = { - k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items() - } - if hasattr(self, "_graph_checksums"): - field_hashes["_graph_checksums"] = self._graph_checksums - return hash_function(sorted(field_hashes.items())), field_hashes - - def retrieve_values(self, wf, state_index: ty.Optional[int] = None): - """Get values contained by this spec.""" - retrieved_values = {} - for field in attr_fields(self): - value = getattr(self, field.name) - if isinstance(value, LazyField): - retrieved_values[field.name] = value.get_value( - wf, state_index=state_index - ) - for field, val in retrieved_values.items(): - setattr(self, field, val) - - def check_fields_input_spec(self): - """ - Check fields from input spec based on the medatada. +# from pydra.utils.misc import add_exc_note - e.g., if xor, requires are fulfilled, if value provided when mandatory. - """ - fields = attr_fields(self) - - for field in fields: - field_is_mandatory = bool(field.metadata.get("mandatory")) - field_is_unset = getattr(self, field.name) is attr.NOTHING - - if field_is_unset and not field_is_mandatory: - continue - - # Collect alternative fields associated with this field. - alternative_fields = { - name: getattr(self, name) is not attr.NOTHING - for name in field.metadata.get("xor", []) - if name != field.name - } - alternatives_are_set = any(alternative_fields.values()) - - # Raise error if no field in mandatory alternative group is set. - if field_is_unset: - if alternatives_are_set: - continue - message = f"{field.name} is mandatory and unset." - if alternative_fields: - raise AttributeError( - message[:-1] - + f", but no alternative provided by {list(alternative_fields)}." - ) - else: - raise AttributeError(message) - - # Raise error if multiple alternatives are set. - elif alternatives_are_set: - set_alternative_fields = [ - name for name, is_set in alternative_fields.items() if is_set - ] - raise AttributeError( - f"{field.name} is mutually exclusive with {set_alternative_fields}" - ) - - # Collect required fields associated with this field. - required_fields = { - name: getattr(self, name) is not attr.NOTHING - for name in field.metadata.get("requires", []) - if name != field.name - } - - # Raise error if any required field is unset. - if not all(required_fields.values()): - unset_required_fields = [ - name for name, is_set in required_fields.items() if not is_set - ] - raise AttributeError(f"{field.name} requires {unset_required_fields}") - - def check_metadata(self): - """Check contained metadata.""" - - def template_update(self): - """Update template.""" - - def copyfile_input(self, output_dir): - """Copy the file pointed by a :class:`File` input.""" - - -@attr.s(auto_attribs=True, kw_only=True) +# @attrs.define(auto_attribs=True, kw_only=True) +# class SpecInfo: +# """Base data structure for metadata of specifications.""" + +# name: str +# """A name for the specification.""" +# fields: ty.List[ty.Tuple] = attrs.field(factory=list) +# """List of names of fields (can be inputs or outputs).""" +# bases: ty.Sequence[ty.Type["BaseSpec"]] = attrs.field(factory=tuple) +# """Keeps track of specification inheritance. +# Should be a tuple containing at least one BaseSpec """ + + +# @attrs.define(auto_attribs=True, kw_only=True) +# class BaseSpec: +# """The base dataclass specs for all inputs and outputs.""" + +# def collect_additional_outputs(self, inputs, output_dir, outputs): +# """Get additional outputs.""" +# return {} + +# @property +# def hash(self): +# hsh, self._hashes = self._compute_hashes() +# return hsh + +# def hash_changes(self): +# """Detects any changes in the hashed values between the current inputs and the +# previously calculated values""" +# _, new_hashes = self._compute_hashes() +# return [k for k, v in new_hashes.items() if v != self._hashes[k]] + +# def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: +# """Compute a basic hash for any given set of fields.""" +# inp_dict = {} +# for field in attr_fields( +# self, exclude_names=("_graph_checksums", "bindings", "files_hash") +# ): +# if field.metadata.get("output_file_template"): +# continue +# # removing values that are not set from hash calculation +# if getattr(self, field.name) is attrs.NOTHING: +# continue +# if "container_path" in field.metadata: +# continue +# inp_dict[field.name] = getattr(self, field.name) +# hash_cache = Cache() +# field_hashes = { +# k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items() +# } +# if hasattr(self, "_graph_checksums"): +# field_hashes["_graph_checksums"] = self._graph_checksums +# return hash_function(sorted(field_hashes.items())), field_hashes + +# def retrieve_values(self, wf, state_index: ty.Optional[int] = None): +# """Get values contained by this spec.""" +# retrieved_values = {} +# for field in attr_fields(self): +# value = getattr(self, field.name) +# if isinstance(value, LazyField): +# retrieved_values[field.name] = value.get_value( +# wf, state_index=state_index +# ) +# for field, val in retrieved_values.items(): +# setattr(self, field, val) + +# def check_fields_input_spec(self): +# """ +# Check fields from input spec based on the medatada. + +# e.g., if xor, requires are fulfilled, if value provided when mandatory. + +# """ +# fields = attr_fields(self) + +# for field in fields: +# field_is_mandatory = bool(field.metadata.get("mandatory")) +# field_is_unset = getattr(self, field.name) is attrs.NOTHING + +# if field_is_unset and not field_is_mandatory: +# continue + +# # Collect alternative fields associated with this field. +# alternative_fields = { +# name: getattr(self, name) is not attrs.NOTHING +# for name in field.metadata.get("xor", []) +# if name != field.name +# } +# alternatives_are_set = any(alternative_fields.values()) + +# # Raise error if no field in mandatory alternative group is set. +# if field_is_unset: +# if alternatives_are_set: +# continue +# message = f"{field.name} is mandatory and unset." +# if alternative_fields: +# raise AttributeError( +# message[:-1] +# + f", but no alternative provided by {list(alternative_fields)}." +# ) +# else: +# raise AttributeError(message) + +# # Raise error if multiple alternatives are set. +# elif alternatives_are_set: +# set_alternative_fields = [ +# name for name, is_set in alternative_fields.items() if is_set +# ] +# raise AttributeError( +# f"{field.name} is mutually exclusive with {set_alternative_fields}" +# ) + +# # Collect required fields associated with this field. +# required_fields = { +# name: getattr(self, name) is not attrs.NOTHING +# for name in field.metadata.get("requires", []) +# if name != field.name +# } + +# # Raise error if any required field is unset. +# if not all(required_fields.values()): +# unset_required_fields = [ +# name for name, is_set in required_fields.items() if not is_set +# ] +# raise AttributeError(f"{field.name} requires {unset_required_fields}") + +# def check_metadata(self): +# """Check contained metadata.""" + +# def template_update(self): +# """Update template.""" + +# def copyfile_input(self, output_dir): +# """Copy the file pointed by a :class:`File` input.""" + + +@attrs.define(auto_attribs=True, kw_only=True) class Runtime: """Represent run time metadata.""" @@ -205,7 +173,7 @@ class Runtime: """Peak in cpu consumption.""" -@attr.s(auto_attribs=True, kw_only=True) +@attrs.define(auto_attribs=True, kw_only=True) class Result: """Metadata regarding the outputs of processing.""" @@ -218,15 +186,15 @@ def __getstate__(self): if state["output"] is not None: fields = tuple((el.name, el.type) for el in attr_fields(state["output"])) state["output_spec"] = (state["output"].__class__.__name__, fields) - state["output"] = attr.asdict(state["output"], recurse=False) + state["output"] = attrs.asdict(state["output"], recurse=False) return state def __setstate__(self, state): if "output_spec" in state: spec = list(state["output_spec"]) del state["output_spec"] - klass = attr.make_class( - spec[0], {k: attr.ib(type=v) for k, v in list(spec[1])} + klass = attrs.make_class( + spec[0], {k: attrs.field(type=v) for k, v in list(spec[1])} ) state["output"] = klass(**state["output"]) self.__dict__.update(state) @@ -240,12 +208,12 @@ def get_output_field(self, field_name): Name of field in LazyField object """ if field_name == "all_": - return attr.asdict(self.output, recurse=False) + return attrs.asdict(self.output, recurse=False) else: return getattr(self.output, field_name) -@attr.s(auto_attribs=True, kw_only=True) +@attrs.define(auto_attribs=True, kw_only=True) class RuntimeSpec: """ Specification for a task. @@ -270,812 +238,679 @@ class RuntimeSpec: network: bool = False -@attr.s(auto_attribs=True, kw_only=True) -class FunctionSpec(BaseSpec): - """Specification for a process invoked from a shell.""" - - def check_metadata(self): - """ - Check the metadata for fields in input_spec and fields. - - Also sets the default values when available and needed. - - """ - supported_keys = { - "allowed_values", - "copyfile", - "help_string", - "mandatory", - # "readonly", #likely not needed - # "output_field_name", #likely not needed - # "output_file_template", #likely not needed - "requires", - "keep_extension", - "xor", - "sep", - } - for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")): - mdata = fld.metadata - # checking keys from metadata - if set(mdata.keys()) - supported_keys: - raise AttributeError( - f"only these keys are supported {supported_keys}, but " - f"{set(mdata.keys()) - supported_keys} provided" - ) - # checking if the help string is provided (required field) - if "help_string" not in mdata: - raise AttributeError(f"{fld.name} doesn't have help_string field") - # not allowing for default if the field is mandatory - if not fld.default == attr.NOTHING and mdata.get("mandatory"): - raise AttributeError( - f"default value ({fld.default!r}) should not be set when the field " - f"('{fld.name}') in {self}) is mandatory" - ) - # setting default if value not provided and default is available - if getattr(self, fld.name) is None: - if not fld.default == attr.NOTHING: - setattr(self, fld.name, fld.default) - - -@attr.s(auto_attribs=True, kw_only=True) -class ShellSpec(BaseSpec): - """Specification for a process invoked from a shell.""" - - executable: ty.Union[str, ty.List[str]] = attr.ib( - metadata={ - "help_string": "the first part of the command, can be a string, " - "e.g. 'ls', or a list, e.g. ['ls', '-l', 'dirname']" - } - ) - args: ty.Union[str, ty.List[str], None] = attr.ib( - None, - metadata={ - "help_string": "the last part of the command, can be a string, " - "e.g. , or a list" - }, - ) - - def retrieve_values(self, wf, state_index=None): - """Parse output results.""" - temp_values = {} - for field in attr_fields(self): - # retrieving values that do not have templates - if not field.metadata.get("output_file_template"): - value = getattr(self, field.name) - if isinstance(value, LazyField): - temp_values[field.name] = value.get_value( - wf, state_index=state_index - ) - for field, val in temp_values.items(): - value = path_to_string(value) - setattr(self, field, val) - - def check_metadata(self): - """ - Check the metadata for fields in input_spec and fields. - - Also sets the default values when available and needed. - - """ - from pydra.utils.typing import TypeParser - - supported_keys = { - "allowed_values", - "argstr", - "container_path", - "copyfile", - "help_string", - "mandatory", - "readonly", - "output_field_name", - "output_file_template", - "position", - "requires", - "keep_extension", - "xor", - "sep", - "formatter", - "_output_type", - } - - for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")): - mdata = fld.metadata - # checking keys from metadata - if set(mdata.keys()) - supported_keys: - raise AttributeError( - f"only these keys are supported {supported_keys}, but " - f"{set(mdata.keys()) - supported_keys} provided for '{fld.name}' " - f"field in {self}" - ) - # checking if the help string is provided (required field) - if "help_string" not in mdata: - raise AttributeError( - f"{fld.name} doesn't have help_string field in {self}" - ) - # assuming that fields with output_file_template shouldn't have default - if mdata.get("output_file_template"): - if not any( - TypeParser.matches_type(fld.type, t) for t in OUTPUT_TEMPLATE_TYPES - ): - raise TypeError( - f"Type of '{fld.name}' should be one of {OUTPUT_TEMPLATE_TYPES} " - f"(not {fld.type}) because it has a value for output_file_template " - f"({mdata['output_file_template']!r})" - ) - if fld.default not in [attr.NOTHING, True, False]: - raise AttributeError( - f"default value ({fld.default!r}) should not be set together with " - f"output_file_template ({mdata['output_file_template']!r}) for " - f"'{fld.name}' field in {self}" - ) - # not allowing for default if the field is mandatory - if not fld.default == attr.NOTHING and mdata.get("mandatory"): - raise AttributeError( - f"default value ({fld.default!r}) should not be set when the field " - f"('{fld.name}') in {self}) is mandatory" - ) - # setting default if value not provided and default is available - if getattr(self, fld.name) is None: - if not fld.default == attr.NOTHING: - setattr(self, fld.name, fld.default) - - -@attr.s(auto_attribs=True, kw_only=True) -class ShellOutSpec: - """Output specification of a generic shell process.""" - - return_code: int - """The process' exit code.""" - stdout: str - """The process' standard output.""" - stderr: str - """The process' standard input.""" - - def collect_additional_outputs(self, inputs, output_dir, outputs): - from pydra.utils.typing import TypeParser - - """Collect additional outputs from shelltask output_spec.""" - additional_out = {} - for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): - if not TypeParser.is_subclass( - fld.type, - ( - os.PathLike, - MultiOutputObj, - int, - float, - bool, - str, - list, - ), - ): - raise TypeError( - f"Support for {fld.type} type, required for '{fld.name}' in {self}, " - "has not been implemented in collect_additional_output" - ) - # assuming that field should have either default or metadata, but not both - input_value = getattr(inputs, fld.name, attr.NOTHING) - if fld.metadata and "callable" in fld.metadata: - fld_out = self._field_metadata(fld, inputs, output_dir, outputs) - elif fld.type in [int, float, bool, str, list]: - raise AttributeError(f"{fld.type} has to have a callable in metadata") - elif input_value: # Map input value through to output - fld_out = input_value - elif fld.default != attr.NOTHING: - fld_out = self._field_defaultvalue(fld, output_dir) - else: - raise AttributeError("File has to have default value or metadata") - if TypeParser.contains_type(FileSet, fld.type): - label = f"output field '{fld.name}' of {self}" - fld_out = TypeParser(fld.type, label=label).coerce(fld_out) - additional_out[fld.name] = fld_out - return additional_out - - def generated_output_names(self, inputs, output_dir): - """Returns a list of all outputs that will be generated by the task. - Takes into account the task input and the requires list for the output fields. - TODO: should be in all Output specs? - """ - # checking the input (if all mandatory fields are provided, etc.) - inputs.check_fields_input_spec() - output_names = ["return_code", "stdout", "stderr"] - for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): - if fld.type not in [File, MultiOutputFile, Directory]: - raise Exception("not implemented (collect_additional_output)") - # assuming that field should have either default or metadata, but not both - if ( - fld.default in (None, attr.NOTHING) and not fld.metadata - ): # TODO: is it right? - raise AttributeError("File has to have default value or metadata") - elif fld.default != attr.NOTHING: - output_names.append(fld.name) - elif ( - fld.metadata - and self._field_metadata( - fld, inputs, output_dir, outputs=None, check_existance=False - ) - != attr.NOTHING - ): - output_names.append(fld.name) - return output_names - - def _field_defaultvalue(self, fld, output_dir): - """Collect output file if the default value specified.""" - if not isinstance(fld.default, (str, Path)): - raise AttributeError( - f"{fld.name} is a File, so default value " - f"should be a string or a Path, " - f"{fld.default} provided" - ) - default = fld.default - if isinstance(default, str): - default = Path(default) - - default = output_dir / default - if "*" not in str(default): - if default.exists(): - return default - else: - raise AttributeError(f"file {default} does not exist") - else: - all_files = [Path(el) for el in glob(str(default.expanduser()))] - if len(all_files) > 1: - return all_files - elif len(all_files) == 1: - return all_files[0] - else: - raise AttributeError(f"no file matches {default.name}") - - def _field_metadata( - self, fld, inputs, output_dir, outputs=None, check_existance=True - ): - """Collect output file if metadata specified.""" - if self._check_requires(fld, inputs) is False: - return attr.NOTHING - - if "value" in fld.metadata: - return output_dir / fld.metadata["value"] - # this block is only run if "output_file_template" is provided in output_spec - # if the field is set in input_spec with output_file_template, - # than the field already should have value - elif "output_file_template" in fld.metadata: - value = template_update_single( - fld, inputs=inputs, output_dir=output_dir, spec_type="output" - ) - - if fld.type is MultiOutputFile and type(value) is list: - # TODO: how to deal with mandatory list outputs - ret = [] - for val in value: - val = Path(val) - if check_existance and not val.exists(): - ret.append(attr.NOTHING) - else: - ret.append(val) - return ret - else: - val = Path(value) - # checking if the file exists - if check_existance and not val.exists(): - # if mandatory raise exception - if "mandatory" in fld.metadata: - if fld.metadata["mandatory"]: - raise Exception( - f"mandatory output for variable {fld.name} does not exist" - ) - return attr.NOTHING - return val - elif "callable" in fld.metadata: - callable_ = fld.metadata["callable"] - if isinstance(callable_, staticmethod): - # In case callable is defined as a static method, - # retrieve the function wrapped in the descriptor. - callable_ = callable_.__func__ - call_args = inspect.getfullargspec(callable_) - call_args_val = {} - for argnm in call_args.args: - if argnm == "field": - call_args_val[argnm] = fld - elif argnm == "output_dir": - call_args_val[argnm] = output_dir - elif argnm == "inputs": - call_args_val[argnm] = inputs - elif argnm == "stdout": - call_args_val[argnm] = outputs["stdout"] - elif argnm == "stderr": - call_args_val[argnm] = outputs["stderr"] - else: - try: - call_args_val[argnm] = getattr(inputs, argnm) - except AttributeError: - raise AttributeError( - f"arguments of the callable function from {fld.name} " - f"has to be in inputs or be field or output_dir, " - f"but {argnm} is used" - ) - return callable_(**call_args_val) - else: - raise Exception( - f"Metadata for '{fld.name}', does not not contain any of the required fields " - f'("callable", "output_file_template" or "value"): {fld.metadata}.' - ) +# @attrs.define(auto_attribs=True, kw_only=True) +# class FunctionSpec(BaseSpec): +# """Specification for a process invoked from a shell.""" + +# def check_metadata(self): +# """ +# Check the metadata for fields in input_spec and fields. + +# Also sets the default values when available and needed. + +# """ +# supported_keys = { +# "allowed_values", +# "copyfile", +# "help_string", +# "mandatory", +# # "readonly", #likely not needed +# # "output_field_name", #likely not needed +# # "output_file_template", #likely not needed +# "requires", +# "keep_extension", +# "xor", +# "sep", +# } +# for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")): +# mdata = fld.metadata +# # checking keys from metadata +# if set(mdata.keys()) - supported_keys: +# raise AttributeError( +# f"only these keys are supported {supported_keys}, but " +# f"{set(mdata.keys()) - supported_keys} provided" +# ) +# # checking if the help string is provided (required field) +# if "help_string" not in mdata: +# raise AttributeError(f"{fld.name} doesn't have help_string field") +# # not allowing for default if the field is mandatory +# if not fld.default == attrs.NOTHING and mdata.get("mandatory"): +# raise AttributeError( +# f"default value ({fld.default!r}) should not be set when the field " +# f"('{fld.name}') in {self}) is mandatory" +# ) +# # setting default if value not provided and default is available +# if getattr(self, fld.name) is None: +# if not fld.default == attrs.NOTHING: +# setattr(self, fld.name, fld.default) + + +# @attrs.define(auto_attribs=True, kw_only=True) +# class ShellSpec(BaseSpec): +# """Specification for a process invoked from a shell.""" + +# executable: ty.Union[str, ty.List[str]] = attrs.field( +# metadata={ +# "help_string": "the first part of the command, can be a string, " +# "e.g. 'ls', or a list, e.g. ['ls', '-l', 'dirname']" +# } +# ) +# args: ty.Union[str, ty.List[str], None] = attrs.field( +# default=None, +# metadata={ +# "help_string": "the last part of the command, can be a string, " +# "e.g. , or a list" +# }, +# ) + +# def retrieve_values(self, wf, state_index=None): +# """Parse output results.""" +# temp_values = {} +# for field in attr_fields(self): +# # retrieving values that do not have templates +# if not field.metadata.get("output_file_template"): +# value = getattr(self, field.name) +# if isinstance(value, LazyField): +# temp_values[field.name] = value.get_value( +# wf, state_index=state_index +# ) +# for field, val in temp_values.items(): +# value = path_to_string(value) +# setattr(self, field, val) + +# def check_metadata(self): +# """ +# Check the metadata for fields in input_spec and fields. + +# Also sets the default values when available and needed. + +# """ +# from pydra.utils.typing import TypeParser + +# supported_keys = { +# "allowed_values", +# "argstr", +# "container_path", +# "copyfile", +# "help_string", +# "mandatory", +# "readonly", +# "output_field_name", +# "output_file_template", +# "position", +# "requires", +# "keep_extension", +# "xor", +# "sep", +# "formatter", +# "_output_type", +# } + +# for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")): +# mdata = fld.metadata +# # checking keys from metadata +# if set(mdata.keys()) - supported_keys: +# raise AttributeError( +# f"only these keys are supported {supported_keys}, but " +# f"{set(mdata.keys()) - supported_keys} provided for '{fld.name}' " +# f"field in {self}" +# ) +# # checking if the help string is provided (required field) +# if "help_string" not in mdata: +# raise AttributeError( +# f"{fld.name} doesn't have help_string field in {self}" +# ) +# # assuming that fields with output_file_template shouldn't have default +# if mdata.get("output_file_template"): +# if not any( +# TypeParser.matches_type(fld.type, t) for t in OUTPUT_TEMPLATE_TYPES +# ): +# raise TypeError( +# f"Type of '{fld.name}' should be one of {OUTPUT_TEMPLATE_TYPES} " +# f"(not {fld.type}) because it has a value for output_file_template " +# f"({mdata['output_file_template']!r})" +# ) +# if fld.default not in [attrs.NOTHING, True, False]: +# raise AttributeError( +# f"default value ({fld.default!r}) should not be set together with " +# f"output_file_template ({mdata['output_file_template']!r}) for " +# f"'{fld.name}' field in {self}" +# ) +# # not allowing for default if the field is mandatory +# if not fld.default == attrs.NOTHING and mdata.get("mandatory"): +# raise AttributeError( +# f"default value ({fld.default!r}) should not be set when the field " +# f"('{fld.name}') in {self}) is mandatory" +# ) +# # setting default if value not provided and default is available +# if getattr(self, fld.name) is None: +# if not fld.default == attrs.NOTHING: +# setattr(self, fld.name, fld.default) + + +# @attrs.define(auto_attribs=True, kw_only=True) +# class ShellOutSpec: +# """Output specification of a generic shell process.""" + +# return_code: int +# """The process' exit code.""" +# stdout: str +# """The process' standard output.""" +# stderr: str +# """The process' standard input.""" + +# def collect_additional_outputs(self, inputs, output_dir, outputs): +# from pydra.utils.typing import TypeParser + +# """Collect additional outputs from shelltask output_spec.""" +# additional_out = {} +# for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): +# if not TypeParser.is_subclass( +# fld.type, +# ( +# os.PathLike, +# MultiOutputObj, +# int, +# float, +# bool, +# str, +# list, +# ), +# ): +# raise TypeError( +# f"Support for {fld.type} type, required for '{fld.name}' in {self}, " +# "has not been implemented in collect_additional_output" +# ) +# # assuming that field should have either default or metadata, but not both +# input_value = getattr(inputs, fld.name, attrs.NOTHING) +# if fld.metadata and "callable" in fld.metadata: +# fld_out = self._field_metadata(fld, inputs, output_dir, outputs) +# elif fld.type in [int, float, bool, str, list]: +# raise AttributeError(f"{fld.type} has to have a callable in metadata") +# elif input_value: # Map input value through to output +# fld_out = input_value +# elif fld.default != attrs.NOTHING: +# fld_out = self._field_defaultvalue(fld, output_dir) +# else: +# raise AttributeError("File has to have default value or metadata") +# if TypeParser.contains_type(FileSet, fld.type): +# label = f"output field '{fld.name}' of {self}" +# fld_out = TypeParser(fld.type, label=label).coerce(fld_out) +# additional_out[fld.name] = fld_out +# return additional_out + +# def generated_output_names(self, inputs, output_dir): +# """Returns a list of all outputs that will be generated by the task. +# Takes into account the task input and the requires list for the output fields. +# TODO: should be in all Output specs? +# """ +# # checking the input (if all mandatory fields are provided, etc.) +# inputs.check_fields_input_spec() +# output_names = ["return_code", "stdout", "stderr"] +# for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): +# if fld.type not in [File, MultiOutputFile, Directory]: +# raise Exception("not implemented (collect_additional_output)") +# # assuming that field should have either default or metadata, but not both +# if ( +# fld.default in (None, attrs.NOTHING) and not fld.metadata +# ): # TODO: is it right? +# raise AttributeError("File has to have default value or metadata") +# elif fld.default != attrs.NOTHING: +# output_names.append(fld.name) +# elif ( +# fld.metadata +# and self._field_metadata( +# fld, inputs, output_dir, outputs=None, check_existance=False +# ) +# != attrs.NOTHING +# ): +# output_names.append(fld.name) +# return output_names + +# def _field_defaultvalue(self, fld, output_dir): +# """Collect output file if the default value specified.""" +# if not isinstance(fld.default, (str, Path)): +# raise AttributeError( +# f"{fld.name} is a File, so default value " +# f"should be a string or a Path, " +# f"{fld.default} provided" +# ) +# default = fld.default +# if isinstance(default, str): +# default = Path(default) + +# default = output_dir / default +# if "*" not in str(default): +# if default.exists(): +# return default +# else: +# raise AttributeError(f"file {default} does not exist") +# else: +# all_files = [Path(el) for el in glob(str(default.expanduser()))] +# if len(all_files) > 1: +# return all_files +# elif len(all_files) == 1: +# return all_files[0] +# else: +# raise AttributeError(f"no file matches {default.name}") + +# def _field_metadata( +# self, fld, inputs, output_dir, outputs=None, check_existance=True +# ): +# """Collect output file if metadata specified.""" +# if self._check_requires(fld, inputs) is False: +# return attrs.NOTHING + +# if "value" in fld.metadata: +# return output_dir / fld.metadata["value"] +# # this block is only run if "output_file_template" is provided in output_spec +# # if the field is set in input_spec with output_file_template, +# # than the field already should have value +# elif "output_file_template" in fld.metadata: +# value = template_update_single( +# fld, inputs=inputs, output_dir=output_dir, spec_type="output" +# ) + +# if fld.type is MultiOutputFile and type(value) is list: +# # TODO: how to deal with mandatory list outputs +# ret = [] +# for val in value: +# val = Path(val) +# if check_existance and not val.exists(): +# ret.append(attrs.NOTHING) +# else: +# ret.append(val) +# return ret +# else: +# val = Path(value) +# # checking if the file exists +# if check_existance and not val.exists(): +# # if mandatory raise exception +# if "mandatory" in fld.metadata: +# if fld.metadata["mandatory"]: +# raise Exception( +# f"mandatory output for variable {fld.name} does not exist" +# ) +# return attrs.NOTHING +# return val +# elif "callable" in fld.metadata: +# callable_ = fld.metadata["callable"] +# if isinstance(callable_, staticmethod): +# # In case callable is defined as a static method, +# # retrieve the function wrapped in the descriptor. +# callable_ = callable_.__func__ +# call_args = inspect.getfullargspec(callable_) +# call_args_val = {} +# for argnm in call_args.args: +# if argnm == "field": +# call_args_val[argnm] = fld +# elif argnm == "output_dir": +# call_args_val[argnm] = output_dir +# elif argnm == "inputs": +# call_args_val[argnm] = inputs +# elif argnm == "stdout": +# call_args_val[argnm] = outputs["stdout"] +# elif argnm == "stderr": +# call_args_val[argnm] = outputs["stderr"] +# else: +# try: +# call_args_val[argnm] = getattr(inputs, argnm) +# except AttributeError: +# raise AttributeError( +# f"arguments of the callable function from {fld.name} " +# f"has to be in inputs or be field or output_dir, " +# f"but {argnm} is used" +# ) +# return callable_(**call_args_val) +# else: +# raise Exception( +# f"Metadata for '{fld.name}', does not not contain any of the required fields " +# f'("callable", "output_file_template" or "value"): {fld.metadata}.' +# ) + +# def _check_requires(self, fld, inputs): +# """checking if all fields from the requires and template are set in the input +# if requires is a list of list, checking if at least one list has all elements set +# """ +# from .helpers import ensure_list + +# if "requires" in fld.metadata: +# # if requires is a list of list it is treated as el[0] OR el[1] OR... +# required_fields = ensure_list(fld.metadata["requires"]) +# if all([isinstance(el, list) for el in required_fields]): +# field_required_OR = required_fields +# # if requires is a list of tuples/strings - I'm creating a 1-el nested list +# elif all([isinstance(el, (str, tuple)) for el in required_fields]): +# field_required_OR = [required_fields] +# else: +# raise Exception( +# f"requires field can be a list of list, or a list " +# f"of strings/tuples, but {fld.metadata['requires']} " +# f"provided for {fld.name}" +# ) +# else: +# field_required_OR = [[]] + +# for field_required in field_required_OR: +# # if the output has output_file_template field, +# # adding all input fields from the template to requires +# if "output_file_template" in fld.metadata: +# template = fld.metadata["output_file_template"] +# # if a template is a function it has to be run first with the inputs as the only arg +# if callable(template): +# template = template(inputs) +# inp_fields = re.findall(r"{\w+}", template) +# field_required += [ +# el[1:-1] for el in inp_fields if el[1:-1] not in field_required +# ] + +# # it's a flag, of the field from the list is not in input it will be changed to False +# required_found = True +# for field_required in field_required_OR: +# required_found = True +# # checking if the input fields from requires have set values +# for inp in field_required: +# if isinstance(inp, str): # name of the input field +# if not hasattr(inputs, inp): +# raise Exception( +# f"{inp} is not a valid input field, can't be used in requires" +# ) +# elif getattr(inputs, inp) in [attrs.NOTHING, None]: +# required_found = False +# break +# elif isinstance(inp, tuple): # (name, allowed values) +# inp, allowed_val = inp[0], ensure_list(inp[1]) +# if not hasattr(inputs, inp): +# raise Exception( +# f"{inp} is not a valid input field, can't be used in requires" +# ) +# elif getattr(inputs, inp) not in allowed_val: +# required_found = False +# break +# else: +# raise Exception( +# f"each element of the requires element should be a string or a tuple, " +# f"but {inp} is found in {field_required}" +# ) +# # if the specific list from field_required_OR has all elements set, no need to check more +# if required_found: +# break + +# if required_found: +# return True +# else: +# return False + + +# @attrs.define +# class LazyInterface: +# _task: "core.Task" = attrs.field() +# _attr_type: str + +# def __getattr__(self, name): +# if name in ("_task", "_attr_type", "_field_names"): +# raise AttributeError(f"{name} hasn't been set yet") +# if name not in self._field_names: +# raise AttributeError( +# f"Task '{self._task.name}' has no {self._attr_type} attribute '{name}', " +# "available: '" + "', '".join(self._field_names) + "'" +# ) +# type_ = self._get_type(name) +# splits = self._get_task_splits() +# combines = self._get_task_combines() +# if combines and self._attr_type == "output": +# # Add in any scalar splits referencing upstream splits, i.e. "_myupstreamtask", +# # "_myarbitrarytask" +# combined_upstreams = set() +# if self._task.state: +# for scalar in LazyField.normalize_splitter( +# self._task.state.splitter, strip_previous=False +# ): +# for field in scalar: +# if field.startswith("_"): +# node_name = field[1:] +# if any(c.split(".")[0] == node_name for c in combines): +# combines.update( +# f for f in scalar if not f.startswith("_") +# ) +# combined_upstreams.update( +# f[1:] for f in scalar if f.startswith("_") +# ) +# if combines: +# # Wrap type in list which holds the combined items +# type_ = ty.List[type_] +# # Iterate through splits to remove any splits which are removed by the +# # combiner +# for splitter in copy(splits): +# remaining = tuple( +# s +# for s in splitter +# if not any( +# (x in combines or x.split(".")[0] in combined_upstreams) +# for x in s +# ) +# ) +# if remaining != splitter: +# splits.remove(splitter) +# if remaining: +# splits.add(remaining) +# # Wrap the type in a nested StateArray type +# if splits: +# type_ = StateArray[type_] +# lf_klass = LazyInField if self._attr_type == "input" else LazyOutField +# return lf_klass[type_]( +# name=self._task.name, +# field=name, +# type=type_, +# splits=splits, +# ) + +# def _get_task_splits(self) -> ty.Set[ty.Tuple[ty.Tuple[str, ...], ...]]: +# """Returns the states over which the inputs of the task are split""" +# splitter = self._task.state.splitter if self._task.state else None +# splits = set() +# if splitter: +# # Ensure that splits is of tuple[tuple[str, ...], ...] form +# splitter = LazyField.normalize_splitter(splitter) +# if splitter: +# splits.add(splitter) +# for inpt in attrs.asdict(self._task.inputs, recurse=False).values(): +# if isinstance(inpt, LazyField): +# splits.update(inpt.splits) +# return splits + +# def _get_task_combines(self) -> ty.Set[ty.Union[str, ty.Tuple[str, ...]]]: +# """Returns the states over which the outputs of the task are combined""" +# combiner = ( +# self._task.state.combiner +# if self._task.state is not None +# else getattr(self._task, "fut_combiner", None) +# ) +# return set(combiner) if combiner else set() + + +# class LazyIn(LazyInterface): +# _attr_type = "input" + +# def _get_type(self, name): +# attr = next(t for n, t in self._task.input_spec.fields if n == name) +# if attr is None: +# return ty.Any +# elif inspect.isclass(attr): +# return attr +# else: +# return attr.type + +# @property +# def _field_names(self): +# return [field[0] for field in self._task.input_spec.fields] + + +# class LazyOut(LazyInterface): +# _attr_type = "output" + +# def _get_type(self, name): +# try: +# type_ = next(f[1] for f in self._task.output_spec.fields if f[0] == name) +# except StopIteration: +# type_ = ty.Any +# else: +# if not inspect.isclass(type_): +# try: +# type_ = type_.type # attrs _CountingAttribute +# except AttributeError: +# pass # typing._SpecialForm +# return type_ + +# @property +# def _field_names(self): +# return self._task.output_names + ["all_"] - def _check_requires(self, fld, inputs): - """checking if all fields from the requires and template are set in the input - if requires is a list of list, checking if at least one list has all elements set - """ - from .helpers import ensure_list - - if "requires" in fld.metadata: - # if requires is a list of list it is treated as el[0] OR el[1] OR... - required_fields = ensure_list(fld.metadata["requires"]) - if all([isinstance(el, list) for el in required_fields]): - field_required_OR = required_fields - # if requires is a list of tuples/strings - I'm creating a 1-el nested list - elif all([isinstance(el, (str, tuple)) for el in required_fields]): - field_required_OR = [required_fields] - else: - raise Exception( - f"requires field can be a list of list, or a list " - f"of strings/tuples, but {fld.metadata['requires']} " - f"provided for {fld.name}" - ) - else: - field_required_OR = [[]] - - for field_required in field_required_OR: - # if the output has output_file_template field, - # adding all input fields from the template to requires - if "output_file_template" in fld.metadata: - template = fld.metadata["output_file_template"] - # if a template is a function it has to be run first with the inputs as the only arg - if callable(template): - template = template(inputs) - inp_fields = re.findall(r"{\w+}", template) - field_required += [ - el[1:-1] for el in inp_fields if el[1:-1] not in field_required - ] - - # it's a flag, of the field from the list is not in input it will be changed to False - required_found = True - for field_required in field_required_OR: - required_found = True - # checking if the input fields from requires have set values - for inp in field_required: - if isinstance(inp, str): # name of the input field - if not hasattr(inputs, inp): - raise Exception( - f"{inp} is not a valid input field, can't be used in requires" - ) - elif getattr(inputs, inp) in [attr.NOTHING, None]: - required_found = False - break - elif isinstance(inp, tuple): # (name, allowed values) - inp, allowed_val = inp[0], ensure_list(inp[1]) - if not hasattr(inputs, inp): - raise Exception( - f"{inp} is not a valid input field, can't be used in requires" - ) - elif getattr(inputs, inp) not in allowed_val: - required_found = False - break - else: - raise Exception( - f"each element of the requires element should be a string or a tuple, " - f"but {inp} is found in {field_required}" - ) - # if the specific list from field_required_OR has all elements set, no need to check more - if required_found: - break - - if required_found: - return True - else: - return False +def donothing(*args, **kwargs): + return None -@attr.s -class LazyInterface: - _task: "core.Task" = attr.ib() - _attr_type: str - def __getattr__(self, name): - if name in ("_task", "_attr_type", "_field_names"): - raise AttributeError(f"{name} hasn't been set yet") - if name not in self._field_names: - raise AttributeError( - f"Task '{self._task.name}' has no {self._attr_type} attribute '{name}', " - "available: '" + "', '".join(self._field_names) + "'" - ) - type_ = self._get_type(name) - splits = self._get_task_splits() - combines = self._get_task_combines() - if combines and self._attr_type == "output": - # Add in any scalar splits referencing upstream splits, i.e. "_myupstreamtask", - # "_myarbitrarytask" - combined_upstreams = set() - if self._task.state: - for scalar in LazyField.normalize_splitter( - self._task.state.splitter, strip_previous=False - ): - for field in scalar: - if field.startswith("_"): - node_name = field[1:] - if any(c.split(".")[0] == node_name for c in combines): - combines.update( - f for f in scalar if not f.startswith("_") - ) - combined_upstreams.update( - f[1:] for f in scalar if f.startswith("_") - ) - if combines: - # Wrap type in list which holds the combined items - type_ = ty.List[type_] - # Iterate through splits to remove any splits which are removed by the - # combiner - for splitter in copy(splits): - remaining = tuple( - s - for s in splitter - if not any( - (x in combines or x.split(".")[0] in combined_upstreams) - for x in s - ) - ) - if remaining != splitter: - splits.remove(splitter) - if remaining: - splits.add(remaining) - # Wrap the type in a nested StateArray type - if splits: - type_ = StateArray[type_] - lf_klass = LazyInField if self._attr_type == "input" else LazyOutField - return lf_klass[type_]( - name=self._task.name, - field=name, - type=type_, - splits=splits, - ) +@attrs.define(auto_attribs=True, kw_only=True) +class TaskHook: + """Callable task hooks.""" - def _get_task_splits(self) -> ty.Set[ty.Tuple[ty.Tuple[str, ...], ...]]: - """Returns the states over which the inputs of the task are split""" - splitter = self._task.state.splitter if self._task.state else None - splits = set() - if splitter: - # Ensure that splits is of tuple[tuple[str, ...], ...] form - splitter = LazyField.normalize_splitter(splitter) - if splitter: - splits.add(splitter) - for inpt in attr.asdict(self._task.inputs, recurse=False).values(): - if isinstance(inpt, LazyField): - splits.update(inpt.splits) - return splits - - def _get_task_combines(self) -> ty.Set[ty.Union[str, ty.Tuple[str, ...]]]: - """Returns the states over which the outputs of the task are combined""" - combiner = ( - self._task.state.combiner - if self._task.state is not None - else getattr(self._task, "fut_combiner", None) - ) - return set(combiner) if combiner else set() + pre_run_task: ty.Callable = donothing + post_run_task: ty.Callable = donothing + pre_run: ty.Callable = donothing + post_run: ty.Callable = donothing + def __setattr__(self, attr, val): + if attr not in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: + raise AttributeError("Cannot set unknown hook") + super().__setattr__(attr, val) -class LazyIn(LazyInterface): - _attr_type = "input" + def reset(self): + for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: + setattr(self, val, donothing) - def _get_type(self, name): - attr = next(t for n, t in self._task.input_spec.fields if n == name) - if attr is None: - return ty.Any - elif inspect.isclass(attr): - return attr - else: - return attr.type - @property - def _field_names(self): - return [field[0] for field in self._task.input_spec.fields] +def path_to_string(value): + """Convert paths to strings.""" + if isinstance(value, Path): + value = str(value) + elif isinstance(value, list) and len(value) and isinstance(value[0], Path): + value = [str(val) for val in value] + return value -class LazyOut(LazyInterface): - _attr_type = "output" +class OutputsSpec: + """Base class for all output specifications""" - def _get_type(self, name): - try: - type_ = next(f[1] for f in self._task.output_spec.fields if f[0] == name) - except StopIteration: - type_ = ty.Any - else: - if not inspect.isclass(type_): - try: - type_ = type_.type # attrs _CountingAttribute - except AttributeError: - pass # typing._SpecialForm - return type_ - - @property - def _field_names(self): - return self._task.output_names + ["all_"] - - -TypeOrAny = ty.Union[ty.Type[T], ty.Any] -Splitter = ty.Union[str, ty.Tuple[str, ...]] - - -@attr.s(auto_attribs=True, kw_only=True) -class LazyField(ty.Generic[T]): - """Lazy fields implement promises.""" - - name: str - field: str - type: TypeOrAny - # Set of splitters that have been applied to the lazy field. Note that the splitter - # specifications are transformed to a tuple[tuple[str, ...], ...] form where the - # outer tuple is the outer product, the inner tuple are inner products (where either - # product can be of length==1) - splits: ty.FrozenSet[ty.Tuple[ty.Tuple[str, ...], ...]] = attr.field( - factory=frozenset, converter=frozenset - ) - cast_from: ty.Optional[ty.Type[ty.Any]] = None - # type_checked will be set to False after it is created but defaults to True here for - # ease of testing - type_checked: bool = True - - def __bytes_repr__(self, cache): - yield type(self).__name__.encode() - yield self.name.encode() - yield self.field.encode() - - def cast(self, new_type: TypeOrAny) -> Self: - """ "casts" the lazy field to a new type + def split( + self, + splitter: ty.Union[str, ty.List[str], ty.Tuple[str, ...], None] = None, + /, + overwrite: bool = False, + cont_dim: ty.Optional[dict] = None, + **inputs, + ) -> Self: + """ + Run this task parametrically over lists of split inputs. Parameters ---------- - new_type : type - the type to cast the lazy-field to + splitter : str or list[str] or tuple[str] or None + the fields which to split over. If splitting over multiple fields, lists of + fields are interpreted as outer-products and tuples inner-products. If None, + then the fields to split are taken from the keyword-arg names. + overwrite : bool, optional + whether to overwrite an existing split on the node, by default False + cont_dim : dict, optional + Container dimensions for specific inputs, used in the splitter. + If input name is not in cont_dim, it is assumed that the input values has + a container dimension of 1, so only the most outer dim will be used for splitting. + **inputs + fields to split over, will automatically be wrapped in a StateArray object + and passed to the node inputs Returns ------- - cast_field : LazyField - a copy of the lazy field with the new type + self : TaskBase + a reference to the task """ - return type(self)[new_type]( - name=self.name, - field=self.field, - type=new_type, - splits=self.splits, - cast_from=self.cast_from if self.cast_from else self.type, - ) - - def split(self, splitter: Splitter) -> Self: - """ "Splits" the lazy field over an array of nodes by replacing the sequence type - of the lazy field with StateArray to signify that it will be "split" across - - Parameters - ---------- - splitter : str or ty.Tuple[str, ...] or ty.List[str] - the splitter to append to the list of splitters + self._node.split(splitter, overwrite=overwrite, cont_dim=cont_dim, **inputs) + return self + + def combine( + self, + combiner: ty.Union[ty.List[str], str], + overwrite: bool = False, # **kwargs + ) -> Self: """ - from pydra.utils.typing import ( - TypeParser, - ) # pylint: disable=import-outside-toplevel - - splits = self.splits | set([LazyField.normalize_splitter(splitter)]) - # Check to see whether the field has already been split over the given splitter - if splits == self.splits: - return self - - # Modify the type of the lazy field to include the split across a state-array - inner_type, prev_split_depth = TypeParser.strip_splits(self.type) - assert prev_split_depth <= 1 - if inner_type is ty.Any: - type_ = StateArray[ty.Any] - elif TypeParser.matches_type(inner_type, list): - item_type = TypeParser.get_item_type(inner_type) - type_ = StateArray[item_type] - else: - raise TypeError( - f"Cannot split non-sequence field {self} of type {inner_type}" - ) - if prev_split_depth: - type_ = StateArray[type_] - return type(self)[type_]( - name=self.name, - field=self.field, - type=type_, - splits=splits, - ) - - # def combine(self, combiner: str | list[str]) -> Self: - - @classmethod - def normalize_splitter( - cls, splitter: Splitter, strip_previous: bool = True - ) -> ty.Tuple[ty.Tuple[str, ...], ...]: - """Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form - used in LazyFields""" - if isinstance(splitter, str): - splitter = (splitter,) - if isinstance(splitter, tuple): - splitter = (splitter,) # type: ignore - else: - assert isinstance(splitter, list) - # convert to frozenset to differentiate from tuple, yet still be hashable - # (NB: order of fields in list splitters aren't relevant) - splitter = tuple((s,) if isinstance(s, str) else s for s in splitter) - # Strip out fields starting with "_" designating splits in upstream nodes - if strip_previous: - stripped = tuple( - tuple(f for f in i if not f.startswith("_")) for i in splitter - ) - splitter = tuple(s for s in stripped if s) # type: ignore - return splitter # type: ignore - - def _apply_cast(self, value): - """\"Casts\" the value from the retrieved type if a cast has been applied to - the lazy-field""" - from pydra.utils.typing import TypeParser - - if self.cast_from: - assert TypeParser.matches(value, self.cast_from) - value = self.type(value) - return value - - -@attr.s(auto_attribs=True, kw_only=True) -class LazyInField(LazyField[T]): - - name: str = None - attr_type = "input" - - def get_value( - self, wf: "pydra.engine.workflow.Workflow", state_index: ty.Optional[int] = None - ) -> ty.Any: - """Return the value of a lazy field. + Combine inputs parameterized by one or more previous tasks. Parameters ---------- - wf : Workflow - the workflow the lazy field references - state_index : int, optional - the state index of the field to access + combiner : list[str] or str + the field or list of inputs to be combined (i.e. not left split) after the + task has been run + overwrite : bool + whether to overwrite an existing combiner on the node + **kwargs : dict[str, Any] + values for the task that will be "combined" before they are provided to the + node Returns ------- - value : Any - the resolved value of the lazy-field + self : Self + a reference to the outputs object """ - from pydra.utils.typing import ( - TypeParser, - ) # pylint: disable=import-outside-toplevel - - value = getattr(wf.inputs, self.field) - if TypeParser.is_subclass(self.type, StateArray) and not wf._pre_split: - _, split_depth = TypeParser.strip_splits(self.type) - - def apply_splits(obj, depth): - if depth < 1: - return obj - return StateArray[self.type](apply_splits(i, depth - 1) for i in obj) + self._node.combine(combiner, overwrite=overwrite) + return self - value = apply_splits(value, split_depth) - value = self._apply_cast(value) - return value +OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) -class LazyOutField(LazyField[T]): - attr_type = "output" - def get_value( - self, wf: "pydra.Workflow", state_index: ty.Optional[int] = None - ) -> ty.Any: - """Return the value of a lazy field. +class TaskSpec(ty.Generic[OutputType]): + """Base class for all task specifications""" - Parameters - ---------- - wf : Workflow - the workflow the lazy field references - state_index : int, optional - the state index of the field to access + Task: "ty.Type[core.Task]" - Returns - ------- - value : Any - the resolved value of the lazy-field - """ - from pydra.utils.typing import ( - TypeParser, - ) # pylint: disable=import-outside-toplevel - - node = getattr(wf, self.name) - result = node.result(state_index=state_index) - if result is None: - raise RuntimeError( - f"Could not find results of '{node.name}' node in a sub-directory " - f"named '{node.checksum}' in any of the cache locations.\n" - + "\n".join(str(p) for p in set(node.cache_locations)) - + f"\n\nThis is likely due to hash changes in '{self.name}' node inputs. " - f"Current values and hashes: {node.inputs}, " - f"{node.inputs._hashes}\n\n" - "Set loglevel to 'debug' in order to track hash changes " - "throughout the execution of the workflow.\n\n " - "These issues may have been caused by `bytes_repr()` methods " - "that don't return stable hash values for specific object " - "types across multiple processes (see bytes_repr() " - '"singledispatch "function in pydra/utils/hash.py).' - "You may need to write specific `bytes_repr()` " - "implementations (see `pydra.utils.hash.register_serializer`) or a " - "`__bytes_repr__()` dunder methods to handle one or more types in " - "your interface inputs." + def __call__( + self, + name: str | None = None, + audit_flags: AuditFlag = AuditFlag.NONE, + cache_dir=None, + cache_locations=None, + inputs: ty.Text | File | dict[str, ty.Any] | None = None, + cont_dim=None, + messenger_args=None, + messengers=None, + rerun=False, + **kwargs, + ): + self._check_for_unset_values() + task = self.Task( + self, + name=name, + audit_flags=audit_flags, + cache_dir=cache_dir, + cache_locations=cache_locations, + inputs=inputs, + cont_dim=cont_dim, + messenger_args=messenger_args, + messengers=messengers, + rerun=rerun, + ) + return task(**kwargs) + + def _check_for_unset_values(self): + if unset := [ + k + for k, v in attrs.asdict(self, recurse=False).items() + if v is attrs.NOTHING + ]: + raise ValueError( + f"The following values {unset} in the {self!r} interface need to be set " + "before the workflow can be constructed" ) - _, split_depth = TypeParser.strip_splits(self.type) - - def get_nested_results(res, depth: int): - if isinstance(res, list): - if not depth: - val = [r.get_output_field(self.field) for r in res] - else: - val = StateArray[self.type]( - get_nested_results(res=r, depth=depth - 1) for r in res - ) - else: - if res.errored: - raise ValueError( - f"Cannot retrieve value for {self.field} from {self.name} as " - "the node errored" - ) - val = res.get_output_field(self.field) - if depth and not wf._pre_split: - assert isinstance(val, ty.Sequence) and not isinstance(val, str) - val = StateArray[self.type](val) - return val - - value = get_nested_results(result, depth=split_depth) - value = self._apply_cast(value) - return value - - -class StateArray(ty.List[T]): - """an array of values from, or to be split over in an array of nodes (see TaskBase.split()), - multiple nodes of the same task. Used in type-checking to differentiate between list - types and values for multiple nodes - """ - - def __repr__(self): - return f"{type(self).__name__}(" + ", ".join(repr(i) for i in self) + ")" - - -def donothing(*args, **kwargs): - return None - - -@attr.s(auto_attribs=True, kw_only=True) -class TaskHook: - """Callable task hooks.""" - - pre_run_task: ty.Callable = donothing - post_run_task: ty.Callable = donothing - pre_run: ty.Callable = donothing - post_run: ty.Callable = donothing - - def __setattr__(self, attr, val): - if attr not in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: - raise AttributeError("Cannot set unknown hook") - super().__setattr__(attr, val) - - def reset(self): - for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: - setattr(self, val, donothing) - - -def path_to_string(value): - """Convert paths to strings.""" - if isinstance(value, Path): - value = str(value) - elif isinstance(value, list) and len(value) and isinstance(value[0], Path): - value = [str(val) for val in value] - return value -from . import core # noqa +from pydra.engine import core # noqa: E402 diff --git a/pydra/engine/state.py b/pydra/engine/state.py index befbf86b9d..ffaddf3f3f 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -3,10 +3,11 @@ from copy import deepcopy import itertools from functools import reduce - +import attrs from . import helpers_state as hlpst from .helpers import ensure_list -from .specs import BaseSpec + +# from .specs import BaseSpec # TODO: move to State op = {".": zip, "*": itertools.product} @@ -763,8 +764,8 @@ def prepare_states(self, inputs, cont_dim=None): self.cont_dim = cont_dim else: self.cont_dim = {} - if isinstance(inputs, BaseSpec): - self.inputs = hlpst.inputs_types_to_dict(self.name, inputs) + if attrs.has(inputs): + self.inputs = attrs.asdict(inputs, recurse=False) else: self.inputs = inputs if self.other_states: diff --git a/pydra/engine/task.py b/pydra/engine/task.py index b9317602bd..68731f47bc 100644 --- a/pydra/engine/task.py +++ b/pydra/engine/task.py @@ -55,10 +55,11 @@ from .core import Task, is_lazy from pydra.utils.messenger import AuditFlag from .specs import ( - BaseSpec, - SpecInfo, + # BaseSpec, + # SpecInfo, # ShellSpec, # ShellOutSpec, + TaskSpec, attr_fields, ) from .helpers import ( @@ -78,16 +79,14 @@ class FunctionTask(Task): def __init__( self, - func: ty.Callable, + spec: TaskSpec, audit_flags: AuditFlag = AuditFlag.NONE, cache_dir=None, cache_locations=None, - input_spec: ty.Optional[ty.Union[SpecInfo, BaseSpec]] = None, cont_dim=None, messenger_args=None, messengers=None, name=None, - output_spec: ty.Optional[ty.Union[SpecInfo, BaseSpec]] = None, rerun=False, **kwargs, ): @@ -226,14 +225,13 @@ class ShellCommandTask(Task): def __init__( self, + spec: TaskSpec, audit_flags: AuditFlag = AuditFlag.NONE, cache_dir=None, - input_spec: ty.Optional[SpecInfo] = None, cont_dim=None, messenger_args=None, messengers=None, name=None, - output_spec: ty.Optional[SpecInfo] = None, rerun=False, strip=False, environment=Native(), diff --git a/pydra/engine/workflow/__init__.py b/pydra/engine/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pydra/engine/workflow/base.py b/pydra/engine/workflow/base.py new file mode 100644 index 0000000000..5026109111 --- /dev/null +++ b/pydra/engine/workflow/base.py @@ -0,0 +1,178 @@ +import typing as ty +from copy import copy +from operator import itemgetter +from typing_extensions import Self +import attrs +from pydra.engine.helpers import list_fields +from pydra.engine.specs import TaskSpec, OutputsSpec +from .lazy import LazyInField +from pydra.utils.hash import hash_function +from pydra.utils.typing import TypeParser, StateArray +from .node import Node + + +OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) + + +@attrs.define(auto_attribs=False) +class Workflow(ty.Generic[OutputType]): + """A workflow, constructed from a workflow specification + + Parameters + ---------- + name : str + The name of the workflow + inputs : TaskSpec + The input specification of the workflow + outputs : TaskSpec + The output specification of the workflow + """ + + name: str = attrs.field() + inputs: TaskSpec[OutputType] = attrs.field() + outputs: OutputType = attrs.field() + _nodes: dict[str, Node] = attrs.field(factory=dict) + + @classmethod + def construct( + cls, + spec: TaskSpec[OutputType], + ) -> Self: + """Construct a workflow from a specification, caching the constructed worklow""" + + lazy_inputs = [f for f in list_fields(type(spec)) if f.lazy] + + # Create a cache key by hashing all the non-lazy input values in the spec + # and use this to store the constructed workflow in case it is reused or nested + # and split over within another workflow + lazy_input_names = {f.name for f in lazy_inputs} + non_lazy_vals = tuple( + sorted( + ( + i + for i in attrs.asdict(spec, recurse=False).items() + if i[0] not in lazy_input_names + ), + key=itemgetter(0), + ) + ) + hash_key = hash_function(non_lazy_vals) + if hash_key in cls._constructed: + return cls._constructed[hash_key] + + # Initialise the outputs of the workflow + outputs = spec.Outputs( + **{f.name: attrs.NOTHING for f in attrs.fields(spec.Outputs)} + ) + + # Initialise the lzin fields + lazy_spec = copy(spec) + wf = cls.under_construction = Workflow( + name=type(spec).__name__, + inputs=lazy_spec, + outputs=outputs, + ) + for lzy_inpt in lazy_inputs: + setattr( + lazy_spec, + lzy_inpt.name, + LazyInField( + node=wf, + field=lzy_inpt.name, + type=lzy_inpt.type, + ), + ) + + input_values = attrs.asdict(lazy_spec, recurse=False) + constructor = input_values.pop("constructor") + cls._under_construction = wf + try: + # Call the user defined constructor to set the outputs + output_lazy_fields = constructor(**input_values) + # Check to see whether any mandatory inputs are not set + for node in wf.nodes: + node.inputs._check_for_unset_values() + # Check that the outputs are set correctly, either directly by the constructor + # or via returned values that can be zipped with the output names + if output_lazy_fields: + if not isinstance(output_lazy_fields, (list, tuple)): + output_lazy_fields = [output_lazy_fields] + output_fields = list_fields(spec.Outputs) + if len(output_lazy_fields) != len(output_fields): + raise ValueError( + f"Expected {len(output_fields)} outputs, got " + f"{len(output_lazy_fields)} ({output_lazy_fields})" + ) + for outpt, outpt_lf in zip(output_fields, output_lazy_fields): + if TypeParser.get_origin(outpt_lf.type) is StateArray: + # Automatically combine any uncombined state arrays into lists + tp, _ = TypeParser.strip_splits(outpt_lf.type) + outpt_lf.type = list[tp] + outpt_lf.splits = frozenset() + setattr(outputs, outpt.name, outpt_lf) + else: + if unset_outputs := [ + a + for a, v in attrs.asdict(outputs, recurse=False).items() + if v is attrs.NOTHING + ]: + raise ValueError( + f"Expected outputs {unset_outputs} to be set by the " + f"constructor of {wf!r}" + ) + finally: + cls._under_construction = None + + cls._constructed[hash_key] = wf + + return wf + + def add(self, task_spec: TaskSpec[OutputType], name=None) -> OutputType: + """Add a node to the workflow + + Parameters + ---------- + task_spec : TaskSpec + The specification of the task to add to the workflow as a node + name : str, optional + The name of the node, by default it will be the name of the task specification + class + + Returns + ------- + OutputType + The outputs specification of the node + """ + if name is None: + name = type(task_spec).__name__ + if name in self._nodes: + raise ValueError(f"Node with name {name!r} already exists in the workflow") + node = Node[OutputType](name=name, spec=task_spec, workflow=self) + self._nodes[name] = node + return node.lzout + + def __getitem__(self, key: str) -> Node: + return self._nodes[key] + + @property + def nodes(self) -> ty.Iterable[Node]: + return self._nodes.values() + + @property + def node_names(self) -> list[str]: + return list(self._nodes) + + @property + @classmethod + def under_construction(cls) -> "Workflow[ty.Any]": + if cls._under_construction is None: + raise ValueError( + "pydra.design.workflow.this() can only be called from within a workflow " + "constructor function (see 'pydra.design.workflow.define')" + ) + return cls._under_construction + + # Used to store the workflow that is currently being constructed + _under_construction: "Workflow[ty.Any]" = None + # Used to cache the constructed workflows by their hashed input values + _constructed: dict[int, "Workflow[ty.Any]"] = {} diff --git a/pydra/engine/workflow/lazy.py b/pydra/engine/workflow/lazy.py new file mode 100644 index 0000000000..f9d7bbddbb --- /dev/null +++ b/pydra/engine/workflow/lazy.py @@ -0,0 +1,250 @@ +import typing as ty +from typing_extensions import Self +import attrs +from pydra.utils.typing import StateArray +from . import node + +if ty.TYPE_CHECKING: + from .base import Workflow + + +T = ty.TypeVar("T") + +TypeOrAny = ty.Union[type, ty.Any] +Splitter = ty.Union[str, ty.Tuple[str, ...]] + + +@attrs.define(auto_attribs=True, kw_only=True) +class LazyField(ty.Generic[T]): + """Lazy fields implement promises.""" + + node: node.Node + field: str + type: TypeOrAny + # Set of splitters that have been applied to the lazy field. Note that the splitter + # specifications are transformed to a tuple[tuple[str, ...], ...] form where the + # outer tuple is the outer product, the inner tuple are inner products (where either + # product can be of length==1) + splits: ty.FrozenSet[ty.Tuple[ty.Tuple[str, ...], ...]] = attrs.field( + factory=frozenset, converter=frozenset + ) + cast_from: ty.Optional[ty.Type[ty.Any]] = None + # type_checked will be set to False after it is created but defaults to True here for + # ease of testing + type_checked: bool = True + + def __bytes_repr__(self, cache): + yield type(self).__name__.encode() + yield self.name.encode() + yield self.field.encode() + + def cast(self, new_type: TypeOrAny) -> Self: + """ "casts" the lazy field to a new type + + Parameters + ---------- + new_type : type + the type to cast the lazy-field to + + Returns + ------- + cast_field : LazyField + a copy of the lazy field with the new type + """ + return type(self)[new_type]( + name=self.name, + field=self.field, + type=new_type, + splits=self.splits, + cast_from=self.cast_from if self.cast_from else self.type, + ) + + # def split(self, splitter: Splitter) -> Self: + # """ "Splits" the lazy field over an array of nodes by replacing the sequence type + # of the lazy field with StateArray to signify that it will be "split" across + + # Parameters + # ---------- + # splitter : str or ty.Tuple[str, ...] or ty.List[str] + # the splitter to append to the list of splitters + # """ + # from pydra.utils.typing import ( + # TypeParser, + # ) # pylint: disable=import-outside-toplevel + + # splits = self.splits | set([LazyField.normalize_splitter(splitter)]) + # # Check to see whether the field has already been split over the given splitter + # if splits == self.splits: + # return self + + # # Modify the type of the lazy field to include the split across a state-array + # inner_type, prev_split_depth = TypeParser.strip_splits(self.type) + # assert prev_split_depth <= 1 + # if inner_type is ty.Any: + # type_ = StateArray[ty.Any] + # elif TypeParser.matches_type(inner_type, list): + # item_type = TypeParser.get_item_type(inner_type) + # type_ = StateArray[item_type] + # else: + # raise TypeError( + # f"Cannot split non-sequence field {self} of type {inner_type}" + # ) + # if prev_split_depth: + # type_ = StateArray[type_] + # return type(self)[type_]( + # name=self.name, + # field=self.field, + # type=type_, + # splits=splits, + # ) + + # # def combine(self, combiner: str | list[str]) -> Self: + + # @classmethod + # def normalize_splitter( + # cls, splitter: Splitter, strip_previous: bool = True + # ) -> ty.Tuple[ty.Tuple[str, ...], ...]: + # """Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form + # used in LazyFields""" + # if isinstance(splitter, str): + # splitter = (splitter,) + # if isinstance(splitter, tuple): + # splitter = (splitter,) # type: ignore + # else: + # assert isinstance(splitter, list) + # # convert to frozenset to differentiate from tuple, yet still be hashable + # # (NB: order of fields in list splitters aren't relevant) + # splitter = tuple((s,) if isinstance(s, str) else s for s in splitter) + # # Strip out fields starting with "_" designating splits in upstream nodes + # if strip_previous: + # stripped = tuple( + # tuple(f for f in i if not f.startswith("_")) for i in splitter + # ) + # splitter = tuple(s for s in stripped if s) # type: ignore + # return splitter # type: ignore + + def _apply_cast(self, value): + """\"Casts\" the value from the retrieved type if a cast has been applied to + the lazy-field""" + from pydra.utils.typing import TypeParser + + if self.cast_from: + assert TypeParser.matches(value, self.cast_from) + value = self.type(value) + return value + + +@attrs.define(auto_attribs=True, kw_only=True) +class LazyInField(LazyField[T]): + + attr_type = "input" + + def __eq__(self, other): + return ( + isinstance(other, LazyInField) + and self.field == other.field + and self.type == other.type + and self.splits == other.splits + ) + + def get_value(self, wf: "Workflow", state_index: ty.Optional[int] = None) -> ty.Any: + """Return the value of a lazy field. + + Parameters + ---------- + wf : Workflow + the workflow the lazy field references + state_index : int, optional + the state index of the field to access + + Returns + ------- + value : Any + the resolved value of the lazy-field + """ + from pydra.utils.typing import ( + TypeParser, + ) # pylint: disable=import-outside-toplevel + + value = getattr(wf.inputs, self.field) + if TypeParser.is_subclass(self.type, StateArray) and not wf._pre_split: + _, split_depth = TypeParser.strip_splits(self.type) + + def apply_splits(obj, depth): + if depth < 1: + return obj + return StateArray[self.type](apply_splits(i, depth - 1) for i in obj) + + value = apply_splits(value, split_depth) + value = self._apply_cast(value) + return value + + +class LazyOutField(LazyField[T]): + attr_type = "output" + + def get_value(self, wf: "Workflow", state_index: ty.Optional[int] = None) -> ty.Any: + """Return the value of a lazy field. + + Parameters + ---------- + wf : Workflow + the workflow the lazy field references + state_index : int, optional + the state index of the field to access + + Returns + ------- + value : Any + the resolved value of the lazy-field + """ + from pydra.utils.typing import ( + TypeParser, + ) # pylint: disable=import-outside-toplevel + + node = getattr(wf, self.name) + result = node.result(state_index=state_index) + if result is None: + raise RuntimeError( + f"Could not find results of '{node.name}' node in a sub-directory " + f"named '{node.checksum}' in any of the cache locations.\n" + + "\n".join(str(p) for p in set(node.cache_locations)) + + f"\n\nThis is likely due to hash changes in '{self.name}' node inputs. " + f"Current values and hashes: {node.inputs}, " + f"{node.inputs._hashes}\n\n" + "Set loglevel to 'debug' in order to track hash changes " + "throughout the execution of the workflow.\n\n " + "These issues may have been caused by `bytes_repr()` methods " + "that don't return stable hash values for specific object " + "types across multiple processes (see bytes_repr() " + '"singledispatch "function in pydra/utils/hash.py).' + "You may need to write specific `bytes_repr()` " + "implementations (see `pydra.utils.hash.register_serializer`) or a " + "`__bytes_repr__()` dunder methods to handle one or more types in " + "your interface inputs." + ) + _, split_depth = TypeParser.strip_splits(self.type) + + def get_nested_results(res, depth: int): + if isinstance(res, list): + if not depth: + val = [r.get_output_field(self.field) for r in res] + else: + val = StateArray[self.type]( + get_nested_results(res=r, depth=depth - 1) for r in res + ) + else: + if res.errored: + raise ValueError( + f"Cannot retrieve value for {self.field} from {self.name} as " + "the node errored" + ) + val = res.get_output_field(self.field) + if depth and not wf._pre_split: + assert isinstance(val, ty.Sequence) and not isinstance(val, str) + val = StateArray[self.type](val) + return val + + value = get_nested_results(result, depth=split_depth) + value = self._apply_cast(value) + return value diff --git a/pydra/engine/workflow.py b/pydra/engine/workflow/node.py similarity index 83% rename from pydra/engine/workflow.py rename to pydra/engine/workflow/node.py index 475c60d338..373410b30d 100644 --- a/pydra/engine/workflow.py +++ b/pydra/engine/workflow/node.py @@ -3,13 +3,13 @@ from operator import itemgetter from typing_extensions import Self import attrs -from pydra.design.base import list_fields, TaskSpec, OutputsSpec -from pydra.engine.specs import LazyField, LazyInField, LazyOutField, StateArray from pydra.utils.hash import hash_function -from pydra.utils.typing import TypeParser -from . import helpers_state as hlpst -from .helpers import ensure_list -from . import state +from pydra.utils.typing import TypeParser, StateArray +from . import lazy +from ..specs import TaskSpec, OutputsSpec +from .. import helpers_state as hlpst +from ..helpers import ensure_list, list_fields +from .. import state OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) @@ -28,22 +28,61 @@ class Node(ty.Generic[OutputType]): """ name: str - inputs: TaskSpec[OutputType] - _workflow: "Workflow" = None - _lzout: OutputType | None = None - _state: state.State | None = None - _cont_dim: dict[str, int] | None = ( - None # QUESTION: should this be included in the state? + _spec: TaskSpec[OutputType] + _workflow: "Workflow" = attrs.field(default=None, eq=False, hash=False) + _lzout: OutputType | None = attrs.field( + init=False, default=None, eq=False, hash=False ) + _state: state.State | None = attrs.field(init=False, default=None) + _cont_dim: dict[str, int] | None = attrs.field( + init=False, default=None + ) # QUESTION: should this be included in the state? + + class Inputs: + """A class to wrap the inputs of a node and control access to them so lazy fields + that will change the downstream state aren't set after the node has been split, + combined or its outputs accessed + """ + + _node: "Node" + + def __init__(self, node: "Node") -> None: + super().__setattr__("_node", node) + + def __getattr__(self, name: str) -> ty.Any: + return getattr(self._node._spec, name) + + def __setattr__(self, name: str, value: ty.Any) -> None: + if isinstance(value, lazy.LazyField): + if self._node.state: + + raise AttributeError( + "Cannot set inputs on a node that has been split or combined" + ) + setattr(self._node._spec, name, value) + + @property + def state(self): + return self._state + + @property + def inputs(self) -> Inputs: + return self.Inputs(self) + + @property + def input_values(self) -> tuple[tuple[str, ty.Any]]: + return tuple(attrs.asdict(self._spec, recurse=False).items()) @property def lzout(self) -> OutputType: + from pydra.engine.helpers import list_fields + """The output spec of the node populated with lazy fields""" if self._lzout is not None: return self._lzout combined_splitter = set() - for inpt_name, inpt_val in attrs.asdict(self.inputs, recurse=False).items(): - if isinstance(inpt_val, LazyField): + for inpt_name, inpt_val in self.input_values: + if isinstance(inpt_val, lazy.LazyField): combined_splitter.update(inpt_val.splits) lazy_fields = {} for field in list_fields(self.inputs.Outputs): @@ -52,8 +91,8 @@ def lzout(self) -> OutputType: # over state values for _ in range(len(combined_splitter)): type_ = StateArray[type_] - lazy_fields[field.name] = LazyOutField( - name=self.name, + lazy_fields[field.name] = lazy.LazyOutField( + node=self, field=field.name, type=type_, splits=frozenset(iter(combined_splitter)), @@ -99,7 +138,7 @@ def split( self : TaskSpec a reference to the task """ - self._check_if_outputs_have_been_used() + self._check_if_outputs_have_been_used("the node cannot be split or combined") if splitter is None and inputs: splitter = list(inputs) elif splitter: @@ -129,7 +168,7 @@ def split( for inpt_name, inpt_val in inputs.items(): new_val: ty.Any if f"{self.name}.{inpt_name}" in split_inputs: # type: ignore - if isinstance(inpt_val, LazyField): + if isinstance(inpt_val, lazy.LazyField): new_val = inpt_val.split(splitter) elif isinstance(inpt_val, ty.Iterable) and not isinstance( inpt_val, (ty.Mapping, str) @@ -143,12 +182,12 @@ def split( new_val = inpt_val new_inputs[inpt_name] = new_val # Update the inputs with the new split values - self.inputs = attrs.evolve(self.inputs, **new_inputs) + self._spec = attrs.evolve(self._spec, **new_inputs) if not self._state or splitter != self._state.splitter: self._set_state(splitter) # Wrap types of lazy outputs in StateArray types - split_depth = len(LazyField.normalize_splitter(splitter)) - outpt_lf: LazyOutField + split_depth = len(lazy.LazyField.normalize_splitter(splitter)) + outpt_lf: lazy.LazyOutField for outpt_lf in attrs.asdict(self.lzout, recurse=False).values(): assert not outpt_lf.type_checked outpt_type = outpt_lf.type @@ -185,12 +224,10 @@ def combine( if not isinstance(combiner, (str, list)): raise Exception("combiner has to be a string or a list") combiner = hlpst.add_name_combiner(ensure_list(combiner), self.name) - if not_split := [ - c for c in combiner if not any(c in s for s in self._state.splitter) - ]: + if not_split := [c for c in combiner if not any(c in s for s in self.splitter)]: raise ValueError( f"Combiner fields {not_split} for Node {self.name!r} are not in the " - f"splitter fields {self._state.splitter}" + f"splitter fields {self.splitter}" ) if ( self._state @@ -207,16 +244,18 @@ def combine( # a task can have a combiner without a splitter # if is connected to one with a splitter; # self.fut_combiner will be used later as a combiner - self._state.fut_combiner = combiner + self._state.fut_combiner = ( + combiner # QUESTION: why separate combiner and fut_combiner? + ) else: # self.state and not self.state.combiner self._set_state(splitter=self._state.splitter, combiner=combiner) # Wrap types of lazy outputs in StateArray types - norm_splitter = LazyField.normalize_splitter(self._state.splitter) + norm_splitter = lazy.LazyField.normalize_splitter(self._state.splitter) remaining_splits = [ s for s in norm_splitter if not any(c in s for c in combiner) ] combine_depth = len(norm_splitter) - len(remaining_splits) - outpt_lf: LazyOutField + outpt_lf: lazy.LazyOutField for outpt_lf in attrs.asdict(self.lzout, recurse=False).values(): assert not outpt_lf.type_checked outpt_type, split_depth = TypeParser.strip_splits(outpt_lf.type) @@ -270,16 +309,16 @@ def cont_dim(self, cont_dim): @property def splitter(self): if not self._state: - return None + return () return self._state.splitter @property def combiner(self): if not self._state: - return None + return () return self._state.combiner - def _check_if_outputs_have_been_used(self): + def _check_if_outputs_have_been_used(self, msg): used = [] if self._lzout: for outpt_name, outpt_val in attrs.asdict( @@ -289,8 +328,8 @@ def _check_if_outputs_have_been_used(self): used.append(outpt_name) if used: raise RuntimeError( - f"Outputs {used} of {self} have already been accessed and therefore cannot " - "be split or combined" + f"Outputs {used} of {self} have already been accessed and therefore " + + msg ) @@ -356,7 +395,7 @@ def construct( setattr( lazy_spec, lzy_inpt.name, - LazyInField( + lazy.LazyInField( field=lzy_inpt.name, type=lzy_inpt.type, ), @@ -436,5 +475,7 @@ def under_construction(cls) -> "Workflow[ty.Any]": ) return cls._under_construction + # Used to store the workflow that is currently being constructed _under_construction: "Workflow[ty.Any]" = None + # Used to cache the constructed workflows by their hashed input values _constructed: dict[int, "Workflow[ty.Any]"] = {} diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index c9f1b9b592..decdc81e0f 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -8,14 +8,8 @@ import typing as ty import logging import attr -from pydra.engine.specs import ( - LazyField, - StateArray, - MultiInputObj, - MultiOutputObj, -) from pydra.utils import add_exc_note -from fileformats import field, core +from fileformats import field, core, generic try: from typing import get_origin, get_args @@ -46,6 +40,45 @@ TypeOrAny = ty.Union[type, ty.Any] +# These are special types that are checked for in the construction of input/output specs +# and special converters inserted into the attrs fields. + + +class MultiInputObj(list, ty.Generic[T]): + pass + + +MultiInputFile = MultiInputObj[generic.File] + + +# Since we can't create a NewType from a type union, we add a dummy type to the union +# so we can detect the MultiOutput in the input/output spec creation +class MultiOutputType: + pass + + +MultiOutputObj = ty.Union[list, object, MultiOutputType] +MultiOutputFile = ty.Union[generic.File, ty.List[generic.File], MultiOutputType] + +OUTPUT_TEMPLATE_TYPES = ( + Path, + ty.List[Path], + ty.Union[Path, bool], + ty.Union[ty.List[Path], bool], + ty.List[ty.List[Path]], +) + + +class StateArray(ty.List[T]): + """an array of values from, or to be split over in an array of nodes (see TaskBase.split()), + multiple nodes of the same task. Used in type-checking to differentiate between list + types and values for multiple nodes + """ + + def __repr__(self): + return f"{type(self).__name__}(" + ", ".join(repr(i) for i in self) + ")" + + class TypeParser(ty.Generic[T]): """A callable which can be used as a converter for attrs.fields to check whether an object or LazyField matches the specified field type, or can be @@ -159,7 +192,7 @@ def expand_pattern(t): self.superclass_auto_cast = superclass_auto_cast self.match_any_of_union = match_any_of_union - def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: + def __call__(self, obj: ty.Any) -> T: """Attempts to coerce the object to the specified type, unless the value is a LazyField where the type of the field is just checked instead or an attrs.NOTHING where it is simply returned. @@ -180,6 +213,8 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: if the coercion is not possible, or not specified by the `coercible`/`not_coercible` parameters, then a TypeError is raised """ + from pydra.engine.workflow.lazy import LazyField + coerced: T if obj is attr.NOTHING: coerced = attr.NOTHING # type: ignore[assignment]