Skip to content

Commit

Permalink
Improve output file codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
nx10 committed May 20, 2024
1 parent 9af78f3 commit 37decee
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/styx/compiler/compile/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _generate_run_function(
generate_command_line_args_building(command.input_command_line_template, symbols, func, inputs)

# Outputs static definition
generate_outputs_definition(module, symbols, outputs)
generate_outputs_definition(module, symbols, outputs, inputs)
# Outputs building code
generate_output_building(func, scopes, symbols, outputs, inputs)

Expand Down
82 changes: 61 additions & 21 deletions src/styx/compiler/compile/outputs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from styx.compiler.compile.common import SharedScopes, SharedSymbols
from styx.compiler.compile.inputs import codegen_var_is_set_by_user
from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, WithSymbol
from styx.pycodegen.core import PyFunc, PyModule, indent
from styx.pycodegen.utils import as_py_literal, enbrace, enquote


def _find_output_dependencies(
output: WithSymbol[OutputArgument],
inputs: list[WithSymbol[InputArgument]],
) -> list[WithSymbol[InputArgument]]:
"""Find the input dependencies for an output."""
dependencies = []
for input_ in inputs:
if input_.data.template_key in output.data.path_template:
dependencies.append(input_)
return dependencies


def generate_outputs_definition(
module: PyModule,
symbols: SharedSymbols,
outputs: list[WithSymbol[OutputArgument]],
inputs: list[WithSymbol[InputArgument]],
) -> None:
"""Generate the static output class definition."""
module.header.extend([
Expand All @@ -23,10 +37,16 @@ def generate_outputs_definition(
]),
])
for out in outputs:
deps = _find_output_dependencies(out, inputs)
if any([input_.data.type.is_optional for input_ in deps]):
out_type = "OutputPathType | None"
else:
out_type = "OutputPathType"

# Declaration
module.header.extend(
indent([
f"{out.symbol}: OutputPathType",
f"{out.symbol}: {out_type}",
f'"""{out.data.doc}"""',
])
)
Expand Down Expand Up @@ -61,33 +81,53 @@ def generate_output_building(

for out in outputs:
strip_extensions = out.data.stripped_file_extensions is not None
s_optional = ", optional=True" if out.data.optional else ""
if out.data.path_template is not None:
s = out.data.path_template
for input_ in inputs:
if input_.data.template_key not in s:
continue
path_template = out.data.path_template

input_dependencies = _find_output_dependencies(out, inputs)

if len(input_dependencies) == 0:
# No substitutions needed
func.body.extend(
indent([f"{out.symbol}={symbols.execution}.output_file(f{enquote(path_template)}{s_optional}),"])
)
else:
for input_ in input_dependencies:
substitute = input_.symbol

if input_.data.type.is_list:
raise Exception(f"Output path template replacements cannot be lists. ({input_.data.name})")

substitute = input_.symbol
if input_.data.type.primitive == InputTypePrimitive.File:
# Just use the stem of the file
# This is commonly used when output files 'inherit' the name of an input file
substitute = f"pathlib.Path({substitute}).stem"
elif (input_.data.type.primitive == InputTypePrimitive.Number) or (
input_.data.type.primitive == InputTypePrimitive.Integer
):
# Convert to string
substitute = f"str({substitute})"
elif input_.data.type.primitive != InputTypePrimitive.String:
raise Exception(
f"Unsupported input type {input_.data.type.primitive} "
f"for output path template of '{out.data.name}'."
)

if input_.data.type.primitive == InputTypePrimitive.File:
# Just use the stem of the file
# This is commonly used when output files 'inherit' the name of an input file
substitute = f"pathlib.Path({substitute}).stem"
elif input_.data.type.primitive != InputTypePrimitive.String:
raise Exception(
f"Unsupported input type {input_.data.type.primitive} "
f"for output path template of '{out.data.name}'."
)
if strip_extensions:
exts = as_py_literal(out.data.stripped_file_extensions, "'")
substitute = f"{py_rstrip_fun}({substitute}, {exts})"

if strip_extensions:
exts = as_py_literal(out.data.stripped_file_extensions, "'")
substitute = f"{py_rstrip_fun}({substitute}, {exts})"
path_template = path_template.replace(input_.data.template_key, enbrace(substitute))

s = s.replace(input_.data.template_key, enbrace(substitute))
resolved_output = f"{symbols.execution}.output_file(f{enquote(path_template)}{s_optional})"

s_optional = ", optional=True" if out.data.optional else ""
if any([input_.data.type.is_optional for input_ in input_dependencies]):
# Codegen: Condition: Is any variable in the segment set by the user?
condition = [codegen_var_is_set_by_user(i) for i in input_dependencies]
resolved_output = f"{resolved_output} if {' and '.join(condition)} else None"

func.body.extend(indent([f"{out.symbol}={symbols.execution}.output_file(f{enquote(s)}{s_optional}),"]))
func.body.extend(indent([f"{out.symbol}={resolved_output},"]))
else:
raise NotImplementedError

Expand Down

0 comments on commit 37decee

Please sign in to comment.