From 201f9736d537b27779ce5a8c8ea0c06a341477fc Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 24 Jun 2022 20:08:22 +0300 Subject: [PATCH] allow more type annotations in dataclass_array_container --- arraycontext/container/dataclass.py | 44 ++++++++++++----------------- test/test_utils.py | 38 +++++++++++-------------- 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index edbb4506..0d2c3791 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -51,47 +51,39 @@ def dataclass_array_container(cls: type) -> type: Attributes that are not array containers are allowed. In order to decide whether an attribute is an array container, the declared attribute type - is checked by the criteria from :func:`is_array_container_type`. + is checked by the criteria from :func:`is_array_container_type`. This + includes some support for type annotations: + + * a :class:`typing.Union` of array containers is considered an array container. + * other type annotations, e.g. :class:`typing.Optional`, are not considered + array containers, even if they wrap one. """ from dataclasses import is_dataclass, Field assert is_dataclass(cls) - def is_array_field(f: Field) -> bool: + def is_array_type(tp: type) -> bool: from arraycontext import Array + return tp is Array or is_array_container_type(tp) - origin = get_origin(f.type) - if origin is Union: - if not all( - arg is Array or is_array_container_type(arg) - for arg in get_args(f.type)): - raise TypeError( - f"Field '{f.name}' union contains non-array container " - "arguments. All arguments must be array containers.") - else: - return True - + def is_array_field(f: Field) -> bool: if __debug__: if not f.init: raise ValueError( - f"'init=False' field not allowed: '{f.name}'") + f"Fields with 'init=False' not allowed: '{f.name}'") if isinstance(f.type, str): raise TypeError( - f"string annotation on field '{f.name}' not supported") + f"String annotation on field '{f.name}' not supported") - from typing import _SpecialForm - if isinstance(f.type, _SpecialForm): - # NOTE: anything except a Union is not allowed - raise TypeError( - f"typing annotation not supported on field '{f.name}': " - f"'{f.type!r}'") + origin = get_origin(f.type) + if origin is Union: + return all(is_array_type(arg) for arg in get_args(f.type)) - if not isinstance(f.type, type): - raise TypeError( - f"field '{f.name}' not an instance of 'type': " - f"'{f.type!r}'") + from typing import _GenericAlias, _SpecialForm + if isinstance(f.type, (_GenericAlias, _SpecialForm)): + return False - return f.type is Array or is_array_container_type(f.type) + return is_array_type(f.type) from pytools import partition array_fields, non_array_fields = partition(is_array_field, fields(cls)) diff --git a/test/test_utils.py b/test/test_utils.py index 7a12ad27..60898381 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -47,7 +47,6 @@ def test_pt_actx_key_stringification_uniqueness(): # {{{ test_dataclass_array_container def test_dataclass_array_container(): - from typing import Optional from dataclasses import dataclass, field from arraycontext import dataclass_array_container @@ -64,19 +63,6 @@ class ArrayContainerWithStringTypes: # }}} - # {{{ optional fields - - @dataclass - class ArrayContainerWithOptional: - x: np.ndarray - y: Optional[np.ndarray] - - with pytest.raises(TypeError): - # NOTE: cannot have wrapped annotations (here by `Optional`) - dataclass_array_container(ArrayContainerWithOptional) - - # }}} - # {{{ field(init=False) @dataclass @@ -106,36 +92,44 @@ class ArrayContainerWithArray: # }}} -# {{{ test_dataclass_container_unions +# {{{ test_dataclass_container_type_annotations -def test_dataclass_container_unions(): +def test_dataclass_container_type_annotations(): from dataclasses import dataclass from arraycontext import dataclass_array_container - from typing import Union + from typing import Optional, Tuple, Union from arraycontext import Array # {{{ union fields + @dataclass_array_container @dataclass class ArrayContainerWithUnion: x: np.ndarray y: Union[np.ndarray, Array] - dataclass_array_container(ArrayContainerWithUnion) - # }}} # {{{ non-container union + @dataclass_array_container @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray y: Union[np.ndarray, float] - with pytest.raises(TypeError): - # NOTE: float is not an ArrayContainer, so y should fail - dataclass_array_container(ArrayContainerWithWrongUnion) + # }}} + + # {{{ optional and other fields + + @dataclass_array_container + @dataclass + class ArrayContainerWithAnnotations: + x: np.ndarray + y: Tuple[float, float] + z: Optional[np.ndarray] + w: str # }}}