Skip to content

Commit

Permalink
debugged setting of state in split and combine
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Dec 10, 2024
1 parent cf7b331 commit c6b7cd7
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 109 deletions.
9 changes: 6 additions & 3 deletions pydra/design/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,12 @@ def MyTestWorkflow(a: list[int], b: list[float], c: float) -> list[float]:

wf = Workflow.construct(MyTestWorkflow(a=[1, 2, 3], b=[1.0, 10.0, 100.0], c=2.0))
assert wf["Mul"].splitter == ["Mul.x", "Mul.y"]
assert wf["Mul"].combiner == ["Mul.x"]
assert wf["Add"].lzout.out.splits == frozenset(["Mul.x"])
assert wf.outputs.out == LazyOutField(node=wf["Sum"], field="out", type=list[float])
assert wf["Mul"].combiner == []
assert wf["Add"].splitter == "_Mul"
assert wf["Add"].combiner == ["Mul.x"]
assert wf.outputs.out == LazyOutField(
node=wf["Sum"], field="out", type=list[float], type_checked=True
)


def test_workflow_split_after_access_fail():
Expand Down
25 changes: 22 additions & 3 deletions pydra/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,28 @@ def __str__(self):
)

@property
def depth(self):
"""Return the number of uncombined splits of the state."""
return len(self.states_ind)
def depth(self) -> int:
"""Return the number of uncombined splits of the state, i.e. the number nested
state arrays to wrap around the type of lazy out fields
Returns
-------
int
number of uncombined splits
"""
depth = 0
stack = []
for spl in self.splitter_rpn:
if spl in [".", "*"]:
if spl == ".":
depth += int(all(s not in self.combiner for s in stack))
else:
assert spl == "*"
depth += len([s for s in stack if s not in self.combiner])
stack = []
else:
stack.append(spl)
return depth + len(stack)

@property
def splitter(self):
Expand Down
157 changes: 54 additions & 103 deletions pydra/engine/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydra.utils.typing import TypeParser, StateArray
from . import lazy
from ..specs import TaskSpec, Outputs
from ..helpers import ensure_list, attrs_values
from ..helpers import ensure_list, attrs_values, is_lazy
from .. import helpers_state as hlpst
from ..state import State

Expand Down Expand Up @@ -62,17 +62,16 @@ def __getattr__(self, name: str) -> ty.Any:
return getattr(self._node._spec, name)

def __setattr__(self, name: str, value: ty.Any) -> None:
if isinstance(value, lazy.LazyField):
# Save the current state for comparison later
prev_state = self._node.state
if value.node.state:
# Reset the state to allow the lazy field to be set
self._node._state = NOT_SET
setattr(self._node._spec, name, value)
if value.node.state and self._node.state != prev_state:
setattr(self._node._spec, name, value)
if is_lazy(value):
upstream_states = self._node._get_upstream_states()
if (
not self._node._state
or self._node._state.other_states != upstream_states
):
self._node._check_if_outputs_have_been_used(
f"cannot set {name!r} input to {value} because it changes the "
f"state of the node from {prev_state} to {value.node.state}"
f"state"
)

