From 0f47d4af675060fa525d309733d2a27f14c24455 Mon Sep 17 00:00:00 2001 From: Florian Rupprecht Date: Thu, 30 May 2024 17:41:17 -0400 Subject: [PATCH] Implement sub-command outputs --- src/styx/compiler/compile/descriptor.py | 14 ++- src/styx/compiler/compile/outputs.py | 76 ++++++++++++--- src/styx/compiler/compile/subcommand.py | 119 +++++++++++++++++++----- src/styx/pycodegen/core.py | 17 +++- 4 files changed, 178 insertions(+), 48 deletions(-) diff --git a/src/styx/compiler/compile/descriptor.py b/src/styx/compiler/compile/descriptor.py index 5c7d80a..81cebed 100644 --- a/src/styx/compiler/compile/descriptor.py +++ b/src/styx/compiler/compile/descriptor.py @@ -3,7 +3,7 @@ from styx.compiler.compile.definitions import generate_definitions from styx.compiler.compile.inputs import build_input_arguments, generate_command_line_args_building from styx.compiler.compile.metadata import generate_static_metadata -from styx.compiler.compile.outputs import generate_output_building, generate_outputs_definition +from styx.compiler.compile.outputs import generate_output_building, generate_outputs_class from styx.compiler.compile.subcommand import generate_sub_command_classes from styx.compiler.settings import CompilerSettings from styx.model.core import Descriptor, InputArgument, OutputArgument, SubCommand, WithSymbol @@ -25,7 +25,9 @@ def _generate_run_function( outputs: list[WithSymbol[OutputArgument]], ) -> None: # Sub-command classes - sub_aliases, _ = generate_sub_command_classes(module, symbols, command, scopes.module) + sub_aliases, sub_sub_command_class_aliases, _ = generate_sub_command_classes( + module, symbols, command, scopes.module + ) # Function func = PyFunc( @@ -58,9 +60,13 @@ def _generate_run_function( generate_command_line_args_building(command.input_command_line_template, symbols, func, inputs) # Outputs static definition - generate_outputs_definition(module, symbols, outputs, inputs) + generate_outputs_class( + module, symbols.output_class, symbols.function, outputs, inputs, sub_sub_command_class_aliases + ) # Outputs building code - generate_output_building(func, scopes, symbols, outputs, inputs) + generate_output_building( + func, scopes.function, symbols.execution, symbols.output_class, symbols.ret, outputs, inputs + ) # Function body: Run and return func.body.extend([ diff --git a/src/styx/compiler/compile/outputs.py b/src/styx/compiler/compile/outputs.py index 3dd61cf..ab84219 100644 --- a/src/styx/compiler/compile/outputs.py +++ b/src/styx/compiler/compile/outputs.py @@ -1,7 +1,7 @@ -from styx.compiler.compile.common import SharedScopes, SharedSymbols from styx.compiler.compile.inputs import codegen_var_is_set_by_user from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, WithSymbol from styx.pycodegen.core import PyFunc, PyModule, indent +from styx.pycodegen.scope import Scope from styx.pycodegen.utils import as_py_literal, enbrace, enquote @@ -17,20 +17,22 @@ def _find_output_dependencies( return dependencies -def generate_outputs_definition( +def generate_outputs_class( module: PyModule, - symbols: SharedSymbols, + symbol_output_class: str, + symbol_parent_function: str, outputs: list[WithSymbol[OutputArgument]], inputs: list[WithSymbol[InputArgument]], + sub_command_output_class_aliases: dict[str, str], ) -> None: """Generate the static output class definition.""" module.header.extend([ "", "", - f"class {symbols.output_class}(typing.NamedTuple):", + f"class {symbol_output_class}(typing.NamedTuple):", *indent([ '"""', - f"Output object returned when calling `{symbols.function}(...)`.", + f"Output object returned when calling `{symbol_parent_function}(...)`.", '"""', "root: OutputPathType", '"""Output root folder. This is the root folder for all outputs."""', @@ -51,16 +53,48 @@ def generate_outputs_definition( ]) ) + for input_ in inputs: + if input_.data.type.primitive == InputTypePrimitive.SubCommand: + assert input_.data.sub_command is not None + module.header.extend( + indent([ + f"{input_.symbol}: {sub_command_output_class_aliases[input_.data.sub_command.internal_id]}", + '"""Subcommand outputs"""', + ]) + ) + + if input_.data.type.primitive == InputTypePrimitive.SubCommandUnion: + assert input_.data.sub_command_union is not None + + sub_commands = [ + sub_command_output_class_aliases[sub_command.internal_id] + for sub_command in input_.data.sub_command_union + ] + sub_commands_type = ", ".join(sub_commands) + sub_commands_type = f"typing.Union[{sub_commands_type}]" + + if input_.data.type.is_list: + sub_commands_type = f"typing.List[{sub_commands_type}]" + + module.header.extend( + indent([ + f"{input_.symbol}: {sub_commands_type}", + '"""Subcommand outputs"""', + ]) + ) + def generate_output_building( func: PyFunc, - scopes: SharedScopes, - symbols: SharedSymbols, + func_scope: Scope, + symbol_execution: str, + symbol_output_class: str, + symbol_return_var: str, outputs: list[WithSymbol[OutputArgument]], inputs: list[WithSymbol[InputArgument]], ) -> None: """Generate the output building code.""" - py_rstrip_fun = scopes.function.add_or_dodge("_rstrip") + py_rstrip_fun = func_scope.add_or_dodge("_rstrip") if any([out.data.stripped_file_extensions is not None for out in outputs]): func.body.extend([ f"def {py_rstrip_fun}(s, r):", @@ -74,10 +108,10 @@ def generate_output_building( ]), ]) - func.body.append(f"{symbols.ret} = {symbols.output_class}(") + func.body.append(f"{symbol_return_var} = {symbol_output_class}(") # Set root output path - func.body.extend(indent([f'root={symbols.execution}.output_file("."),'])) + func.body.extend(indent([f'root={symbol_execution}.output_file("."),'])) for out in outputs: strip_extensions = out.data.stripped_file_extensions is not None @@ -90,7 +124,7 @@ def generate_output_building( if len(input_dependencies) == 0: # No substitutions needed func.body.extend( - indent([f"{out.symbol}={symbols.execution}.output_file(f{enquote(path_template)}{s_optional}),"]) + indent([f"{out.symbol}={symbol_execution}.output_file(f{enquote(path_template)}{s_optional}),"]) ) else: for input_ in input_dependencies: @@ -100,9 +134,9 @@ def generate_output_building( raise Exception(f"Output path template replacements cannot be lists. ({input_.data.name})") if input_.data.type.primitive == InputTypePrimitive.File: - # Just use the stem of the file + # Just use the name of the file # This is commonly used when output files 'inherit' the name of an input file - substitute = f"pathlib.Path({substitute}).stem" + substitute = f"pathlib.Path({substitute}).name" elif (input_.data.type.primitive == InputTypePrimitive.Number) or ( input_.data.type.primitive == InputTypePrimitive.Integer ): @@ -120,7 +154,7 @@ def generate_output_building( path_template = path_template.replace(input_.data.template_key, enbrace(substitute)) - resolved_output = f"{symbols.execution}.output_file(f{enquote(path_template)}{s_optional})" + resolved_output = f"{symbol_execution}.output_file(f{enquote(path_template)}{s_optional})" if any([input_.data.type.is_optional for input_ in input_dependencies]): # Codegen: Condition: Is any variable in the segment set by the user? @@ -131,4 +165,18 @@ def generate_output_building( else: raise NotImplementedError + for input_ in inputs: + if (input_.data.type.primitive == InputTypePrimitive.SubCommand) or ( + input_.data.type.primitive == InputTypePrimitive.SubCommandUnion + ): + if input_.data.type.is_list: + func.body.extend( + indent([ + f"{input_.symbol}=" + f"[{input_.symbol}.outputs({symbol_execution}) for {input_.symbol} in {input_.symbol}]," + ]) + ) + else: + func.body.extend(indent([f"{input_.symbol}={input_.symbol}.outputs({symbol_execution}),"])) + func.body.extend([")"]) diff --git a/src/styx/compiler/compile/subcommand.py b/src/styx/compiler/compile/subcommand.py index 91b69ac..350f1fb 100644 --- a/src/styx/compiler/compile/subcommand.py +++ b/src/styx/compiler/compile/subcommand.py @@ -1,26 +1,36 @@ from styx.compiler.compile.common import SharedSymbols from styx.compiler.compile.constraints import generate_constraint_checks from styx.compiler.compile.inputs import build_input_arguments, generate_command_line_args_building -from styx.model.core import InputArgument, InputTypePrimitive, SubCommand, WithSymbol +from styx.compiler.compile.outputs import generate_output_building, generate_outputs_class +from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, SubCommand, WithSymbol from styx.pycodegen.core import PyArg, PyDataClass, PyFunc, PyModule, blank_before from styx.pycodegen.scope import Scope from styx.pycodegen.utils import python_pascalize, python_snakify -def _sub_command_class_name(parent_name: str, sub_command: SubCommand) -> str: +def _sub_command_class_name(sub_command: SubCommand) -> str: """Return the name of the sub-command class.""" - return python_pascalize(f"{parent_name}_{sub_command.name}") + return python_pascalize(f"{sub_command.name}") + + +def _sub_command_output_class_name(sub_command: SubCommand) -> str: + """Return the name of the sub-command output class.""" + return python_pascalize(f"{sub_command.name}_Outputs") def _generate_sub_command( module: PyModule, + scope_module: Scope, symbols: SharedSymbols, sub_command: SubCommand, + outputs: list[WithSymbol[OutputArgument]], inputs: list[WithSymbol[InputArgument]], aliases: dict[str, str], -) -> str: + sub_command_output_class_aliases: dict[str, str], +) -> tuple[str, str]: """Generate the static output class definition.""" - class_name = _sub_command_class_name(symbols.function, sub_command) + class_name = scope_module.add_or_dodge(_sub_command_class_name(sub_command)) + output_class_name = scope_module.add_or_dodge(_sub_command_output_class_name(sub_command)) module.exports.append(class_name) sub_command_class = PyDataClass( @@ -53,24 +63,53 @@ def _generate_sub_command( ]) sub_command_class.methods.append(run_method) + # Outputs method + + outputs_method = PyFunc( + name="outputs", + docstring_body="Collect output file paths.", + return_type=output_class_name, + return_descr=f"NamedTuple of outputs (described in `{output_class_name}`).", + args=[ + PyArg(name="self", type=None, default=None, docstring="The sub-command object."), + PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), + ], + body=[], + ) + generate_outputs_class( + module, + output_class_name, + class_name + ".run", + outputs, + inputs_self, + sub_command_output_class_aliases, + ) + module.exports.append(output_class_name) + generate_output_building(outputs_method, Scope(), symbols.execution, output_class_name, "ret", outputs, inputs_self) + outputs_method.body.extend(["return ret"]) + sub_command_class.methods.append(outputs_method) + module.header.extend(blank_before(sub_command_class.generate(), 2)) if "import dataclasses" not in module.imports: module.imports.append("import dataclasses") - return class_name + return class_name, output_class_name def generate_sub_command_classes( module: PyModule, symbols: SharedSymbols, command: SubCommand, - scope: Scope, -) -> tuple[dict[str, str], list[WithSymbol[InputArgument]]]: + scope_module: Scope, +) -> tuple[dict[str, str], dict[str, str], list[WithSymbol[InputArgument]]]: """Build Python function arguments from input arguments.""" + # internal_id -> class_name aliases: dict[str, str] = {} + # subcommand.internal_id -> subcommand.outputs() class name + sub_command_output_class_aliases: dict[str, str] = {} - inputs_scope = Scope(parent=scope) - # outputs_scope = Scope(parent=scope) + inputs_scope = Scope(parent=scope_module) + outputs_scope = Scope(parent=scope_module) # Input symbols inputs: list[WithSymbol[InputArgument]] = [] @@ -78,29 +117,59 @@ def generate_sub_command_classes( py_symbol = inputs_scope.add_or_dodge(python_snakify(i.name)) inputs.append(WithSymbol(i, py_symbol)) - # Output symbols - # outputs: list[WithSymbol[OutputArgument]] = [] - # for output in command.outputs: - # py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name)) - # outputs.append(WithSymbol(output, py_symbol)) - for input_ in inputs: if input_.data.type.primitive == InputTypePrimitive.SubCommand: assert input_.data.sub_command is not None sub_command = input_.data.sub_command - sub_aliases, sub_inputs = generate_sub_command_classes(module, symbols, sub_command, inputs_scope) + sub_aliases, sub_sub_command_output_class_aliases, sub_inputs = generate_sub_command_classes( + module, symbols, sub_command, inputs_scope + ) aliases.update(sub_aliases) - sub_command_type = _generate_sub_command(module, symbols, sub_command, sub_inputs, aliases) - if sub_command_type is not None: - aliases[sub_command.internal_id] = sub_command_type + sub_command_output_class_aliases.update(sub_sub_command_output_class_aliases) + + sub_outputs = [] + for output in sub_command.outputs: + py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name)) + sub_outputs.append(WithSymbol(output, py_symbol)) + + sub_command_type, sub_command_output_type = _generate_sub_command( + module, + scope_module, + symbols, + sub_command, + sub_outputs, + sub_inputs, + aliases, + sub_command_output_class_aliases, + ) + aliases[sub_command.internal_id] = sub_command_type + sub_command_output_class_aliases[sub_command.internal_id] = sub_command_output_type if input_.data.type.primitive == InputTypePrimitive.SubCommandUnion: assert input_.data.sub_command_union is not None for sub_command in input_.data.sub_command_union: - sub_aliases, sub_inputs = generate_sub_command_classes(module, symbols, sub_command, inputs_scope) + sub_aliases, sub_sub_command_output_class_aliases, sub_inputs = generate_sub_command_classes( + module, symbols, sub_command, inputs_scope + ) aliases.update(sub_aliases) - sub_command_type = _generate_sub_command(module, symbols, sub_command, sub_inputs, aliases) - if sub_command_type is not None: - aliases[sub_command.internal_id] = sub_command_type + sub_command_output_class_aliases.update(sub_sub_command_output_class_aliases) + + sub_outputs = [] + for output in sub_command.outputs: + py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name)) + sub_outputs.append(WithSymbol(output, py_symbol)) + + sub_command_type, sub_command_output_type = _generate_sub_command( + module, + scope_module, + symbols, + sub_command, + sub_outputs, + sub_inputs, + aliases, + sub_command_output_class_aliases, + ) + aliases[sub_command.internal_id] = sub_command_type + sub_command_output_class_aliases[sub_command.internal_id] = sub_command_output_type - return aliases, inputs + return aliases, sub_command_output_class_aliases, inputs diff --git a/src/styx/pycodegen/core.py b/src/styx/pycodegen/core.py index 25d4fa6..99fe797 100644 --- a/src/styx/pycodegen/core.py +++ b/src/styx/pycodegen/core.py @@ -31,9 +31,16 @@ def expand(text: str) -> LineBuffer: return text.splitlines() -def concat(line_buffers: list[LineBuffer]) -> LineBuffer: +def concat(line_buffers: list[LineBuffer], separator: LineBuffer | None = None) -> LineBuffer: """Concatenate multiple LineBuffers.""" - return [line for buf in line_buffers for line in buf] + if separator is None: + return sum(line_buffers, []) + ret = [] + for i, buf in enumerate(line_buffers): + if i > 0: + ret.extend(separator) + ret.extend(buf) + return ret def blank_before(lines: LineBuffer, blanks: int = 1) -> LineBuffer: @@ -142,7 +149,7 @@ def _arg_docstring(arg: PyArg) -> LineBuffer: return linebreak_paragraph(f'"""{arg.docstring}"""', width=80 - 4, first_line_width=80 - 4) args = concat([[f.declaration(), *_arg_docstring(f)] for f in self.fields]) - methods = concat([method.generate() for method in self.methods]) + methods = concat([method.generate() for method in self.methods], [""]) buf = [ "@dataclasses.dataclass", @@ -181,8 +188,8 @@ def generate(self) -> LineBuffer: return blank_after([ *comment([ - "This file was auto generated by styx", - "Do not edit this file directly", + "This file was auto generated by Styx.", + "Do not edit this file directly.", ]), *blank_before(self.imports), *blank_before(self.header),