Skip to content

Commit

Permalink
Improve & document IR API
Browse files Browse the repository at this point in the history
  • Loading branch information
nx10 committed Sep 26, 2024
1 parent f7a826a commit 68a36ad
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/styx/backend/python/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def generate_static_metadata(
"""Generate the static metadata."""
metadata_symbol = scope.add_or_dodge(f"{python_screaming_snakify(interface.command.base.name)}_METADATA")

entries = {
entries: dict = {
"id": interface.uid,
"name": interface.command.base.name,
"package": interface.package.name,
Expand Down
4 changes: 2 additions & 2 deletions src/styx/backend/python/pycodegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def generate(self) -> LineBuffer:
if self.docstring_body:
docstring_linebroken = linebreak_paragraph(self.docstring_body, width=80 - 4)
else:
docstring_linebroken = ""
docstring_linebroken = [""]

buf.extend(
indent([
Expand All @@ -150,7 +150,7 @@ class PyDataClass(PyGen):
"""Python generate."""

name: str
docstring: str
docstring: str | None
fields: list[PyArg] = field(default_factory=list)
methods: list[PyFunc] = field(default_factory=list)
is_named_tuple: bool = False
Expand Down
6 changes: 3 additions & 3 deletions src/styx/backend/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def _val() -> tuple[str, bool]:
if isinstance(param.body, ir.Param.Bool):
as_list = (len(param.body.value_true) > 1) or (len(param.body.value_false) > 1)
if as_list:
value_true = param.body.value_true
value_false = param.body.value_false
value_true: str | list[str] | None = param.body.value_true
value_false: str | list[str] | None = param.body.value_false
else:
value_true = param.body.value_true[0] if len(param.body.value_true) > 0 else None
value_false = param.body.value_false[0] if len(param.body.value_false) > 0 else None
Expand Down Expand Up @@ -136,7 +136,7 @@ def param_py_default_value(param: ir.Param) -> str | None:
return "None"
if param.default_value is None:
return None
return as_py_literal(param.default_value)
return as_py_literal(param.default_value) # type: ignore


def param_py_var_is_set_by_user(
Expand Down
7 changes: 4 additions & 3 deletions src/styx/frontend/boutiques/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def _arg_elem_from_bt_elem(
assert choices is None or all([
isinstance(o, int) for o in choices
]), "value-choices must be all int for integer input"
assert constraints.value_min is None or isinstance(constraints.value_min, int)
assert constraints.value_max is None or isinstance(constraints.value_max, int)

return ir.Param(
base=dparam,
Expand Down Expand Up @@ -229,7 +231,7 @@ def _arg_elem_from_bt_elem(
return ir.Param(
base=dparam,
body=ir.Param.File(
resolve_parent=d.get("resolve-parent"),
resolve_parent=d.get("resolve-parent") is True,
),
list_=dlist,
nullable=input_type.is_optional,
Expand All @@ -240,7 +242,6 @@ def _arg_elem_from_bt_elem(
input_prefix = d.get("command-line-flag")
assert input_prefix is not None, "Flag type input must have command-line-flag"

dparam.prefix = []
return ir.Param(
base=dparam,
body=ir.Param.Bool(
Expand Down Expand Up @@ -390,7 +391,7 @@ def _collect_outputs(bt: dict, ir_id_lookup: dict[str, ir.IdType], id_counter: I
for bt_output in bt.get("output-files", []):
path_template = bt_output["path-template"]
destructed = destruct_template(path_template, ir_id_lookup)
output_sequence = [
output_sequence: list[str, ir.OutputParamReference] = [ # type: ignore # mypy is wrong
ir.OutputParamReference(
ref_id=x,
file_remove_suffixes=bt_output.get("path-template-stripped-extensions", []),
Expand Down
177 changes: 139 additions & 38 deletions src/styx/ir/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
from dataclasses import dataclass
from typing import Any, Generator, Generic, Optional, TypeVar, Union
from typing import Any, Generator, Generic, Optional, TypeGuard, TypeVar, Union


@dataclass
Expand Down Expand Up @@ -67,7 +69,7 @@ class Output:
tokens: list[str | OutputParamReference] = dataclasses.field(default_factory=list)
"""List of tokens and/or parameter references. This is similar to a Python f-string."""

docs: Documentation | None = None
docs: Documentation = dataclasses.field(default_factory=Documentation)
"""Documentation for the output."""


Expand All @@ -90,7 +92,7 @@ class Base:
outputs: list[Output] = dataclasses.field(default_factory=list)
"""List of outputs associated with this parameter."""

docs: Documentation | None = None
docs: Documentation = dataclasses.field(default_factory=Documentation)
"""Documentation for the parameter."""

@dataclass
Expand All @@ -108,6 +110,7 @@ class List:

class SetToNone:
"""Represents a parameter that can be set to None."""

pass

@dataclass
Expand Down Expand Up @@ -143,6 +146,7 @@ class Float:
@dataclass
class String:
"""Represents string parameters."""

pass

@dataclass
Expand All @@ -159,15 +163,14 @@ class Struct:
name: str | None = None
"""Name of the struct."""

groups: list["ConditionalGroup"] = dataclasses.field(default_factory=list)
groups: list[ConditionalGroup] = dataclasses.field(default_factory=list)
"""List of conditional groups."""

docs: Documentation | None = None
"""Documentation for the struct."""

def iter_params(self) -> Generator["Param", Any, None]:
"""
Iterate over all parameters in the struct.
def iter_params(self) -> Generator[Param, Any, None]:
"""Iterate over all parameters in the struct.
Yields:
Each parameter in the struct.
Expand All @@ -179,22 +182,21 @@ def iter_params(self) -> Generator["Param", Any, None]:
class StructUnion:
"""Represents a union of struct parameters."""

alts: list["Param[Param.Struct]"] = dataclasses.field(default_factory=list)
alts: list[Param[Param.Struct]] = dataclasses.field(default_factory=list)
"""List of alternative struct parameters."""

def __init__(
self,
base: Base,
body: Union[Bool, Int, Float, String, File, Struct, StructUnion],
list_: Optional[List] = None,
nullable: bool = False,
choices: Optional[list[Union[str, int, float]]] = None,
default_value: Union[
bool, str, int, float, list[bool], list[str], list[int], list[float], SetToNone, None
] = None,
self,
base: Base,
body: T,
list_: Optional[List] = None,
nullable: bool = False,
choices: Optional[list[Union[str, int, float]]] = None,
default_value: Union[
bool, str, int, float, list[bool], list[str], list[int], list[float], type[SetToNone], None
] = None,
) -> None:
"""
Initialize a Param instance.
"""Initialize a Param instance.
Args:
base: Base parameter information.
Expand Down Expand Up @@ -232,8 +234,7 @@ def _check_base(self) -> None:
def _check_body_type(self) -> None:
"""Check if body is an instance of one of the allowed types."""
if not isinstance(
self.body,
(Param.Bool, Param.Int, Param.Float, Param.String, Param.File, Param.Struct, Param.StructUnion)
self.body, (Param.Bool, Param.Int, Param.Float, Param.String, Param.File, Param.Struct, Param.StructUnion)
):
raise TypeError(
"body must be an instance of "
Expand All @@ -257,7 +258,7 @@ def _check_choices(self) -> None:
raise TypeError("choices must be None or a list")
expected_type = self._get_expected_type()
if expected_type is not None and not all(isinstance(choice, expected_type) for choice in self.choices):
raise TypeError(f"All choices must be of type {expected_type.__name__}")
raise TypeError(f"All choices must be of type {' or '.join([e.__name__ for e in expected_type])}")

def _check_default_value(self) -> None:
"""Check if default_value is of the correct type."""
Expand All @@ -275,12 +276,12 @@ def _check_default_value(self) -> None:
if not isinstance(self.default_value, list):
raise TypeError("default_value must be a list when list_ is provided")
if not all(isinstance(item, expected_type) for item in self.default_value):
raise TypeError(f"All items in default_value must be of type {expected_type.__name__}")
raise TypeError(
f"All items in default_value must be of type {' or '.join([e.__name__ for e in expected_type])}"
)
else:
if not isinstance(self.default_value, expected_type):
if isinstance(expected_type, tuple):
raise TypeError(f"default_value must be of type {' or '.join([e.__name__ for e in expected_type])}")
raise TypeError(f"default_value must be of type {expected_type.__name__}")
raise TypeError(f"default_value must be of type {' or '.join([e.__name__ for e in expected_type])}")

def _check_constraints(self) -> None:
"""Check if all constraints are satisfied."""
Expand All @@ -290,10 +291,11 @@ def _check_constraints(self) -> None:
raise ValueError("min_value cannot be greater than max_value")

if (
self.default_value is not None
and self.default_value is not Param.SetToNone
and not isinstance(self.default_value, (list, Param.SetToNone))
self.default_value is not None
and self.default_value is not Param.SetToNone
and not isinstance(self.default_value, (list, Param.SetToNone))
):
assert isinstance(self.default_value, (int, float))
if self.body.min_value is not None and self.default_value < self.body.min_value:
raise ValueError(f"default_value cannot be less than min_value ({self.body.min_value})")
if self.body.max_value is not None and self.default_value > self.body.max_value:
Expand All @@ -314,22 +316,123 @@ def _check_constraints(self) -> None:
f"default_value list length cannot be greater than count_max ({self.list_.count_max})"
)

def _get_expected_type(self) -> type | tuple[type, ...] | None:
def _get_expected_type(self) -> tuple[type, ...] | None:
"""Get the expected type based on the body type."""
if isinstance(self.body, Param.Bool):
return bool
return (bool,)
elif isinstance(self.body, Param.Int):
return int
return (int,)
elif isinstance(self.body, Param.Float):
return float, int
elif isinstance(self.body, Param.String):
return str
return (str,)
elif isinstance(self.body, (Param.File, Param.Struct, Param.StructUnion)):
return None
else:
raise TypeError("Unknown body type")


# Unfortunately TypeGuards dont work as methods with implicit self


def is_bool(param: Param[Any]) -> TypeGuard[Param[Param.Bool]]:
"""Check if the parameter is a boolean type.
Args:
param: The parameter to check.
Returns:
True if the parameter is a boolean type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.Bool)


def is_int(param: Param[Any]) -> TypeGuard[Param[Param.Int]]:
"""Check if the parameter is an integer type.
Args:
param: The parameter to check.
Returns:
True if the parameter is an integer type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.Int)


def is_float(param: Param[Any]) -> TypeGuard[Param[Param.Float]]:
"""Check if the parameter is a float type.
Args:
param: The parameter to check.
Returns:
True if the parameter is a float type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.Float)


def is_string(param: Param[Any]) -> TypeGuard[Param[Param.String]]:
"""Check if the parameter is a string type.
Args:
param: The parameter to check.
Returns:
True if the parameter is a string type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.String)


def is_file(param: Param[Any]) -> TypeGuard[Param[Param.File]]:
"""Check if the parameter is a file type.
Args:
param: The parameter to check.
Returns:
True if the parameter is a file type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.File)


def is_struct(param: Param[Any]) -> TypeGuard[Param[Param.Struct]]:
"""Check if the parameter is a struct type.
Args:
param: The parameter to check.
Returns:
True if the parameter is a struct type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.Struct)


def is_struct_union(param: Param[Any]) -> TypeGuard[Param[Param.StructUnion]]:
"""Check if the parameter is a struct union type.
Args:
param: The parameter to check.
Returns:
True if the parameter is a struct union type, False otherwise.
This function can be used for type narrowing in conditional blocks.
"""
return isinstance(param.body, Param.StructUnion)


@dataclass
class Carg:
"""Represents command arguments."""
Expand All @@ -338,8 +441,7 @@ class Carg:
"""List of parameters or string tokens."""

def iter_params(self) -> Generator[Param, Any, None]:
"""
Iterate over all parameters in the command argument.
"""Iterate over all parameters in the command argument.
Yields:
Each parameter in the command argument.
Expand All @@ -357,8 +459,7 @@ class ConditionalGroup:
"""List of command arguments."""

def iter_params(self) -> Generator[Param, Any, None]:
"""
Iterate over all parameters in the conditional group.
"""Iterate over all parameters in the conditional group.
Yields:
Each parameter in the conditional group.
Expand All @@ -378,4 +479,4 @@ class Interface:
"""The package associated with this interface."""

command: Param[Param.Struct]
"""The command structure for this interface."""
"""The command structure for this interface."""

0 comments on commit 68a36ad

Please sign in to comment.