Skip to content

Commit

Permalink
fixed up lazy out splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Dec 5, 2024
1 parent 03e6951 commit 384e57d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
24 changes: 0 additions & 24 deletions pydra/engine/workflow/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
T = ty.TypeVar("T")

TypeOrAny = ty.Union[type, ty.Any]
Splitter = ty.Union[str, ty.Tuple[str, ...]]


@attrs.define(auto_attribs=True, kw_only=True)
Expand Down Expand Up @@ -100,29 +99,6 @@ def cast(self, new_type: TypeOrAny) -> Self:

# # def combine(self, combiner: str | list[str]) -> Self:

# @classmethod
# def normalize_splitter(
# cls, splitter: Splitter, strip_previous: bool = True
# ) -> ty.Tuple[ty.Tuple[str, ...], ...]:
# """Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form
# used in LazyFields"""
# if isinstance(splitter, str):
# splitter = (splitter,)
# if isinstance(splitter, tuple):
# splitter = (splitter,) # type: ignore
# else:
# assert isinstance(splitter, list)
# # convert to frozenset to differentiate from tuple, yet still be hashable
# # (NB: order of fields in list splitters aren't relevant)
# splitter = tuple((s,) if isinstance(s, str) else s for s in splitter)
# # Strip out fields starting with "_" designating splits in upstream nodes
# if strip_previous:
# stripped = tuple(
# tuple(f for f in i if not f.startswith("_")) for i in splitter
# )
# splitter = tuple(s for s in stripped if s) # type: ignore
# return splitter # type: ignore

def _apply_cast(self, value):
"""\"Casts\" the value from the retrieved type if a cast has been applied to
the lazy-field"""
Expand Down
29 changes: 26 additions & 3 deletions pydra/engine/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


OutputType = ty.TypeVar("OutputType", bound=OutputsSpec)
Splitter = ty.Union[str, ty.Tuple[str, ...]]


@attrs.define
Expand Down Expand Up @@ -186,15 +187,14 @@ def split(
if not self._state or splitter != self._state.splitter:
self._set_state(splitter)
# Wrap types of lazy outputs in StateArray types
split_depth = len(lazy.LazyField.normalize_splitter(splitter))
split_depth = len(self._normalize_splitter(splitter))
outpt_lf: lazy.LazyOutField
for outpt_lf in attrs.asdict(self.lzout, recurse=False).values():
assert not outpt_lf.type_checked
outpt_type = outpt_lf.type
for d in range(split_depth):
outpt_type = StateArray[outpt_type]
outpt_lf.type = outpt_type
outpt_lf.splits = frozenset(iter(self._state.splitter))
return self

def combine(
Expand Down Expand Up @@ -250,7 +250,7 @@ def combine(
else: # self.state and not self.state.combiner
self._set_state(splitter=self._state.splitter, combiner=combiner)
# Wrap types of lazy outputs in StateArray types
norm_splitter = lazy.LazyField.normalize_splitter(self._state.splitter)
norm_splitter = self._normalize_splitter(self._state.splitter)
remaining_splits = [
s for s in norm_splitter if not any(c in s for c in combiner)
]
Expand Down Expand Up @@ -332,6 +332,29 @@ def _check_if_outputs_have_been_used(self, msg):
+ msg
)

@classmethod
def _normalize_splitter(
cls, splitter: Splitter, strip_previous: bool = True
) -> ty.Tuple[ty.Tuple[str, ...], ...]:
"""Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form
used in LazyFields"""
if isinstance(splitter, str):
splitter = (splitter,)
if isinstance(splitter, tuple):
splitter = (splitter,) # type: ignore
else:
assert isinstance(splitter, list)
# convert to frozenset to differentiate from tuple, yet still be hashable
# (NB: order of fields in list splitters aren't relevant)
splitter = tuple((s,) if isinstance(s, str) else s for s in splitter)
# Strip out fields starting with "_" designating splits in upstream nodes
if strip_previous:
stripped = tuple(
tuple(f for f in i if not f.startswith("_")) for i in splitter
)
splitter = tuple(s for s in stripped if s) # type: ignore
return splitter # type: ignore


@attrs.define(auto_attribs=False)
class Workflow(ty.Generic[OutputType]):
Expand Down

0 comments on commit 384e57d

Please sign in to comment.