Skip to content

Commit

Permalink
Implement sub-command outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
nx10 committed May 30, 2024
1 parent ea88b86 commit 0f47d4a
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 48 deletions.
14 changes: 10 additions & 4 deletions src/styx/compiler/compile/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from styx.compiler.compile.definitions import generate_definitions
from styx.compiler.compile.inputs import build_input_arguments, generate_command_line_args_building
from styx.compiler.compile.metadata import generate_static_metadata
from styx.compiler.compile.outputs import generate_output_building, generate_outputs_definition
from styx.compiler.compile.outputs import generate_output_building, generate_outputs_class
from styx.compiler.compile.subcommand import generate_sub_command_classes
from styx.compiler.settings import CompilerSettings
from styx.model.core import Descriptor, InputArgument, OutputArgument, SubCommand, WithSymbol
Expand All @@ -25,7 +25,9 @@ def _generate_run_function(
outputs: list[WithSymbol[OutputArgument]],
) -> None:
# Sub-command classes
sub_aliases, _ = generate_sub_command_classes(module, symbols, command, scopes.module)
sub_aliases, sub_sub_command_class_aliases, _ = generate_sub_command_classes(
module, symbols, command, scopes.module
)

# Function
func = PyFunc(
Expand Down Expand Up @@ -58,9 +60,13 @@ def _generate_run_function(
generate_command_line_args_building(command.input_command_line_template, symbols, func, inputs)

# Outputs static definition
generate_outputs_definition(module, symbols, outputs, inputs)
generate_outputs_class(
module, symbols.output_class, symbols.function, outputs, inputs, sub_sub_command_class_aliases
)
# Outputs building code
generate_output_building(func, scopes, symbols, outputs, inputs)
generate_output_building(
func, scopes.function, symbols.execution, symbols.output_class, symbols.ret, outputs, inputs
)

# Function body: Run and return
func.body.extend([
Expand Down
76 changes: 62 additions & 14 deletions src/styx/compiler/compile/outputs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from styx.compiler.compile.common import SharedScopes, SharedSymbols
from styx.compiler.compile.inputs import codegen_var_is_set_by_user
from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, WithSymbol
from styx.pycodegen.core import PyFunc, PyModule, indent
from styx.pycodegen.scope import Scope
from styx.pycodegen.utils import as_py_literal, enbrace, enquote


Expand All @@ -17,20 +17,22 @@ def _find_output_dependencies(
return dependencies


def generate_outputs_definition(
def generate_outputs_class(
module: PyModule,
symbols: SharedSymbols,
symbol_output_class: str,
symbol_parent_function: str,
outputs: list[WithSymbol[OutputArgument]],
inputs: list[WithSymbol[InputArgument]],
sub_command_output_class_aliases: dict[str, str],
) -> None:
"""Generate the static output class definition."""
module.header.extend([
"",
"",
f"class {symbols.output_class}(typing.NamedTuple):",
f"class {symbol_output_class}(typing.NamedTuple):",
*indent([
'"""',
f"Output object returned when calling `{symbols.function}(...)`.",
f"Output object returned when calling `{symbol_parent_function}(...)`.",
'"""',
"root: OutputPathType",
'"""Output root folder. This is the root folder for all outputs."""',
Expand All @@ -51,16 +53,48 @@ def generate_outputs_definition(
])
)

for input_ in inputs:
if input_.data.type.primitive == InputTypePrimitive.SubCommand:
assert input_.data.sub_command is not None
module.header.extend(
indent([
f"{input_.symbol}: {sub_command_output_class_aliases[input_.data.sub_command.internal_id]}",
'"""Subcommand outputs"""',
])
)

if input_.data.type.primitive == InputTypePrimitive.SubCommandUnion:
assert input_.data.sub_command_union is not None

sub_commands = [
sub_command_output_class_aliases[sub_command.internal_id]
for sub_command in input_.data.sub_command_union
]
sub_commands_type = ", ".join(sub_commands)
sub_commands_type = f"typing.Union[{sub_commands_type}]"

if input_.data.type.is_list:
sub_commands_type = f"typing.List[{sub_commands_type}]"

module.header.extend(
indent([
f"{input_.symbol}: {sub_commands_type}",
'"""Subcommand outputs"""',
])
)


def generate_output_building(
func: PyFunc,
scopes: SharedScopes,
symbols: SharedSymbols,
func_scope: Scope,
symbol_execution: str,
symbol_output_class: str,
symbol_return_var: str,
outputs: list[WithSymbol[OutputArgument]],
inputs: list[WithSymbol[InputArgument]],
) -> None:
"""Generate the output building code."""
py_rstrip_fun = scopes.function.add_or_dodge("_rstrip")
py_rstrip_fun = func_scope.add_or_dodge("_rstrip")
if any([out.data.stripped_file_extensions is not None for out in outputs]):
func.body.extend([
f"def {py_rstrip_fun}(s, r):",
Expand All @@ -74,10 +108,10 @@ def generate_output_building(
]),
])

func.body.append(f"{symbols.ret} = {symbols.output_class}(")
func.body.append(f"{symbol_return_var} = {symbol_output_class}(")

# Set root output path
func.body.extend(indent([f'root={symbols.execution}.output_file("."),']))
func.body.extend(indent([f'root={symbol_execution}.output_file("."),']))

for out in outputs:
strip_extensions = out.data.stripped_file_extensions is not None
Expand All @@ -90,7 +124,7 @@ def generate_output_building(
if len(input_dependencies) == 0:
# No substitutions needed
func.body.extend(
indent([f"{out.symbol}={symbols.execution}.output_file(f{enquote(path_template)}{s_optional}),"])
indent([f"{out.symbol}={symbol_execution}.output_file(f{enquote(path_template)}{s_optional}),"])
)
else:
for input_ in input_dependencies:
Expand All @@ -100,9 +134,9 @@ def generate_output_building(
raise Exception(f"Output path template replacements cannot be lists. ({input_.data.name})")

if input_.data.type.primitive == InputTypePrimitive.File:
# Just use the stem of the file
# Just use the name of the file
# This is commonly used when output files 'inherit' the name of an input file
substitute = f"pathlib.Path({substitute}).stem"
substitute = f"pathlib.Path({substitute}).name"
elif (input_.data.type.primitive == InputTypePrimitive.Number) or (
input_.data.type.primitive == InputTypePrimitive.Integer
):
Expand All @@ -120,7 +154,7 @@ def generate_output_building(

path_template = path_template.replace(input_.data.template_key, enbrace(substitute))

resolved_output = f"{symbols.execution}.output_file(f{enquote(path_template)}{s_optional})"
resolved_output = f"{symbol_execution}.output_file(f{enquote(path_template)}{s_optional})"

if any([input_.data.type.is_optional for input_ in input_dependencies]):
# Codegen: Condition: Is any variable in the segment set by the user?
Expand All @@ -131,4 +165,18 @@ def generate_output_building(
else:
raise NotImplementedError

for input_ in inputs:
if (input_.data.type.primitive == InputTypePrimitive.SubCommand) or (
input_.data.type.primitive == InputTypePrimitive.SubCommandUnion
):
if input_.data.type.is_list:
func.body.extend(
indent([
f"{input_.symbol}="
f"[{input_.symbol}.outputs({symbol_execution}) for {input_.symbol} in {input_.symbol}],"
])
)
else:
func.body.extend(indent([f"{input_.symbol}={input_.symbol}.outputs({symbol_execution}),"]))

func.body.extend([")"])
119 changes: 94 additions & 25 deletions src/styx/compiler/compile/subcommand.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
from styx.compiler.compile.common import SharedSymbols
from styx.compiler.compile.constraints import generate_constraint_checks
from styx.compiler.compile.inputs import build_input_arguments, generate_command_line_args_building
from styx.model.core import InputArgument, InputTypePrimitive, SubCommand, WithSymbol
from styx.compiler.compile.outputs import generate_output_building, generate_outputs_class
from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, SubCommand, WithSymbol
from styx.pycodegen.core import PyArg, PyDataClass, PyFunc, PyModule, blank_before
from styx.pycodegen.scope import Scope
from styx.pycodegen.utils import python_pascalize, python_snakify


def _sub_command_class_name(parent_name: str, sub_command: SubCommand) -> str:
def _sub_command_class_name(sub_command: SubCommand) -> str:
"""Return the name of the sub-command class."""
return python_pascalize(f"{parent_name}_{sub_command.name}")
return python_pascalize(f"{sub_command.name}")


def _sub_command_output_class_name(sub_command: SubCommand) -> str:
"""Return the name of the sub-command output class."""
return python_pascalize(f"{sub_command.name}_Outputs")


def _generate_sub_command(
module: PyModule,
scope_module: Scope,
symbols: SharedSymbols,
sub_command: SubCommand,
outputs: list[WithSymbol[OutputArgument]],
inputs: list[WithSymbol[InputArgument]],
aliases: dict[str, str],
) -> str:
sub_command_output_class_aliases: dict[str, str],
) -> tuple[str, str]:
"""Generate the static output class definition."""
class_name = _sub_command_class_name(symbols.function, sub_command)
class_name = scope_module.add_or_dodge(_sub_command_class_name(sub_command))
output_class_name = scope_module.add_or_dodge(_sub_command_output_class_name(sub_command))

module.exports.append(class_name)
sub_command_class = PyDataClass(
Expand Down Expand Up @@ -53,54 +63,113 @@ def _generate_sub_command(
])
sub_command_class.methods.append(run_method)

# Outputs method

outputs_method = PyFunc(
name="outputs",
docstring_body="Collect output file paths.",
return_type=output_class_name,
return_descr=f"NamedTuple of outputs (described in `{output_class_name}`).",
args=[
PyArg(name="self", type=None, default=None, docstring="The sub-command object."),
PyArg(name="execution", type="Execution", default=None, docstring="The execution object."),
],
body=[],
)
generate_outputs_class(
module,
output_class_name,
class_name + ".run",
outputs,
inputs_self,
sub_command_output_class_aliases,
)
module.exports.append(output_class_name)
generate_output_building(outputs_method, Scope(), symbols.execution, output_class_name, "ret", outputs, inputs_self)
outputs_method.body.extend(["return ret"])
sub_command_class.methods.append(outputs_method)

module.header.extend(blank_before(sub_command_class.generate(), 2))
if "import dataclasses" not in module.imports:
module.imports.append("import dataclasses")

return class_name
return class_name, output_class_name


def generate_sub_command_classes(
module: PyModule,
symbols: SharedSymbols,
command: SubCommand,
scope: Scope,
) -> tuple[dict[str, str], list[WithSymbol[InputArgument]]]:
scope_module: Scope,
) -> tuple[dict[str, str], dict[str, str], list[WithSymbol[InputArgument]]]:
"""Build Python function arguments from input arguments."""
# internal_id -> class_name
aliases: dict[str, str] = {}
# subcommand.internal_id -> subcommand.outputs() class name
sub_command_output_class_aliases: dict[str, str] = {}

inputs_scope = Scope(parent=scope)
# outputs_scope = Scope(parent=scope)
inputs_scope = Scope(parent=scope_module)
outputs_scope = Scope(parent=scope_module)

# Input symbols
inputs: list[WithSymbol[InputArgument]] = []
for i in command.inputs:
py_symbol = inputs_scope.add_or_dodge(python_snakify(i.name))
inputs.append(WithSymbol(i, py_symbol))

# Output symbols
# outputs: list[WithSymbol[OutputArgument]] = []
# for output in command.outputs:
# py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name))
# outputs.append(WithSymbol(output, py_symbol))

for input_ in inputs:
if input_.data.type.primitive == InputTypePrimitive.SubCommand:
assert input_.data.sub_command is not None
sub_command = input_.data.sub_command
sub_aliases, sub_inputs = generate_sub_command_classes(module, symbols, sub_command, inputs_scope)
sub_aliases, sub_sub_command_output_class_aliases, sub_inputs = generate_sub_command_classes(
module, symbols, sub_command, inputs_scope
)
aliases.update(sub_aliases)
sub_command_type = _generate_sub_command(module, symbols, sub_command, sub_inputs, aliases)
if sub_command_type is not None:
aliases[sub_command.internal_id] = sub_command_type
sub_command_output_class_aliases.update(sub_sub_command_output_class_aliases)

sub_outputs = []
for output in sub_command.outputs:
py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name))
sub_outputs.append(WithSymbol(output, py_symbol))

sub_command_type, sub_command_output_type = _generate_sub_command(
module,
scope_module,
symbols,
sub_command,
sub_outputs,
sub_inputs,
aliases,
sub_command_output_class_aliases,
)
aliases[sub_command.internal_id] = sub_command_type
sub_command_output_class_aliases[sub_command.internal_id] = sub_command_output_type

if input_.data.type.primitive == InputTypePrimitive.SubCommandUnion:
assert input_.data.sub_command_union is not None
for sub_command in input_.data.sub_command_union:
sub_aliases, sub_inputs = generate_sub_command_classes(module, symbols, sub_command, inputs_scope)
sub_aliases, sub_sub_command_output_class_aliases, sub_inputs = generate_sub_command_classes(
module, symbols, sub_command, inputs_scope
)
aliases.update(sub_aliases)
sub_command_type = _generate_sub_command(module, symbols, sub_command, sub_inputs, aliases)
if sub_command_type is not None:
aliases[sub_command.internal_id] = sub_command_type
sub_command_output_class_aliases.update(sub_sub_command_output_class_aliases)

sub_outputs = []
for output in sub_command.outputs:
py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name))
sub_outputs.append(WithSymbol(output, py_symbol))

sub_command_type, sub_command_output_type = _generate_sub_command(
module,
scope_module,
symbols,
sub_command,
sub_outputs,
sub_inputs,
aliases,
sub_command_output_class_aliases,
)
aliases[sub_command.internal_id] = sub_command_type
sub_command_output_class_aliases[sub_command.internal_id] = sub_command_output_type

return aliases, inputs
return aliases, sub_command_output_class_aliases, inputs
Loading

0 comments on commit 0f47d4a

Please sign in to comment.