@property
Expand All @@ -86,36 +85,8 @@ def state(self):
"""
if self._state is not NOT_SET:
return self._state
upstream_states = self._upstream_states()
if upstream_states:
state = State(
self.name,
splitter=None,
other_states=upstream_states,
combiner=None,
)
else:
state = None
self._state = state
return state

def _upstream_states(self):
"""Get the states of the upstream nodes that are connected to this node"""
upstream_states = {}
for inpt_name, val in self.input_values:
if isinstance(val, lazy.LazyOutField) and val.node.state:
node: Node = val.node
# variables that are part of inner splitters should be treated as a containers
if node.state and f"{node.name}.{inpt_name}" in node.state.splitter:
node._inner_cont_dim[f"{node.name}.{inpt_name}"] = 1
# adding task_name: (task.state, [a field from the connection]
if node.name not in upstream_states:
upstream_states[node.name] = (node.state, [inpt_name])
else:
# if the task already exist in other_state,
# additional field name should be added to the list of fields
upstream_states[node.name][1].append(inpt_name)
return upstream_states
self._set_state(other_states=self._get_upstream_states())
return self._state

@property
def input_values(self) -> tuple[tuple[str, ty.Any]]:
Expand Down Expand Up @@ -222,8 +193,7 @@ def split(
new_inputs[inpt_name] = new_val
# Update the inputs with the new split values
self._spec = attrs.evolve(self._spec, **new_inputs)
if not self._state or splitter != self._state.splitter:
self._set_state(splitter)
self._set_state(splitter=splitter)
# Wrap types of lazy outputs in StateArray types
self._wrap_lzout_types_in_state_arrays()
return self
Expand Down Expand Up @@ -272,39 +242,10 @@ def combine(
"combiner has been already set, "
"if you want to overwrite it - use overwrite=True"
)
if not self._state:
self.split(splitter=None)
# a task can have a combiner without a splitter
# if is connected to one with a splitter;
# self.fut_combiner will be used later as a combiner
self._state.fut_combiner = (
combiner # QUESTION: why separate combiner and fut_combiner?
)
else: # self.state and not self.state.combiner
self._set_state(splitter=self._state.splitter, combiner=combiner)
self._set_state(combiner=combiner)
self._wrap_lzout_types_in_state_arrays()
return self

def _set_state(self, splitter, combiner=None):
"""
Set a particular state on this task.
Parameters
----------
splitter : str | list[str] | tuple[str]
the fields which to split over. If splitting over multiple fields, lists of
fields are interpreted as outer-products and tuples inner-products. If None,
then the fields to split are taken from the keyword-arg names.
combiner : list[str] | str, optional
the field or list of inputs to be combined (i.e. not left split) after the
task has been run
"""
if splitter is not None:
self._state = State(name=self.name, splitter=splitter, combiner=combiner)
else:
self._state = None
return self._state

@property
def cont_dim(self):
# adding inner_cont_dim to the general container_dimension provided by the users
Expand Down Expand Up @@ -353,40 +294,50 @@ def _wrap_lzout_types_in_state_arrays(self) -> None:
if not self.state:
return
outpt_lf: lazy.LazyOutField
remaining_splits = []
for split in self.state.splitter:
if isinstance(split, str):
if split not in self.state.combiner:
remaining_splits.append(split)
elif all(s not in self.state.combiner for s in split):
remaining_splits.append(split)
state_depth = len(remaining_splits)
for outpt_lf in attrs_values(self.lzout).values():
assert not outpt_lf.type_checked
type_, _ = TypeParser.strip_splits(outpt_lf.type)
for _ in range(state_depth):
for _ in range(self._state.depth):
type_ = StateArray[type_]
outpt_lf.type = type_

# @classmethod
# def _normalize_splitter(
# cls, splitter: Splitter, strip_previous: bool = True
# ) -> ty.Tuple[ty.Tuple[str, ...], ...]:
# """Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form
# used in LazyFields"""
# if isinstance(splitter, str):
# splitter = (splitter,)
# if isinstance(splitter, tuple):
# splitter = (splitter,) # type: ignore
# else:
# assert isinstance(splitter, list)
# # convert to frozenset to differentiate from tuple, yet still be hashable
# # (NB: order of fields in list splitters aren't relevant)
# splitter = tuple((s,) if isinstance(s, str) else s for s in splitter)
# # Strip out fields starting with "_" designating splits in upstream nodes
# if strip_previous:
# stripped = tuple(
# tuple(f for f in i if not f.startswith("_")) for i in splitter
# )
# splitter = tuple(s for s in stripped if s) # type: ignore
# return splitter # type: ignore
def _set_state(
self,
splitter: list[str] | tuple[str, ...] | None = None,
combiner: list[str] | None = None,
other_states: dict[str, tuple["State", list[str]]] | None = None,
) -> None:
if self._state not in (NOT_SET, None):
if splitter is None:
splitter = self._state.current_splitter
if combiner is None:
combiner = self._state.current_combiner
if other_states is None:
other_states = self._state.other_states
if not (splitter or combiner or other_states):
self._state = None
else:
self._state = State(
self.name,
splitter=splitter,
other_states=other_states,
combiner=combiner,
)

def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
"""Get the states of the upstream nodes that are connected to this node"""
upstream_states = {}
for inpt_name, val in self.input_values:
if isinstance(val, lazy.LazyOutField) and val.node.state:
node: Node = val.node
# variables that are part of inner splitters should be treated as a containers
if node.state and f"{node.name}.{inpt_name}" in node.state.splitter:
node._inner_cont_dim[f"{node.name}.{inpt_name}"] = 1
# adding task_name: (task.state, [a field from the connection]
if node.name not in upstream_states:
upstream_states[node.name] = (node.state, [inpt_name])
else:
# if the task already exist in other_state,
# additional field name should be added to the list of fields
upstream_states[node.name][1].append(inpt_name)
return upstream_states

0 comments on commit c6b7cd7

Please sign in to comment.