From 37decee7e6de401407fb1bd8d4be23c71058d0ce Mon Sep 17 00:00:00 2001 From: Florian Rupprecht Date: Mon, 20 May 2024 15:14:31 -0400 Subject: [PATCH] Improve output file codegen --- src/styx/compiler/compile/descriptor.py | 2 +- src/styx/compiler/compile/outputs.py | 82 ++++++++++++++++++------- 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/src/styx/compiler/compile/descriptor.py b/src/styx/compiler/compile/descriptor.py index 4665830..2e64371 100644 --- a/src/styx/compiler/compile/descriptor.py +++ b/src/styx/compiler/compile/descriptor.py @@ -53,7 +53,7 @@ def _generate_run_function( generate_command_line_args_building(command.input_command_line_template, symbols, func, inputs) # Outputs static definition - generate_outputs_definition(module, symbols, outputs) + generate_outputs_definition(module, symbols, outputs, inputs) # Outputs building code generate_output_building(func, scopes, symbols, outputs, inputs) diff --git a/src/styx/compiler/compile/outputs.py b/src/styx/compiler/compile/outputs.py index 1e1cf0f..3dd61cf 100644 --- a/src/styx/compiler/compile/outputs.py +++ b/src/styx/compiler/compile/outputs.py @@ -1,13 +1,27 @@ from styx.compiler.compile.common import SharedScopes, SharedSymbols +from styx.compiler.compile.inputs import codegen_var_is_set_by_user from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, WithSymbol from styx.pycodegen.core import PyFunc, PyModule, indent from styx.pycodegen.utils import as_py_literal, enbrace, enquote +def _find_output_dependencies( + output: WithSymbol[OutputArgument], + inputs: list[WithSymbol[InputArgument]], +) -> list[WithSymbol[InputArgument]]: + """Find the input dependencies for an output.""" + dependencies = [] + for input_ in inputs: + if input_.data.template_key in output.data.path_template: + dependencies.append(input_) + return dependencies + + def generate_outputs_definition( module: PyModule, symbols: SharedSymbols, outputs: list[WithSymbol[OutputArgument]], + inputs: list[WithSymbol[InputArgument]], ) -> None: """Generate the static output class definition.""" module.header.extend([ @@ -23,10 +37,16 @@ def generate_outputs_definition( ]), ]) for out in outputs: + deps = _find_output_dependencies(out, inputs) + if any([input_.data.type.is_optional for input_ in deps]): + out_type = "OutputPathType | None" + else: + out_type = "OutputPathType" + # Declaration module.header.extend( indent([ - f"{out.symbol}: OutputPathType", + f"{out.symbol}: {out_type}", f'"""{out.data.doc}"""', ]) ) @@ -61,33 +81,53 @@ def generate_output_building( for out in outputs: strip_extensions = out.data.stripped_file_extensions is not None + s_optional = ", optional=True" if out.data.optional else "" if out.data.path_template is not None: - s = out.data.path_template - for input_ in inputs: - if input_.data.template_key not in s: - continue + path_template = out.data.path_template + + input_dependencies = _find_output_dependencies(out, inputs) + + if len(input_dependencies) == 0: + # No substitutions needed + func.body.extend( + indent([f"{out.symbol}={symbols.execution}.output_file(f{enquote(path_template)}{s_optional}),"]) + ) + else: + for input_ in input_dependencies: + substitute = input_.symbol + + if input_.data.type.is_list: + raise Exception(f"Output path template replacements cannot be lists. ({input_.data.name})") - substitute = input_.symbol + if input_.data.type.primitive == InputTypePrimitive.File: + # Just use the stem of the file + # This is commonly used when output files 'inherit' the name of an input file + substitute = f"pathlib.Path({substitute}).stem" + elif (input_.data.type.primitive == InputTypePrimitive.Number) or ( + input_.data.type.primitive == InputTypePrimitive.Integer + ): + # Convert to string + substitute = f"str({substitute})" + elif input_.data.type.primitive != InputTypePrimitive.String: + raise Exception( + f"Unsupported input type {input_.data.type.primitive} " + f"for output path template of '{out.data.name}'." + ) - if input_.data.type.primitive == InputTypePrimitive.File: - # Just use the stem of the file - # This is commonly used when output files 'inherit' the name of an input file - substitute = f"pathlib.Path({substitute}).stem" - elif input_.data.type.primitive != InputTypePrimitive.String: - raise Exception( - f"Unsupported input type {input_.data.type.primitive} " - f"for output path template of '{out.data.name}'." - ) + if strip_extensions: + exts = as_py_literal(out.data.stripped_file_extensions, "'") + substitute = f"{py_rstrip_fun}({substitute}, {exts})" - if strip_extensions: - exts = as_py_literal(out.data.stripped_file_extensions, "'") - substitute = f"{py_rstrip_fun}({substitute}, {exts})" + path_template = path_template.replace(input_.data.template_key, enbrace(substitute)) - s = s.replace(input_.data.template_key, enbrace(substitute)) + resolved_output = f"{symbols.execution}.output_file(f{enquote(path_template)}{s_optional})" - s_optional = ", optional=True" if out.data.optional else "" + if any([input_.data.type.is_optional for input_ in input_dependencies]): + # Codegen: Condition: Is any variable in the segment set by the user? + condition = [codegen_var_is_set_by_user(i) for i in input_dependencies] + resolved_output = f"{resolved_output} if {' and '.join(condition)} else None" - func.body.extend(indent([f"{out.symbol}={symbols.execution}.output_file(f{enquote(s)}{s_optional}),"])) + func.body.extend(indent([f"{out.symbol}={resolved_output},"])) else: raise NotImplementedError