Skip to content

Commit

Permalink
allow more type annotations in dataclass_array_container
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Jun 26, 2022
1 parent ff1cd0c commit 4ba3383
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 74 deletions.
133 changes: 81 additions & 52 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,75 +30,44 @@
THE SOFTWARE.
"""

from typing import Union, get_args
from typing import Tuple, Union, get_args
try:
# NOTE: only available in python >= 3.8
from typing import get_origin
except ImportError:
from typing_extensions import get_origin

from dataclasses import fields
from dataclasses import is_dataclass, fields, Field
from arraycontext.container import is_array_container_type


# {{{ dataclass containers

def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
:func:`serialize_container` and :func:`deserialize_container`.
*cls* must be a :func:`~dataclasses.dataclass`.
def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)

Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`.
"""
from dataclasses import is_dataclass, Field
assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
from arraycontext import Array
def inject_container_serialization(
cls: type,
array_fields: Tuple[Field, ...],
non_array_fields: Tuple[Field, ...],
) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given class *cls*.
origin = get_origin(f.type)
if origin is Union:
if not all(
arg is Array or is_array_container_type(arg)
for arg in get_args(f.type)):
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
else:
return True
This function modifies *cls* in place, so the returned value is the same
object with additional functionality.
if __debug__:
if not f.init:
raise ValueError(
f"'init=False' field not allowed: '{f.name}'")
:arg array_fields: fields of the given dataclass *cls* which are considered
array containers and should be serialized.
:arg non_array_fields: remaining fields of the dataclass *cls* which are
copied over from the template array in deserialization.
if isinstance(f.type, str):
raise TypeError(
f"string annotation on field '{f.name}' not supported")

from typing import _SpecialForm
if isinstance(f.type, _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"typing annotation not supported on field '{f.name}': "
f"'{f.type!r}'")

if not isinstance(f.type, type):
raise TypeError(
f"field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'")

return f.type is Array or is_array_container_type(f.type)

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
:returns: the input class *cls*.
"""

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")
assert is_dataclass(cls)

serialize_expr = ", ".join(
f"({f.name!r}, ary.{f.name})" for f in array_fields)
Expand Down Expand Up @@ -153,6 +122,66 @@ def _deserialize_init_arrays_code_{lower_cls_name}(

return cls


def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
:func:`serialize_container` and :func:`deserialize_container`.
*cls* must be a :func:`~dataclasses.dataclass`.
Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
"""
assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
if __debug__:
if not f.init:
raise ValueError(
f"Fields with 'init=False' not allowed: '{f.name}'")

if isinstance(f.type, str):
raise TypeError(
f"String annotation on field '{f.name}' not supported")

# NOTE: unions of array containers are treated seprately to allow
# * unions of only array containers, e.g. Union[np.ndarray, Array], as
# they can work seamlessly with arithmetic and traversal.
# * `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage.
#
# Other type annotations, e.g. `Tuple[Container, Container]`, are also
# not allowed, as they do not work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!

origin = get_origin(f.type)
if origin is Union:
# NOTE: `Optional` is caught in here as an alias for `Union[Anon, type]`
return all(is_array_type(arg) for arg in get_args(f.type))

from typing import _GenericAlias, _SpecialForm # type: ignore[attr-defined]
if isinstance(f.type, (_GenericAlias, _SpecialForm)):
return False

return is_array_type(f.type)

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return inject_container_serialization(cls, array_fields, non_array_fields)

# }}}

# vim: foldmethod=marker
38 changes: 16 additions & 22 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_pt_actx_key_stringification_uniqueness():
# {{{ test_dataclass_array_container

def test_dataclass_array_container():
from typing import Optional
from dataclasses import dataclass, field
from arraycontext import dataclass_array_container

Expand All @@ -64,19 +63,6 @@ class ArrayContainerWithStringTypes:

# }}}

# {{{ optional fields

@dataclass
class ArrayContainerWithOptional:
x: np.ndarray
y: Optional[np.ndarray]

with pytest.raises(TypeError):
# NOTE: cannot have wrapped annotations (here by `Optional`)
dataclass_array_container(ArrayContainerWithOptional)

# }}}

# {{{ field(init=False)

@dataclass
Expand Down Expand Up @@ -106,36 +92,44 @@ class ArrayContainerWithArray:
# }}}


# {{{ test_dataclass_container_unions
# {{{ test_dataclass_container_type_annotations

def test_dataclass_container_unions():
def test_dataclass_container_type_annotations():
from dataclasses import dataclass
from arraycontext import dataclass_array_container

from typing import Union
from typing import Optional, Tuple, Union
from arraycontext import Array

# {{{ union fields

@dataclass_array_container
@dataclass
class ArrayContainerWithUnion:
x: np.ndarray
y: Union[np.ndarray, Array]

dataclass_array_container(ArrayContainerWithUnion)

# }}}

# {{{ non-container union

@dataclass_array_container
@dataclass
class ArrayContainerWithWrongUnion:
x: np.ndarray
y: Union[np.ndarray, float]

with pytest.raises(TypeError):
# NOTE: float is not an ArrayContainer, so y should fail
dataclass_array_container(ArrayContainerWithWrongUnion)
# }}}

# {{{ optional and other fields

@dataclass_array_container
@dataclass
class ArrayContainerWithAnnotations:
x: np.ndarray
y: Tuple[float, float]
z: Optional[np.ndarray]
w: str

# }}}

Expand Down

0 comments on commit 4ba3383

Please sign in to comment.