From 68a36ad54e802cbe79bdf5569e36e6d7d3602e44 Mon Sep 17 00:00:00 2001 From: Florian Rupprecht Date: Wed, 25 Sep 2024 21:50:48 -0400 Subject: [PATCH] Improve & document IR API --- src/styx/backend/python/metadata.py | 2 +- src/styx/backend/python/pycodegen/core.py | 4 +- src/styx/backend/python/utils.py | 6 +- src/styx/frontend/boutiques/core.py | 7 +- src/styx/ir/core.py | 177 +++++++++++++++++----- 5 files changed, 149 insertions(+), 47 deletions(-) diff --git a/src/styx/backend/python/metadata.py b/src/styx/backend/python/metadata.py index 82af799..73c4a67 100644 --- a/src/styx/backend/python/metadata.py +++ b/src/styx/backend/python/metadata.py @@ -12,7 +12,7 @@ def generate_static_metadata( """Generate the static metadata.""" metadata_symbol = scope.add_or_dodge(f"{python_screaming_snakify(interface.command.base.name)}_METADATA") - entries = { + entries: dict = { "id": interface.uid, "name": interface.command.base.name, "package": interface.package.name, diff --git a/src/styx/backend/python/pycodegen/core.py b/src/styx/backend/python/pycodegen/core.py index d533591..755356c 100644 --- a/src/styx/backend/python/pycodegen/core.py +++ b/src/styx/backend/python/pycodegen/core.py @@ -123,7 +123,7 @@ def generate(self) -> LineBuffer: if self.docstring_body: docstring_linebroken = linebreak_paragraph(self.docstring_body, width=80 - 4) else: - docstring_linebroken = "" + docstring_linebroken = [""] buf.extend( indent([ @@ -150,7 +150,7 @@ class PyDataClass(PyGen): """Python generate.""" name: str - docstring: str + docstring: str | None fields: list[PyArg] = field(default_factory=list) methods: list[PyFunc] = field(default_factory=list) is_named_tuple: bool = False diff --git a/src/styx/backend/python/utils.py b/src/styx/backend/python/utils.py index 13b47ea..fddd06d 100644 --- a/src/styx/backend/python/utils.py +++ b/src/styx/backend/python/utils.py @@ -78,8 +78,8 @@ def _val() -> tuple[str, bool]: if isinstance(param.body, ir.Param.Bool): as_list = (len(param.body.value_true) > 1) or (len(param.body.value_false) > 1) if as_list: - value_true = param.body.value_true - value_false = param.body.value_false + value_true: str | list[str] | None = param.body.value_true + value_false: str | list[str] | None = param.body.value_false else: value_true = param.body.value_true[0] if len(param.body.value_true) > 0 else None value_false = param.body.value_false[0] if len(param.body.value_false) > 0 else None @@ -136,7 +136,7 @@ def param_py_default_value(param: ir.Param) -> str | None: return "None" if param.default_value is None: return None - return as_py_literal(param.default_value) + return as_py_literal(param.default_value) # type: ignore def param_py_var_is_set_by_user( diff --git a/src/styx/frontend/boutiques/core.py b/src/styx/frontend/boutiques/core.py index 6eb1320..b4838ba 100644 --- a/src/styx/frontend/boutiques/core.py +++ b/src/styx/frontend/boutiques/core.py @@ -196,6 +196,8 @@ def _arg_elem_from_bt_elem( assert choices is None or all([ isinstance(o, int) for o in choices ]), "value-choices must be all int for integer input" + assert constraints.value_min is None or isinstance(constraints.value_min, int) + assert constraints.value_max is None or isinstance(constraints.value_max, int) return ir.Param( base=dparam, @@ -229,7 +231,7 @@ def _arg_elem_from_bt_elem( return ir.Param( base=dparam, body=ir.Param.File( - resolve_parent=d.get("resolve-parent"), + resolve_parent=d.get("resolve-parent") is True, ), list_=dlist, nullable=input_type.is_optional, @@ -240,7 +242,6 @@ def _arg_elem_from_bt_elem( input_prefix = d.get("command-line-flag") assert input_prefix is not None, "Flag type input must have command-line-flag" - dparam.prefix = [] return ir.Param( base=dparam, body=ir.Param.Bool( @@ -390,7 +391,7 @@ def _collect_outputs(bt: dict, ir_id_lookup: dict[str, ir.IdType], id_counter: I for bt_output in bt.get("output-files", []): path_template = bt_output["path-template"] destructed = destruct_template(path_template, ir_id_lookup) - output_sequence = [ + output_sequence: list[str, ir.OutputParamReference] = [ # type: ignore # mypy is wrong ir.OutputParamReference( ref_id=x, file_remove_suffixes=bt_output.get("path-template-stripped-extensions", []), diff --git a/src/styx/ir/core.py b/src/styx/ir/core.py index 0193213..bf409a5 100644 --- a/src/styx/ir/core.py +++ b/src/styx/ir/core.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import dataclasses from dataclasses import dataclass -from typing import Any, Generator, Generic, Optional, TypeVar, Union +from typing import Any, Generator, Generic, Optional, TypeGuard, TypeVar, Union @dataclass @@ -67,7 +69,7 @@ class Output: tokens: list[str | OutputParamReference] = dataclasses.field(default_factory=list) """List of tokens and/or parameter references. This is similar to a Python f-string.""" - docs: Documentation | None = None + docs: Documentation = dataclasses.field(default_factory=Documentation) """Documentation for the output.""" @@ -90,7 +92,7 @@ class Base: outputs: list[Output] = dataclasses.field(default_factory=list) """List of outputs associated with this parameter.""" - docs: Documentation | None = None + docs: Documentation = dataclasses.field(default_factory=Documentation) """Documentation for the parameter.""" @dataclass @@ -108,6 +110,7 @@ class List: class SetToNone: """Represents a parameter that can be set to None.""" + pass @dataclass @@ -143,6 +146,7 @@ class Float: @dataclass class String: """Represents string parameters.""" + pass @dataclass @@ -159,15 +163,14 @@ class Struct: name: str | None = None """Name of the struct.""" - groups: list["ConditionalGroup"] = dataclasses.field(default_factory=list) + groups: list[ConditionalGroup] = dataclasses.field(default_factory=list) """List of conditional groups.""" docs: Documentation | None = None """Documentation for the struct.""" - def iter_params(self) -> Generator["Param", Any, None]: - """ - Iterate over all parameters in the struct. + def iter_params(self) -> Generator[Param, Any, None]: + """Iterate over all parameters in the struct. Yields: Each parameter in the struct. @@ -179,22 +182,21 @@ def iter_params(self) -> Generator["Param", Any, None]: class StructUnion: """Represents a union of struct parameters.""" - alts: list["Param[Param.Struct]"] = dataclasses.field(default_factory=list) + alts: list[Param[Param.Struct]] = dataclasses.field(default_factory=list) """List of alternative struct parameters.""" def __init__( - self, - base: Base, - body: Union[Bool, Int, Float, String, File, Struct, StructUnion], - list_: Optional[List] = None, - nullable: bool = False, - choices: Optional[list[Union[str, int, float]]] = None, - default_value: Union[ - bool, str, int, float, list[bool], list[str], list[int], list[float], SetToNone, None - ] = None, + self, + base: Base, + body: T, + list_: Optional[List] = None, + nullable: bool = False, + choices: Optional[list[Union[str, int, float]]] = None, + default_value: Union[ + bool, str, int, float, list[bool], list[str], list[int], list[float], type[SetToNone], None + ] = None, ) -> None: - """ - Initialize a Param instance. + """Initialize a Param instance. Args: base: Base parameter information. @@ -232,8 +234,7 @@ def _check_base(self) -> None: def _check_body_type(self) -> None: """Check if body is an instance of one of the allowed types.""" if not isinstance( - self.body, - (Param.Bool, Param.Int, Param.Float, Param.String, Param.File, Param.Struct, Param.StructUnion) + self.body, (Param.Bool, Param.Int, Param.Float, Param.String, Param.File, Param.Struct, Param.StructUnion) ): raise TypeError( "body must be an instance of " @@ -257,7 +258,7 @@ def _check_choices(self) -> None: raise TypeError("choices must be None or a list") expected_type = self._get_expected_type() if expected_type is not None and not all(isinstance(choice, expected_type) for choice in self.choices): - raise TypeError(f"All choices must be of type {expected_type.__name__}") + raise TypeError(f"All choices must be of type {' or '.join([e.__name__ for e in expected_type])}") def _check_default_value(self) -> None: """Check if default_value is of the correct type.""" @@ -275,12 +276,12 @@ def _check_default_value(self) -> None: if not isinstance(self.default_value, list): raise TypeError("default_value must be a list when list_ is provided") if not all(isinstance(item, expected_type) for item in self.default_value): - raise TypeError(f"All items in default_value must be of type {expected_type.__name__}") + raise TypeError( + f"All items in default_value must be of type {' or '.join([e.__name__ for e in expected_type])}" + ) else: if not isinstance(self.default_value, expected_type): - if isinstance(expected_type, tuple): - raise TypeError(f"default_value must be of type {' or '.join([e.__name__ for e in expected_type])}") - raise TypeError(f"default_value must be of type {expected_type.__name__}") + raise TypeError(f"default_value must be of type {' or '.join([e.__name__ for e in expected_type])}") def _check_constraints(self) -> None: """Check if all constraints are satisfied.""" @@ -290,10 +291,11 @@ def _check_constraints(self) -> None: raise ValueError("min_value cannot be greater than max_value") if ( - self.default_value is not None - and self.default_value is not Param.SetToNone - and not isinstance(self.default_value, (list, Param.SetToNone)) + self.default_value is not None + and self.default_value is not Param.SetToNone + and not isinstance(self.default_value, (list, Param.SetToNone)) ): + assert isinstance(self.default_value, (int, float)) if self.body.min_value is not None and self.default_value < self.body.min_value: raise ValueError(f"default_value cannot be less than min_value ({self.body.min_value})") if self.body.max_value is not None and self.default_value > self.body.max_value: @@ -314,22 +316,123 @@ def _check_constraints(self) -> None: f"default_value list length cannot be greater than count_max ({self.list_.count_max})" ) - def _get_expected_type(self) -> type | tuple[type, ...] | None: + def _get_expected_type(self) -> tuple[type, ...] | None: """Get the expected type based on the body type.""" if isinstance(self.body, Param.Bool): - return bool + return (bool,) elif isinstance(self.body, Param.Int): - return int + return (int,) elif isinstance(self.body, Param.Float): return float, int elif isinstance(self.body, Param.String): - return str + return (str,) elif isinstance(self.body, (Param.File, Param.Struct, Param.StructUnion)): return None else: raise TypeError("Unknown body type") +# Unfortunately TypeGuards dont work as methods with implicit self + + +def is_bool(param: Param[Any]) -> TypeGuard[Param[Param.Bool]]: + """Check if the parameter is a boolean type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is a boolean type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.Bool) + + +def is_int(param: Param[Any]) -> TypeGuard[Param[Param.Int]]: + """Check if the parameter is an integer type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is an integer type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.Int) + + +def is_float(param: Param[Any]) -> TypeGuard[Param[Param.Float]]: + """Check if the parameter is a float type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is a float type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.Float) + + +def is_string(param: Param[Any]) -> TypeGuard[Param[Param.String]]: + """Check if the parameter is a string type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is a string type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.String) + + +def is_file(param: Param[Any]) -> TypeGuard[Param[Param.File]]: + """Check if the parameter is a file type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is a file type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.File) + + +def is_struct(param: Param[Any]) -> TypeGuard[Param[Param.Struct]]: + """Check if the parameter is a struct type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is a struct type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.Struct) + + +def is_struct_union(param: Param[Any]) -> TypeGuard[Param[Param.StructUnion]]: + """Check if the parameter is a struct union type. + + Args: + param: The parameter to check. + + Returns: + True if the parameter is a struct union type, False otherwise. + + This function can be used for type narrowing in conditional blocks. + """ + return isinstance(param.body, Param.StructUnion) + + @dataclass class Carg: """Represents command arguments.""" @@ -338,8 +441,7 @@ class Carg: """List of parameters or string tokens.""" def iter_params(self) -> Generator[Param, Any, None]: - """ - Iterate over all parameters in the command argument. + """Iterate over all parameters in the command argument. Yields: Each parameter in the command argument. @@ -357,8 +459,7 @@ class ConditionalGroup: """List of command arguments.""" def iter_params(self) -> Generator[Param, Any, None]: - """ - Iterate over all parameters in the conditional group. + """Iterate over all parameters in the conditional group. Yields: Each parameter in the conditional group. @@ -378,4 +479,4 @@ class Interface: """The package associated with this interface.""" command: Param[Param.Struct] - """The command structure for this interface.""" \ No newline at end of file + """The command structure for this interface."""