diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index edbb4506..4f60abd2 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -30,19 +30,24 @@ THE SOFTWARE. """ -from typing import Union, get_args +from typing import Tuple, Union, get_args try: # NOTE: only available in python >= 3.8 from typing import get_origin except ImportError: from typing_extensions import get_origin -from dataclasses import fields +from dataclasses import Field, is_dataclass, fields from arraycontext.container import is_array_container_type # {{{ dataclass containers +def is_array_type(tp: type) -> bool: + from arraycontext import Array + return tp is Array or is_array_container_type(tp) + + def dataclass_array_container(cls: type) -> type: """A class decorator that makes the class to which it is applied an :class:`ArrayContainer` by registering appropriate implementations of @@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type: Attributes that are not array containers are allowed. In order to decide whether an attribute is an array container, the declared attribute type - is checked by the criteria from :func:`is_array_container_type`. + is checked by the criteria from :func:`is_array_container_type`. This + includes some support for type annotations: + + * a :class:`typing.Union` of array containers is considered an array container. + * other type annotations, e.g. :class:`typing.Optional`, are not considered + array containers, even if they wrap one. """ - from dataclasses import is_dataclass, Field + assert is_dataclass(cls) def is_array_field(f: Field) -> bool: - from arraycontext import Array + # NOTE: unions of array containers are treated separately to handle + # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as + # they can work seamlessly with arithmetic and traversal. + # + # `Optional[ArrayContainer]` is not allowed, since `None` is not + # handled by `with_container_arithmetic`, which is the common case + # for current container usage. Other type annotations, e.g. + # `Tuple[Container, Container]`, are also not allowed, as they do not + # work with `with_container_arithmetic`. + # + # This is not set in stone, but mostly driven by current usage! origin = get_origin(f.type) if origin is Union: - if not all( - arg is Array or is_array_container_type(arg) - for arg in get_args(f.type)): + if all(is_array_type(arg) for arg in get_args(f.type)): + return True + else: raise TypeError( f"Field '{f.name}' union contains non-array container " "arguments. All arguments must be array containers.") - else: - return True if __debug__: if not f.init: @@ -79,8 +97,12 @@ def is_array_field(f: Field) -> bool: raise TypeError( f"string annotation on field '{f.name}' not supported") - from typing import _SpecialForm - if isinstance(f.type, _SpecialForm): + # NOTE: + # * `_BaseGenericAlias` catches `List`, `Tuple`, etc. + # * `_SpecialForm` catches `Any`, `Literal`, etc. + from typing import ( # type: ignore[attr-defined] + _BaseGenericAlias, _SpecialForm) + if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)): # NOTE: anything except a Union is not allowed raise TypeError( f"typing annotation not supported on field '{f.name}': " @@ -91,7 +113,7 @@ def is_array_field(f: Field) -> bool: f"field '{f.name}' not an instance of 'type': " f"'{f.type!r}'") - return f.type is Array or is_array_container_type(f.type) + return is_array_type(f.type) from pytools import partition array_fields, non_array_fields = partition(is_array_field, fields(cls)) @@ -100,6 +122,27 @@ def is_array_field(f: Field) -> bool: raise ValueError(f"'{cls}' must have fields with array container type " "in order to use the 'dataclass_array_container' decorator") + return inject_dataclass_serialization(cls, array_fields, non_array_fields) + + +def inject_dataclass_serialization( + cls: type, + array_fields: Tuple[Field, ...], + non_array_fields: Tuple[Field, ...]) -> type: + """Implements :func:`~arraycontext.serialize_container` and + :func:`~arraycontext.deserialize_container` for the given dataclass *cls*. + + This function modifies *cls* in place, so the returned value is the same + object with additional functionality. + + :arg array_fields: fields of the given dataclass *cls* which are considered + array containers and should be serialized. + :arg non_array_fields: remaining fields of the dataclass *cls* which are + copied over from the template array in deserialization. + """ + + assert is_dataclass(cls) + serialize_expr = ", ".join( f"({f.name!r}, ary.{f.name})" for f in array_fields) template_kwargs = ", ".join(