Skip to content

Commit

Permalink
split dataclass_array_container for easier modification
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Jun 28, 2022
1 parent ff1cd0c commit b686676
Showing 1 changed file with 56 additions and 13 deletions.
69 changes: 56 additions & 13 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@
THE SOFTWARE.
"""

from typing import Union, get_args
from typing import Tuple, Union, get_args
try:
# NOTE: only available in python >= 3.8
from typing import get_origin
except ImportError:
from typing_extensions import get_origin

from dataclasses import fields
from dataclasses import Field, is_dataclass, fields
from arraycontext.container import is_array_container_type


# {{{ dataclass containers

def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)


def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
Expand All @@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type:
Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`.
is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
"""
from dataclasses import is_dataclass, Field

assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
from arraycontext import Array
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
#
# `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage. Other type annotations, e.g.
# `Tuple[Container, Container]`, are also not allowed, as they do not
# work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!

origin = get_origin(f.type)
if origin is Union:
if not all(
arg is Array or is_array_container_type(arg)
for arg in get_args(f.type)):
if all(is_array_type(arg) for arg in get_args(f.type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
else:
return True

if __debug__:
if not f.init:
Expand All @@ -79,8 +97,12 @@ def is_array_field(f: Field) -> bool:
raise TypeError(
f"string annotation on field '{f.name}' not supported")

from typing import _SpecialForm
if isinstance(f.type, _SpecialForm):
# NOTE:
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias, _SpecialForm)
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"typing annotation not supported on field '{f.name}': "
Expand All @@ -91,7 +113,7 @@ def is_array_field(f: Field) -> bool:
f"field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'")

return f.type is Array or is_array_container_type(f.type)
return is_array_type(f.type)

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
Expand All @@ -100,6 +122,27 @@ def is_array_field(f: Field) -> bool:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return inject_dataclass_serialization(cls, array_fields, non_array_fields)


def inject_dataclass_serialization(
cls: type,
array_fields: Tuple[Field, ...],
non_array_fields: Tuple[Field, ...]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
This function modifies *cls* in place, so the returned value is the same
object with additional functionality.
:arg array_fields: fields of the given dataclass *cls* which are considered
array containers and should be serialized.
:arg non_array_fields: remaining fields of the dataclass *cls* which are
copied over from the template array in deserialization.
"""

assert is_dataclass(cls)

serialize_expr = ", ".join(
f"({f.name!r}, ary.{f.name})" for f in array_fields)
template_kwargs = ", ".join(
Expand Down

0 comments on commit b686676

Please sign in to comment.