Skip to content

Commit

Permalink
allow more type annotations in dataclass_array_container
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Jun 24, 2022
1 parent 80813d7 commit 201f973
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 48 deletions.
44 changes: 18 additions & 26 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,47 +51,39 @@ def dataclass_array_container(cls: type) -> type:
Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`.
is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
"""
from dataclasses import is_dataclass, Field
assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)

origin = get_origin(f.type)
if origin is Union:
if not all(
arg is Array or is_array_container_type(arg)
for arg in get_args(f.type)):
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
else:
return True

def is_array_field(f: Field) -> bool:
if __debug__:
if not f.init:
raise ValueError(
f"'init=False' field not allowed: '{f.name}'")
f"Fields with 'init=False' not allowed: '{f.name}'")

if isinstance(f.type, str):
raise TypeError(
f"string annotation on field '{f.name}' not supported")
f"String annotation on field '{f.name}' not supported")

from typing import _SpecialForm
if isinstance(f.type, _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"typing annotation not supported on field '{f.name}': "
f"'{f.type!r}'")
origin = get_origin(f.type)
if origin is Union:
return all(is_array_type(arg) for arg in get_args(f.type))

if not isinstance(f.type, type):
raise TypeError(
f"field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'")
from typing import _GenericAlias, _SpecialForm
if isinstance(f.type, (_GenericAlias, _SpecialForm)):
return False

return f.type is Array or is_array_container_type(f.type)
return is_array_type(f.type)

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
Expand Down
38 changes: 16 additions & 22 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_pt_actx_key_stringification_uniqueness():
# {{{ test_dataclass_array_container

def test_dataclass_array_container():
from typing import Optional
from dataclasses import dataclass, field
from arraycontext import dataclass_array_container

Expand All @@ -64,19 +63,6 @@ class ArrayContainerWithStringTypes:

# }}}

# {{{ optional fields

@dataclass
class ArrayContainerWithOptional:
x: np.ndarray
y: Optional[np.ndarray]

with pytest.raises(TypeError):
# NOTE: cannot have wrapped annotations (here by `Optional`)
dataclass_array_container(ArrayContainerWithOptional)

# }}}

# {{{ field(init=False)

@dataclass
Expand Down Expand Up @@ -106,36 +92,44 @@ class ArrayContainerWithArray:
# }}}


# {{{ test_dataclass_container_unions
# {{{ test_dataclass_container_type_annotations

def test_dataclass_container_unions():
def test_dataclass_container_type_annotations():
from dataclasses import dataclass
from arraycontext import dataclass_array_container

from typing import Union
from typing import Optional, Tuple, Union
from arraycontext import Array

# {{{ union fields

@dataclass_array_container
@dataclass
class ArrayContainerWithUnion:
x: np.ndarray
y: Union[np.ndarray, Array]

dataclass_array_container(ArrayContainerWithUnion)

# }}}

# {{{ non-container union

@dataclass_array_container
@dataclass
class ArrayContainerWithWrongUnion:
x: np.ndarray
y: Union[np.ndarray, float]

with pytest.raises(TypeError):
# NOTE: float is not an ArrayContainer, so y should fail
dataclass_array_container(ArrayContainerWithWrongUnion)
# }}}

# {{{ optional and other fields

@dataclass_array_container
@dataclass
class ArrayContainerWithAnnotations:
x: np.ndarray
y: Tuple[float, float]
z: Optional[np.ndarray]
w: str

# }}}

Expand Down

0 comments on commit 201f973

Please sign in to comment.