Skip to content

Commit

Permalink
shell tasks now execute
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Dec 13, 2024
1 parent 000140e commit 5e898af
Show file tree
Hide file tree
Showing 18 changed files with 660 additions and 604 deletions.
28 changes: 20 additions & 8 deletions pydra/design/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import attrs.validators
from attrs.converters import default_if_none
from fileformats.generic import File
from pydra.utils.typing import TypeParser, is_optional, is_fileset_or_union
from pydra.utils.typing import TypeParser, is_optional, is_fileset_or_union, is_type
from pydra.engine.helpers import (
from_list_if_single,
ensure_list,
Expand Down Expand Up @@ -52,11 +52,6 @@ def __bool__(self):
EMPTY = _Empty.EMPTY # To provide a blank placeholder for the default field


def is_type(_, __, val: ty.Any) -> bool:
"""check that the value is a type or generic"""
return inspect.isclass(val) or ty.get_origin(val)


def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
"""Ensure the default value has been coerced into the correct type"""
if value is EMPTY:
Expand Down Expand Up @@ -400,6 +395,10 @@ def make_task_spec(

if name is None and klass is not None:
name = klass.__name__
if reserved_names := [n for n in inputs if n in spec_type.RESERVED_FIELD_NAMES]:
raise ValueError(
f"{reserved_names} are reserved and cannot be used for {spec_type} field names"
)
outputs_klass = make_outputs_spec(out_type, outputs, outputs_bases, name)
if klass is None or not issubclass(klass, spec_type):
if name is None:
Expand Down Expand Up @@ -503,7 +502,7 @@ def make_outputs_spec(
outputs_bases = bases + (spec_type,)
if reserved_names := [n for n in outputs if n in spec_type.RESERVED_FIELD_NAMES]:
raise ValueError(
f"{reserved_names} are reserved and cannot be used for output field names"
f"{reserved_names} are reserved and cannot be used for {spec_type} field names"
)
# Add in any fields in base classes that haven't already been converted into attrs
# fields (e.g. stdout, stderr and return_code)
Expand Down Expand Up @@ -585,12 +584,25 @@ def ensure_field_objects(
arg.name = input_name
if not arg.help_string:
arg.help_string = input_helps.get(input_name, "")
else:
elif is_type(arg):
inputs[input_name] = arg_type(
type=arg,
name=input_name,
help_string=input_helps.get(input_name, ""),
)
elif isinstance(arg, dict):
arg_kwds = copy(arg)
if "help_string" not in arg_kwds:
arg_kwds["help_string"] = input_helps.get(input_name, "")
inputs[input_name] = arg_type(
name=input_name,
**arg_kwds,
)
else:
raise ValueError(
f"Input {input_name} must be an instance of {Arg}, a type, or a dictionary "
f" of keyword arguments to pass to {Arg}, not {arg}"
)

for output_name, out in list(outputs.items()):
if isinstance(out, Out):
Expand Down
1 change: 1 addition & 0 deletions pydra/design/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def make(
argstr="",
position=0,
default=executable,
validator=attrs.validators.min_len(1),
help_string=EXECUTABLE_HELP_STRING,
)

Expand Down
19 changes: 14 additions & 5 deletions pydra/design/tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_interface_template():
assert sorted_fields(SampleInterface) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="cp",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -80,6 +81,7 @@ def test_interface_template_w_types_and_path_template_ext():
assert sorted_fields(SampleInterface) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="trim-png",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -119,6 +121,7 @@ def test_interface_template_w_modify():
assert sorted_fields(SampleInterface) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="trim-png",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -176,6 +179,7 @@ def test_interface_template_more_complex():
assert sorted_fields(SampleInterface) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="cp",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -273,6 +277,7 @@ def test_interface_template_with_overrides_and_optionals():
== [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="cp",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -347,6 +352,7 @@ def test_interface_template_with_defaults():
assert sorted_fields(SampleInterface) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="cp",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -414,6 +420,7 @@ def test_interface_template_with_type_overrides():
assert sorted_fields(SampleInterface) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="cp",
type=str | ty.Sequence[str],
position=0,
Expand Down Expand Up @@ -545,6 +552,7 @@ class Outputs(ShellOutputs):
type=bool,
help_string="Show complete date in long format",
argstr="-T",
default=False,
requires=["long_format"],
xor=["date_format_str"],
),
Expand Down Expand Up @@ -606,7 +614,7 @@ def test_shell_pickle_roundtrip(Ls, tmp_path):
assert RereadLs is Ls


@pytest.mark.xfail(reason="Still need to update tasks to use new shell interface")
# @pytest.mark.xfail(reason="Still need to update tasks to use new shell interface")
def test_shell_run(Ls, tmp_path):
Path.touch(tmp_path / "a")
Path.touch(tmp_path / "b")
Expand All @@ -615,16 +623,16 @@ def test_shell_run(Ls, tmp_path):
ls = Ls(directory=tmp_path, long_format=True)

# Test cmdline
assert ls.inputs.directory == tmp_path
assert not ls.inputs.hidden
assert ls.inputs.long_format
assert ls.directory == Directory(tmp_path)
assert not ls.hidden
assert ls.long_format
assert ls.cmdline == f"ls -l {tmp_path}"

# Drop Long format flag to make output simpler
ls = Ls(directory=tmp_path)
result = ls()

assert result.output.entries == ["a", "b", "c"]
assert sorted(result.output.entries) == ["a", "b", "c"]


@pytest.fixture(params=["static", "dynamic"])
Expand Down Expand Up @@ -721,6 +729,7 @@ class Outputs:
assert sorted_fields(A) == [
shell.arg(
name="executable",
validator=attrs.validators.min_len(1),
default="cp",
type=str | ty.Sequence[str],
argstr="",
Expand Down
8 changes: 1 addition & 7 deletions pydra/design/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,8 @@ def MyTestShellWorkflow(
)
output_video = workflow.add(
shell.define(
"HandBrakeCLI -i <in_video> -o <out|out_video> "
"HandBrakeCLI -i <in_video:video/mp4> -o <out|out_video:video/mp4> "
"--width <width:int> --height <height:int>",
# By default any input/output specified with a flag (e.g. -i <in_video>)
# is considered optional, i.e. of type `FsObject | None`, and therefore
# won't be used by default. By overriding this with non-optional types,
# the fields are specified as being required.
inputs={"in_video": video.Mp4},
outputs={"out_video": video.Mp4},
)(in_video=add_watermark.out_video, width=1280, height=720),
name="resize",
).out_video
Expand Down
10 changes: 5 additions & 5 deletions pydra/engine/boutiques.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,21 @@ def _command_args_single(self, state_ind=None, index=None):
"""Get command line arguments for a single state"""
input_filepath = self._bosh_invocation_file(state_ind=state_ind, index=index)
cmd_list = (
self.inputs.executable
self.spec.executable
+ [str(self.bosh_file), input_filepath]
+ self.inputs.args
+ self.spec.args
+ self.bindings
)
return cmd_list

def _bosh_invocation_file(self, state_ind=None, index=None):
"""creating bosh invocation file - json file with inputs values"""
input_json = {}
for f in attrs_fields(self.inputs, exclude_names=("executable", "args")):
for f in attrs_fields(self.spec, exclude_names=("executable", "args")):
if self.state and f"{self.name}.{f.name}" in state_ind:
value = getattr(self.inputs, f.name)[state_ind[f"{self.name}.{f.name}"]]
value = getattr(self.spec, f.name)[state_ind[f"{self.name}.{f.name}"]]
else:
value = getattr(self.inputs, f.name)
value = getattr(self.spec, f.name)
# adding to the json file if specified by the user
if value is not attr.NOTHING and value != "NOTHING":
if is_local_file(f):
Expand Down
Loading

0 comments on commit 5e898af

Please sign in to comment.