diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1d0efb36..4e0ba830 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -32,6 +32,8 @@ ArrayContainer, ArrayContainerT, NotAnArrayContainerError, + SerializationKey, + SerializedContainer, deserialize_container, get_container_context_opt, get_container_context_recursively, @@ -41,7 +43,9 @@ register_multivector_as_array_container, serialize_container, ) -from .container.arithmetic import with_container_arithmetic +from .container.arithmetic import ( + with_container_arithmetic, +) from .container.dataclass import dataclass_array_container from .container.traversal import ( flat_size_and_dtype, @@ -78,6 +82,7 @@ tag_axes, ) from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program @@ -91,7 +96,6 @@ __all__ = ( - "Array", "Array", "ArrayContainer", "ArrayContainerT", @@ -105,14 +109,16 @@ "EagerJAXArrayContext", "ElementwiseMapKernelTag", "NotAnArrayContainerError", + "NumpyArrayContext", "PyOpenCLArrayContext", "PytatoJAXArrayContext", "PytatoPyOpenCLArrayContext", "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", "Scalar", - "Scalar", "ScalarLike", + "SerializationKey", + "SerializedContainer", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", @@ -146,7 +152,7 @@ "to_numpy", "unflatten", "with_array_context", - "with_container_arithmetic" + "with_container_arithmetic", ) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index ea20a5ac..38a23412 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -12,6 +12,9 @@ Serialization/deserialization ----------------------------- + +.. autoclass:: SerializationKey +.. autoclass:: SerializedContainer .. autofunction:: is_array_container_type .. autofunction:: serialize_container .. autofunction:: deserialize_container @@ -39,6 +42,14 @@ .. class:: ArrayOrContainerT :canonical: arraycontext.ArrayOrContainerT + +.. class:: SerializationKey + + :canonical: arraycontext.SerializationKey + +.. class:: SerializedContainer + + :canonical: arraycontext.SerializedContainer """ from __future__ import annotations @@ -69,12 +80,23 @@ """ from functools import singledispatch -from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + Iterable, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, +) # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy import numpy as np +from typing_extensions import TypeAlias from arraycontext.context import ArrayContext @@ -142,23 +164,27 @@ class NotAnArrayContainerError(TypeError): """:class:`TypeError` subclass raised when an array container is expected.""" +SerializationKey: TypeAlias = Hashable +SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]] + + @singledispatch def serialize_container( - ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]: - r"""Serialize the array container into an iterable over its components. + ary: ArrayContainer) -> SerializedContainer: + r"""Serialize the array container into a sequence over its components. The order of the components and their identifiers are entirely under the control of the container class. However, the order is required to be deterministic, i.e. two calls to :func:`serialize_container` on array containers of the same types with the same number of - sub-arrays must result in an iterable with the keys in the same + sub-arrays must result in a sequence with the keys in the same order. If *ary* is mutable, the serialization function is not required to ensure that the serialization result reflects the array state at the time of the call to :func:`serialize_container`. - :returns: an :class:`Iterable` of 2-tuples where the first + :returns: a :class:`Sequence` of 2-tuples where the first entry is an identifier for the component and the second entry is an array-like component of the :class:`ArrayContainer`. Components can themselves be :class:`ArrayContainer`\ s, allowing @@ -172,13 +198,13 @@ def serialize_container( @singledispatch def deserialize_container( template: ArrayContainerT, - iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT: - """Deserialize an iterable into an array container. + serialized: SerializedContainer) -> ArrayContainerT: + """Deserialize a sequence into an array container following a *template*. :param template: an instance of an existing object that can be used to aid in the deserialization. For a similar choice see :attr:`~numpy.class.__array_finalize__`. - :param iterable: an iterable that mirrors the output of + :param serialized: a sequence that mirrors the output of :meth:`serialize_container`. """ raise NotAnArrayContainerError( @@ -218,7 +244,11 @@ def is_array_container(ary: Any) -> bool: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # type:ignore[attr-defined] + # numpy values with scalar elements aren't array containers + and not (isinstance(ary, np.ndarray) + and ary.dtype.kind != "O") + ) @singledispatch @@ -238,7 +268,7 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: @serialize_container.register(np.ndarray) def _serialize_ndarray_container( - ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]: + ary: numpy.ndarray) -> SerializedContainer: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") @@ -252,20 +282,20 @@ def _serialize_ndarray_container( for j in range(ary.shape[1]) ] else: - return np.ndenumerate(ary) + return list(np.ndenumerate(ary)) @deserialize_container.register(np.ndarray) # https://github.com/python/mypy/issues/13040 def _deserialize_ndarray_container( # type: ignore[misc] template: numpy.ndarray, - iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray: + serialized: SerializedContainer) -> numpy.ndarray: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" result = type(template)(template.shape, dtype=object) - for i, subary in iterable: + for i, subary in serialized: result[i] = subary return result diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 2ef5ddc9..9366b260 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -1,12 +1,13 @@ # mypy: disallow-untyped-defs +from __future__ import annotations -""" + +__doc__ = """ .. currentmodule:: arraycontext + .. autofunction:: with_container_arithmetic """ -import enum - __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -32,7 +33,8 @@ THE SOFTWARE. """ -from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union +import enum +from typing import Any, Callable, Optional, Tuple, TypeVar, Union from warnings import warn import numpy as np @@ -98,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_binary_op_str(op_str: str, - arg1: Union[Tuple[str, ...], str], - arg2: Union[Tuple[str, ...], str]) -> str: + arg1: Union[Tuple[str, str], str], + arg2: Union[Tuple[str, str], str]) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): import sys if sys.version_info >= (3, 10): @@ -126,45 +128,71 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) -class _FailSafe: +class NumpyObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype == object + + +class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass): + pass + + +class ComplainingNumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + if isinstance(instance, np.ndarray) and instance.dtype != object: + # Example usage site: + # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149 + # where normal is passed in by test_lfr_flux as a 'custom-made' + # numpy array of dtype float64. + warn( + "Broadcasting container against non-object numpy array. " + "This was never documented to work and will now stop working in " + "2025. Convert the array to an object array to preserve the " + "current semantics.", DeprecationWarning, stacklevel=3) + return True + else: + return False + + +class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass): pass def with_container_arithmetic( - *, - bcast_number: bool = True, - _bcast_actx_array_type: Optional[bool] = None, - bcast_obj_array: Optional[bool] = None, - bcast_numpy_array: bool = False, - bcast_container_types: Optional[Tuple[type, ...]] = None, - arithmetic: bool = True, - matmul: bool = False, - bitwise: bool = False, - shift: bool = False, - _cls_has_array_context_attr: Optional[bool] = None, - eq_comparison: Optional[bool] = None, - rel_comparison: Optional[bool] = None) -> Callable[[type], type]: + *, + number_bcasts_across: Optional[bool] = None, + bcasts_across_obj_array: Optional[bool] = None, + container_types_bcast_across: Optional[Tuple[type, ...]] = None, + arithmetic: bool = True, + matmul: bool = False, + bitwise: bool = False, + shift: bool = False, + _cls_has_array_context_attr: Optional[bool] = None, + eq_comparison: Optional[bool] = None, + rel_comparison: Optional[bool] = None, + + # deprecated: + bcast_number: Optional[bool] = None, + bcast_obj_array: Optional[bool] = None, + bcast_numpy_array: bool = False, + _bcast_actx_array_type: Optional[bool] = None, + bcast_container_types: Optional[Tuple[type, ...]] = None, + ) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. - :arg bcast_number: If *True*, numbers broadcast over the container + :arg number_bcasts_across: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). - :arg _bcast_actx_array_type: If *True*, instances of base array types of the - container's array context are broadcasted over the container. Can be - *True* only if the container has *_cls_has_array_context_attr* set. - Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, - else *False*. - :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over - the container. (with the container as the 'inner' structure) - :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast - over the container. (with the container as the 'inner' structure) - If this is set to *True*, *bcast_obj_array* must also be *True*. - :arg bcast_container_types: A sequence of container types that will broadcast - over this container (with this container as the 'outer' structure). + :arg bcasts_across_obj_array: If *True*, this container will be broadcast + across :mod:`numpy` object arrays + (with the object array as the 'outer' structure). + Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve + the 'reverse' broadcasting. + :arg container_types_bcast_across: A sequence of container types that will broadcast + across this container, with this container as the 'outer' structure. :class:`numpy.ndarray` is permitted to be part of this sequence to - indicate that, in such broadcasting situations, this container should - be the 'outer' structure. In this case, *bcast_obj_array* - (and consequently *bcast_numpy_array*) must be *False*. + indicate that object arrays (and *only* object arrays) will be broadcast. + In this case, *bcasts_across_obj_array* must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -206,13 +234,18 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): should nest "outside" :func:dataclass_array_container`. """ - # {{{ handle inputs + # Hard-won design lessons: + # + # - Anything that special-cases np.ndarray by type is broken by design because: + # - np.ndarray is an array context array. + # - numpy object arrays can be array containers. + # Using NumpyObjectArray and NumpyNonObjectArray *may* be better? + # They're new, so there is no operational experience with them. + # + # - Broadcast rules are hard to change once established, particularly + # because one cannot grep for their use. - if bcast_obj_array is None: - raise TypeError("bcast_obj_array must be specified") - - if rel_comparison is None: - raise TypeError("rel_comparison must be specified") + # {{{ handle inputs if rel_comparison and eq_comparison is None: eq_comparison = True @@ -220,27 +253,104 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if eq_comparison is None: raise TypeError("eq_comparison must be specified") - if not bcast_obj_array and bcast_numpy_array: + # {{{ handle bcast_number + + if bcast_number is not None: + if number_bcasts_across is not None: + raise TypeError( + "may specify at most one of 'bcast_number' and " + "'number_bcasts_across'") + + warn("'bcast_number' is deprecated and will be unsupported from 2025. " + "Use 'number_bcasts_across', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + number_bcasts_across = bcast_number + else: + if number_bcasts_across is None: + number_bcasts_across = True + + del bcast_number + + # }}} + + # {{{ handle bcast_obj_array + + if bcast_obj_array is not None: + if bcasts_across_obj_array is not None: + raise TypeError( + "may specify at most one of 'bcast_obj_array' and " + "'bcasts_across_obj_array'") + + warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. " + "Use 'bcasts_across_obj_array', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + bcasts_across_obj_array = bcast_obj_array + else: + if bcasts_across_obj_array is None: + raise TypeError("bcasts_across_obj_array must be specified") + + del bcast_obj_array + + # }}} + + # {{{ handle bcast_container_types + + if bcast_container_types is not None: + if container_types_bcast_across is not None: + raise TypeError( + "may specify at most one of 'bcast_container_types' and " + "'container_types_bcast_across'") + + warn("'bcast_container_types' is deprecated and will be unsupported from 2025. " + "Use 'container_types_bcast_across', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + container_types_bcast_across = bcast_container_types + else: + if container_types_bcast_across is None: + container_types_bcast_across = () + + del bcast_container_types + + # }}} + + if rel_comparison is None: + raise TypeError("rel_comparison must be specified") + + if bcast_numpy_array: + warn("'bcast_numpy_array=True' is deprecated and will be unsupported" + " from 2025.", DeprecationWarning, stacklevel=2) + + if _bcast_actx_array_type: + raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" + " cannot be both set.") + + if not bcasts_across_obj_array and bcast_numpy_array: raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") if bcast_numpy_array: def numpy_pred(name: str) -> str: - return f"isinstance({name}, np.ndarray)" - elif bcast_obj_array: + return f"is_numpy_array({name})" + elif bcasts_across_obj_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" else: def numpy_pred(name: str) -> str: return "False" # optimized away - if bcast_container_types is None: - bcast_container_types = () - bcast_container_types_count = len(bcast_container_types) - - if np.ndarray in bcast_container_types and bcast_obj_array: + if np.ndarray in container_types_bcast_across and bcasts_across_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") + numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray] + container_types_bcast_across = tuple( + new_ct + for old_ct in container_types_bcast_across + for new_ct in + (numpy_check_types + if old_ct is np.ndarray + else [old_ct]) + ) + desired_op_classes = set() if arithmetic: desired_op_classes.add(_OpClass.ARITHMETIC) @@ -258,34 +368,33 @@ def numpy_pred(name: str) -> str: # }}} def wrap(cls: Any) -> Any: - cls_has_array_context_attr: Optional[Union[bool, Type[_FailSafe]]] = \ - _cls_has_array_context_attr - bcast_actx_array_type: Optional[Union[bool, Type[_FailSafe]]] = \ - _bcast_actx_array_type + if not hasattr(cls, "__array_ufunc__"): + warn(f"{cls} does not have __array_ufunc__ set. " + "This will cause numpy to attempt broadcasting, in a way that " + "is likely undesired. " + f"To avoid this, set __array_ufunc__ = None in {cls}.", + stacklevel=2) + + cls_has_array_context_attr: bool | None = _cls_has_array_context_attr + bcast_actx_array_type: bool | None = _bcast_actx_array_type if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): - cls_has_array_context_attr = _FailSafe - warn(f"{cls} has an 'array_context' attribute, but it does not " - "set '_cls_has_array_context_attr' to 'True' when calling " - "'with_container_arithmetic'. This is being interpreted " - "as 'array_context' being permitted to fail. Tolerating " - "these failures comes at a substantial cost. It is " - "deprecated and will stop working in 2023. " - "Having a working 'array_context' attribute is desirable " - "to enable arithmetic with other array types supported " - "by the array context. " - f"If '{cls.__name__}.array_context' will not fail, pass " + raise TypeError( + f"{cls} has an 'array_context' attribute, but it does not " + "set '_cls_has_array_context_attr' to *True* when calling " + "with_container_arithmetic. This is being interpreted " + "as '.array_context' being permitted to fail " + "with an exception, which is no longer allowed. " + f"If {cls.__name__}.array_context will not fail, pass " "'_cls_has_array_context_attr=True'. " "If you do not want container arithmetic to make " "use of the array context, set " - "'_cls_has_array_context_attr=False'.", - stacklevel=2) + "'_cls_has_array_context_attr=False'.") if bcast_actx_array_type is None: if cls_has_array_context_attr: - if bcast_number: - # copy over _FailSafe if present + if number_bcasts_across: bcast_actx_array_type = cls_has_array_context_attr else: bcast_actx_array_type = False @@ -294,6 +403,30 @@ def wrap(cls: Any) -> Any: raise TypeError("_bcast_actx_array_type can be True only if " "_cls_has_array_context_attr is set.") + if bcast_actx_array_type: + if _bcast_actx_array_type: + warn( + f"Broadcasting array context array types across {cls} " + "has been explicitly " + "enabled. As of 2025, this will stop working. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/pull/190. " + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False. ", + DeprecationWarning, stacklevel=2) + else: + warn( + f"Broadcasting array context array types across {cls} " + "has been implicitly " + "enabled. As of 2025, this will no longer work. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/pull/190. " + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False.", + DeprecationWarning, stacklevel=2) + if (not hasattr(cls, "_serialize_init_arrays_code") or not hasattr(cls, "_deserialize_init_arrays_code")): raise TypeError(f"class '{cls.__name__}' must provide serialization " @@ -302,20 +435,12 @@ def wrap(cls: Any) -> Any: "'_deserialize_init_arrays_code'. If this is a dataclass, " "use the 'dataclass_array_container' decorator first.") - if cls_has_array_context_attr is _FailSafe: - def actx_getter_code(arg: str) -> str: - return f"_get_actx({arg})" - else: - def actx_getter_code(arg: str) -> str: - return f"{arg}.array_context" - from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() - gen(""" + gen(f""" from numbers import Number import numpy as np - from arraycontext import ( - ArrayContainer, get_container_context_recursively) + from arraycontext import ArrayContainer from warnings import warn def _raise_if_actx_none(actx): @@ -324,56 +449,34 @@ def _raise_if_actx_none(actx): "cannot be operated upon") return actx - def _get_actx(ary): - try: - return ary.array_context - except Exception as e: - warn(f"Accessing '{type(ary).__name__}.array_context' failed " - f"({type(e)}: {e}). This should not happen and is " - "deprecated. " - "Please fix the implementation of " - f"'{type(ary).__name__}.array_context' " - "and then set _cls_has_array_context_attr=True when " - "calling with_container_arithmetic to avoid the run time " - "cost of the check that gave you this warning. " - "Using expensive recovery for now.", - DeprecationWarning, stacklevel=3) - - return get_container_context_recursively(ary) - - def _get_actx_array_types_failsafe(ary): - try: - actx = ary.array_context - except Exception as e: - warn(f"Accessing '{type(ary).__name__}.array_context' failed " - f"({type(e)}: {e}). This should not happen and is " - "deprecated. " - "Please fix the implementation of " - f"'{type(ary).__name__}.array_context' " - "and then set _cls_has_array_context_attr=True when " - "calling with_container_arithmetic to avoid the run time " - "cost of the check that gave you this warning. " - "Using expensive recovery for now.", - DeprecationWarning, stacklevel=3) - - actx = get_container_context_recursively(ary) - - if actx is None: - return () + def is_numpy_array(arg): + if isinstance(arg, np.ndarray): + if arg.dtype != "O": + warn("Operand is a non-object numpy array, " + "and the broadcasting behavior of this array container " + "({cls}) " + "is influenced by this because of its use of " + "the deprecated bcast_numpy_array. This broadcasting " + "behavior will change in 2025. If you would like the " + "broadcasting behavior to stay the same, make sure " + "to convert the passed numpy array to an " + "object array.", + DeprecationWarning, stacklevel=3) + return True + else: + return False - return actx.array_types """) gen("") - if bcast_container_types: - for i, bct in enumerate(bcast_container_types): + if container_types_bcast_across: + for i, bct in enumerate(container_types_bcast_across): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") - outer_bcast_type_names = tuple([ - f"_bctype{i}" for i in range(bcast_container_types_count) - ]) - if bcast_number: - outer_bcast_type_names += ("Number",) + container_type_names_bcast_across = tuple( + f"_bctype{i}" for i in range(len(container_types_bcast_across))) + if number_bcasts_across: + container_type_names_bcast_across += ("Number",) def same_key(k1: T, k2: T) -> T: assert k1 == k2 @@ -385,9 +488,14 @@ def tup_str(t: Tuple[str, ...]) -> str: else: return "({},)".format(", ".join(t)) - gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}") + gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}") + gen("cls._container_types_bcast_across = " + f"{tup_str(container_type_names_bcast_across)}") + gen(f"cls._bcast_numpy_array = {bcast_numpy_array}") - gen(f"cls._bcast_obj_array = {bcast_obj_array}") + + gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}") + gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}") gen("") # {{{ unary operators @@ -431,8 +539,6 @@ def {fname}(arg1): continue - # {{{ "forward" binary operators - zip_init_args = cls._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2) @@ -440,20 +546,27 @@ def {fname}(arg1): cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg2").items()) }) - bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { + bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") for key_arg1, expr_arg1 in cls._serialize_init_arrays_code("arg1").items() }) + bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", { + key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2) + for key_arg2, expr_arg2 in + cls._serialize_init_arrays_code("arg2").items() + }) + + # {{{ "forward" binary operators gen(f"def {fname}(arg1, arg2):") with Indentation(gen): gen("if arg2.__class__ is cls:") with Indentation(gen): if __debug__ and cls_has_array_context_attr: - gen(f""" - arg1_actx = {actx_getter_code("arg1")} - arg2_actx = {actx_getter_code("arg2")} + gen(""" + arg1_actx = arg1.array_context + arg2_actx = arg2.array_context if arg1_actx is not arg2_actx: msg = ("array contexts of both arguments " "must match") @@ -469,31 +582,41 @@ def {fname}(arg1): raise ValueError(msg)""") gen(f"return cls({zip_init_args})") - if bcast_actx_array_type is _FailSafe: - bcast_actx_ary_types: Tuple[str, ...] = ( - "*_get_actx_array_types_failsafe(arg1)",) - elif bcast_actx_array_type: + if bcast_actx_array_type: if __debug__: - bcast_actx_ary_types = ( + bcast_actx_ary_types: tuple[str, ...] = ( "*_raise_if_actx_none(" - f"{actx_getter_code('arg1')}).array_types",) + "arg1.array_context).array_types",) else: bcast_actx_ary_types = ( - f"*{actx_getter_code('arg1')}.array_types",) + "*arg1.array_context.array_types",) else: bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + + if {bool(container_type_names_bcast_across)}: # optimized away + if isinstance(arg2, + {tup_str(container_type_names_bcast_across + + bcast_actx_ary_types)}): + if __debug__: + if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg2)}} is deprecated " + "and will no longer work in 2025. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/" + "pull/190. ", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg1_is_outer}) + return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -505,24 +628,15 @@ def {fname}(arg1): if reversible: fname = f"_{cls.__name__.lower()}_r{dunder_name}" - bcast_init_args = cls._deserialize_init_arrays_code("arg2", { - key_arg2: _format_binary_op_str( - op_str, "arg1", expr_arg2) - for key_arg2, expr_arg2 in - cls._serialize_init_arrays_code("arg2").items() - }) - - if bcast_actx_array_type is _FailSafe: - bcast_actx_ary_types = ( - "*_get_actx_array_types_failsafe(arg2)",) - elif bcast_actx_array_type: + + if bcast_actx_array_type: if __debug__: bcast_actx_ary_types = ( "*_raise_if_actx_none(" - f"{actx_getter_code('arg2')}).array_types",) + "arg2.array_context).array_types",) else: bcast_actx_ary_types = ( - f"*{actx_getter_code('arg2')}.array_types",) + "*arg2.array_context.array_types",) else: bcast_actx_ary_types = () @@ -530,16 +644,30 @@ def {fname}(arg1): def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(container_type_names_bcast_across)}: # optimized away + if isinstance(arg1, + {tup_str(container_type_names_bcast_across + + bcast_actx_ary_types)}): + if __debug__: + if isinstance(arg1, + {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg1)}} " + "is deprecated " + "and will no longer work in 2025." + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/" + "pull/190. ", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg2_is_outer}) + return NotImplemented cls.__r{dunder_name}__ = {fname}""") diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 4a60a8f9..a7547df2 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -43,6 +43,8 @@ from __future__ import annotations +from arraycontext.container.arithmetic import NumpyObjectArray + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -77,6 +79,7 @@ from arraycontext.container import ( ArrayContainer, NotAnArrayContainerError, + SerializationKey, deserialize_container, get_container_context_recursively_opt, serialize_container, @@ -373,12 +376,9 @@ def wrapper(*args: Any) -> Any: # {{{ keyed array container traversal -KeyType = Any - - def keyed_map_array_container( f: Callable[ - [KeyType, ArrayOrContainer], + [SerializationKey, ArrayOrContainer], ArrayOrContainer], ary: ArrayOrContainer) -> ArrayOrContainer: r"""Applies *f* to all components of an :class:`ArrayContainer`. @@ -403,7 +403,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT], + f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -412,7 +412,7 @@ def rec_keyed_map_array_container( the current array. """ - def rec(keys: Tuple[Union[str, int], ...], + def rec(keys: Tuple[SerializationKey, ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) @@ -949,8 +949,7 @@ def outer(a: Any, b: Any) -> Any: Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer` always returns a matrix). Here the definition of "scalar" includes - all non-array-container types and any scalar-like array container types - (including non-object numpy arrays). + all non-array-container types and any scalar-like array container types. If *a* and *b* are both array containers, the result will have the same type as *a*. If both are array containers and neither is an object array, they must @@ -966,14 +965,21 @@ def treat_as_scalar(x: Any) -> bool: return ( not isinstance(x, np.ndarray) # This condition is whether "ndarrays should broadcast inside x". - and np.ndarray not in x.__class__._outer_bcast_types) + and NumpyObjectArray not in x.__class__._outer_bcast_types) + + a_is_ndarray = isinstance(a, np.ndarray) + b_is_ndarray = isinstance(b, np.ndarray) + + if a_is_ndarray and a.dtype != object: + raise TypeError("passing a non-object numpy array is not allowed") + if b_is_ndarray and b.dtype != object: + raise TypeError("passing a non-object numpy array is not allowed") if treat_as_scalar(a) or treat_as_scalar(b): return a*b - # After this point, "isinstance(o, ndarray)" means o is an object array. - elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + elif a_is_ndarray and b_is_ndarray: return np.outer(a, b) - elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + elif a_is_ndarray or b_is_ndarray: return map_array_container(lambda x: outer(x, b), a) else: if type(a) is not type(b): diff --git a/arraycontext/context.py b/arraycontext/context.py index 8b42bca7..30f58cb1 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -75,8 +75,8 @@ .. currentmodule:: arraycontext -The interface of an array context ---------------------------------- +The :class:`ArrayContext` Interface +----------------------------------- .. autoclass:: ArrayContext @@ -217,7 +217,11 @@ def size(self) -> int: def dtype(self) -> "np.dtype[Any]": ... - def __getitem__(self, index: Union[slice, int]) -> "Array": + # Covering all the possible index variations is hard and (kind of) futile. + # If you'd like to see how, try changing the Any to + # AxisIndex = slice | int | "Array" + # Index = AxisIndex |tuple[AxisIndex] + def __getitem__(self, index: Any) -> "Array": ... @@ -274,8 +278,10 @@ class ArrayContext(ABC): A :class:`tuple` of types that are the valid array classes the context can operate on. However, it is not necessary that *all* the - :class:`ArrayContext`\ 's operations would be legal for the types in - *array_types*. + :class:`ArrayContext`\ 's operations are legal for the types in + *array_types*. Note that this tuple is *only* intended for use + with :func:`isinstance`. Other uses are not allowed. This allows + for 'types' with overridden :meth:`class.__instancecheck__`. .. automethod:: freeze .. automethod:: thaw @@ -297,12 +303,6 @@ def _get_fake_numpy_namespace(self) -> Any: def __hash__(self) -> int: raise TypeError(f"unhashable type: '{type(self).__name__}'") - @abstractmethod - def empty(self, - shape: Union[int, Tuple[int, ...]], - dtype: "np.dtype[Any]") -> Array: - pass - def zeros(self, shape: Union[int, Tuple[int, ...]], dtype: "np.dtype[Any]") -> Array: @@ -312,20 +312,6 @@ def zeros(self, return self.np.zeros(shape, dtype) - def empty_like(self, ary: Array) -> Array: - warn(f"{type(self).__name__}.empty_like is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.empty(shape=ary.shape, dtype=ary.dtype) - - def zeros_like(self, ary: Array) -> Array: - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.zeros(shape=ary.shape, dtype=ary.dtype) - @abstractmethod def from_numpy(self, array: NumpyOrContainerOrScalar @@ -353,7 +339,7 @@ def to_numpy(self, @abstractmethod def call_loopy(self, - program: "loopy.TranslationUnit", + t_unit: "loopy.TranslationUnit", **kwargs: Any) -> Dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. @@ -489,7 +475,7 @@ def einsum(self, :return: the output of the einsum :mod:`loopy` program """ if arg_names is None: - arg_names = tuple([f"arg{i}" for i in range(len(args))]) + arg_names = tuple(f"arg{i}" for i in range(len(args))) prg = self._get_einsum_prg(spec, arg_names, tagged) out_ary = self.call_loopy( diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 03045419..26cb9db5 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -87,38 +87,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def empty(self, shape, dtype): - from warnings import warn - warn(f"{type(self).__name__}.empty is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros instead.", - DeprecationWarning, stacklevel=2) - - import jax.numpy as jnp - return jnp.empty(shape=shape, dtype=dtype) - - def zeros(self, shape, dtype): - import jax.numpy as jnp - return jnp.zeros(shape=shape, dtype=dtype) - - def empty_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.empty_like is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - def _empty_like(array): - return self.empty(array.shape, array.dtype) - - return self._rec_map_container(_empty_like, ary) - - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): def _from_numpy(ary): import jax diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 3fc5f2e6..bc9481e3 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -27,12 +27,16 @@ import jax.numpy as jnp -from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container import ( + NotAnArrayContainerError, + serialize_container, +) from arraycontext.container.traversal import ( rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace @@ -156,29 +160,35 @@ def any(self, a): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) def rec_equal(x, y): if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: if x.shape != y.shape: - return false + return false_ary else: return jnp.all(jnp.equal(x, y)) else: + if len(serialized_x) != len(serialized_y): + return false_ary return reduce( jnp.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) return rec_equal(a, b) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 00000000..f8ba95e3 --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,168 @@ +from __future__ import annotations + + +""" +.. currentmodule:: arraycontext + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Any + +import numpy as np + +import loopy as lp +from pytools.tag import ToTagSetConvertible + +from arraycontext.container.traversal import rec_map_array_container, with_array_context +from arraycontext.context import ( + Array, + ArrayContext, + ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, + NumpyOrContainerOrScalar, + UntransformedCodeWarning, +) + + +class NumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype != object + + +class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass): + pass + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays. + + .. automethod:: __init__ + """ + + _loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase] + + def __init__(self) -> None: + super().__init__() + self._loopy_transform_cache = {} + + array_types = (NumpyNonObjectArray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def from_numpy(self, + array: NumpyOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: + return array + + def to_numpy(self, + array: ArrayOrContainerOrScalar + ) -> NumpyOrContainerOrScalar: + return array + + def call_loopy( + self, + t_unit: lp.TranslationUnit, **kwargs: Any + ) -> dict[str, Array]: + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + executor = self._loopy_transform_cache[t_unit] + except KeyError: + executor = self.transform_loopy_program(t_unit).executor() + self._loopy_transform_cache[t_unit] = executor + + _, result = executor(**kwargs) + + return result + + def freeze(self, array): + def _freeze(ary): + return ary + + return with_array_context(rec_map_array_container(_freeze, array), actx=None) + + def thaw(self, array): + def _thaw(ary): + return ary + + return with_array_context(rec_map_array_container(_thaw, array), actx=self) + + # }}} + + def transform_loopy_program(self, t_unit): + from warnings import warn + warn("Using the base " + f"{type(self).__name__}.transform_loopy_program " + "to transform a translation unit. " + "This is a no-op and will result in unoptimized C code for" + "the requested optimization, all in a single statement." + "This will work, but is unlikely to be performant." + f"Instead, subclass {type(self).__name__} and implement " + "the specific transform logic required to transform the program " + "for your package or application. Check higher-level packages " + "(e.g. meshmode), which may already have subclasses you may want " + "to build on.", + UntransformedCodeWarning, stacklevel=2) + + return t_unit + + def tag(self, + tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + # Numpy doesn't support tagging + return array + + def tag_axis(self, + iaxis: int, tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + # Numpy doesn't support tagging + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 00000000..b305717e --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,169 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np + +from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container.traversal import ( + rec_map_array_container, + rec_map_reduce_array_container, + rec_multimap_array_container, + rec_multimap_reduce_array_container, +) +from arraycontext.context import Array, ArrayOrContainer +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, + BaseFakeNumpyNamespace, +) + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = frozenset({"concatenate", "reshape", "transpose", + "ones_like", "where", + *BaseFakeNumpyNamespace._numpy_math_functions + }) + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + raise AttributeError(name) + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y): + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: + false_ary = np.array(False) + true_ary = np.array(True) + if type(a) is not type(b): + return false_ary + + try: + serialized_x = serialize_container(a) + serialized_y = serialize_container(b) + except NotAnArrayContainerError: + assert isinstance(a, np.ndarray) + assert isinstance(b, np.ndarray) + return np.array(np.array_equal(a, b)) + else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( + np.logical_and, + [(true_ary if kx_i == ky_i else false_ary) + and self.array_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) + + def arange(self, *args, **kwargs): + return np.arange(*args, **kwargs) + + def linspace(self, *args, **kwargs): + return np.linspace(*args, **kwargs) + + def zeros_like(self, ary): + return rec_map_array_container(np.zeros_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + + +# vim: fdm=marker diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 9be77a44..de188cbd 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -198,41 +198,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def empty(self, shape, dtype): - from warnings import warn - warn(f"{type(self).__name__}.empty is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros instead.", - DeprecationWarning, stacklevel=2) - - import arraycontext.impl.pyopencl.taggable_cl_array as tga - return tga.empty(self.queue, shape, dtype, allocator=self.allocator) - - def zeros(self, shape, dtype): - import arraycontext.impl.pyopencl.taggable_cl_array as tga - return tga.zeros(self.queue, shape, dtype, allocator=self.allocator) - - def empty_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.empty_like is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - import arraycontext.impl.pyopencl.taggable_cl_array as tga - - def _empty_like(array): - return tga.empty(self.queue, array.shape, array.dtype, - allocator=self.allocator, axes=array.axes, tags=array.tags) - - return self._rec_map_container(_empty_like, ary) - - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 59be99e8..848870a9 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -38,6 +38,7 @@ rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -215,30 +216,40 @@ def _any(ary): result = result.get()[()] return result - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context queue = actx.queue # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, cl_array.Array) + assert isinstance(y, cl_array.Array) + if x.shape != y.shape: - return false + return false_ary else: return (x == y).all() else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( partial(cl_array.minimum, queue=queue), - [rec_equal(x_i, y_i)for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) result = rec_equal(a, b) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8737e5fa..099738a9 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -171,22 +171,6 @@ def _frozen_array_types(self) -> Tuple[Type, ...]: Returns valid frozen array types for the array context. """ - # {{{ ArrayContext interface - - def empty(self, shape, dtype): - raise NotImplementedError( - f"{type(self).__name__}.empty is not supported") - - def zeros(self, shape, dtype): - import pytato as pt - return pt.zeros(shape, dtype) - - def empty_like(self, ary): - raise NotImplementedError( - f"{type(self).__name__}.empty_like is not supported") - - # }}} - # {{{ compilation def transform_dag(self, dag: pytato.DictOfNamedArrays @@ -210,7 +194,7 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU "to transform a translation unit. " "This is a no-op and will result in unoptimized C code for" "the requested optimization, all in a single statement." - "This will work, but is unlikely to be performatn." + "This will work, but is unlikely to be performant." f"Instead, subclass {type(self).__name__} and implement " "the specific transform logic required to transform the program " "for your package or application. Check higher-level packages " @@ -380,14 +364,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): import pytato as pt @@ -776,14 +752,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): import jax import pytato as pt diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 523ef3a4..54d2cbb8 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -214,10 +214,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): :attr:`BaseLazilyCompilingFunctionCaller.f`. """ if np.isscalar(arg): - name = arg_id_to_name[(kw,)] + name = arg_id_to_name[kw,] return pt.make_placeholder(name, (), np.dtype(type(arg))) elif isinstance(arg, pt.Array): - name = arg_id_to_name[(kw,)] + name = arg_id_to_name[kw,] # Transform the DAG to give metadata inference a chance to do its job arg = _to_input_for_compiled(arg, actx) return pt.make_placeholder(name, arg.shape, arg.dtype, @@ -225,7 +225,8 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): tags=arg.tags) elif is_array_container_type(arg.__class__): def _rec_to_placeholder(keys, ary): - name = arg_id_to_name[(kw, *keys)] + index = (kw, *keys) + name = arg_id_to_name[index] # Transform the DAG to give metadata inference a chance to do its job ary = _to_input_for_compiled(ary, actx) return pt.make_placeholder(name, diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d3d018d6..c6508e3a 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -22,7 +22,7 @@ THE SOFTWARE. """ from functools import partial, reduce -from typing import Any +from typing import Any, cast import numpy as np @@ -34,6 +34,7 @@ rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -171,31 +172,41 @@ def any(self, a): partial(reduce, pt.logical_or), lambda subary: pt.any(subary), a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, pt.Array) + assert isinstance(y, pt.Array) + if x.shape != y.shape: - return false + return false_ary else: - return pt.all(pt.equal(x, y)) + return pt.all(cast(pt.Array, pt.equal(x, y))) else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( pt.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) - return rec_equal(a, b) + return cast(Array, rec_equal(a, b)) # }}} diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 0af54204..a5582d18 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,7 +23,7 @@ """ -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Mapping, Set, Tuple from pytato.array import ( AbstractResultWithNamedArrays, @@ -118,7 +118,7 @@ def __init__(self, limit_arg_size_nbytes: int) -> None: self.limit_arg_size_nbytes = limit_arg_size_nbytes @memoize_method - def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: + def get_loopy_target(self) -> "lp.PyOpenCLTarget": from loopy import PyOpenCLTarget return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index dc5d84f4..1bee3eb0 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -82,8 +82,8 @@ def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): def get(c_name, nargs, naxes): from pymbolic import var - var_names = ["i%d" % i for i in range(naxes)] - size_names = ["n%d" % i for i in range(naxes)] + var_names = [f"i{i}" for i in range(naxes)] + size_names = [f"n{i}" for i in range(naxes)] subscript = tuple(var(vname) for vname in var_names) from islpy import make_zero_and_vars v = make_zero_and_vars(var_names, params=size_names) @@ -103,12 +103,12 @@ def get(c_name, nargs, naxes): lp.Assignment( var("out")[subscript], var(c_name)(*[ - var("inp%d" % i)[subscript] for i in range(nargs)])) + var(f"inp{i}")[subscript] for i in range(nargs)])) ], [ lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto)] + [ - lp.GlobalArg("inp%d" % i, + lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto) for i in range(nargs)] + [...], name=f"actx_special_{c_name}", @@ -142,7 +142,7 @@ def loopy_implemented_elwise_func(*args): prg = _get_scalar_func_loopy_program(actx, c_name, nargs=len(args), naxes=len(args[0].shape)) outputs = actx.call_loopy(prg, - **{"inp%d" % i: arg for i, arg in enumerate(args)}) + **{f"inp{i}": arg for i, arg in enumerate(args)}) return outputs["out"] if name in self._c_to_numpy_arc_functions: diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 8a1e0274..088c7e3e 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,6 +34,7 @@ from typing import Any, Callable, Dict, Sequence, Type, Union +from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -221,6 +222,26 @@ def __str__(self): return "" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -229,6 +250,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } @@ -374,7 +396,7 @@ def inner(metafunc): # NOTE: sorts the args so that parallel pytest works arg_value_tuples = sorted([ - tuple([arg_dict[name] for name in arg_names]) + tuple(arg_dict[name] for name in arg_names) for arg_dict in arg_values_with_actx ], key=lambda x: str(x)) diff --git a/arraycontext/version.py b/arraycontext/version.py index a1da82c1..31baea05 100644 --- a/arraycontext/version.py +++ b/arraycontext/version.py @@ -8,7 +8,7 @@ def _parse_version(version: str) -> Tuple[Tuple[int, ...], str]: m = re.match("^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT) assert m is not None - return tuple([int(nr) for nr in m.group(1).split(".")]), m.group(2) + return tuple(int(nr) for nr in m.group(1).split(".")), m.group(2) VERSION_TEXT = metadata.version("arraycontext") diff --git a/doc/array_context.rst b/doc/array_context.rst index 85a6cc44..db1182e2 100644 --- a/doc/array_context.rst +++ b/doc/array_context.rst @@ -4,43 +4,3 @@ The Array Context Abstraction .. automodule:: arraycontext .. automodule:: arraycontext.context - -Implementations of the Array Context Abstraction -================================================ - -.. - When adding a new array context here, make sure to also add it to and run - ``` - doc/make_numpy_coverage_table.py - ``` - to update the coverage table below! - -Array context based on :mod:`pyopencl.array` --------------------------------------------- - -.. automodule:: arraycontext.impl.pyopencl - - -Lazy/Deferred evaluation array context based on :mod:`pytato` -------------------------------------------------------------- - -.. automodule:: arraycontext.impl.pytato - - -Array context based on :mod:`jax.numpy` ---------------------------------------- - -.. automodule:: arraycontext.impl.jax - -.. _numpy-coverage: - -:mod:`numpy` coverage ---------------------- - -This is a list of functionality implemented by :attr:`arraycontext.ArrayContext.np`. - -.. note:: - - Only functions and methods that have at least one implementation are listed. - -.. include:: numpy_coverage.rst diff --git a/doc/implementations.rst b/doc/implementations.rst new file mode 100644 index 00000000..4023e37c --- /dev/null +++ b/doc/implementations.rst @@ -0,0 +1,45 @@ +Implementations of the Array Context Abstraction +================================================ + +.. + When adding a new array context here, make sure to also add it to and run + ``` + doc/make_numpy_coverage_table.py + ``` + to update the coverage table below! + +Array context based on :mod:`numpy` +-------------------------------------------- + +.. automodule:: arraycontext.impl.numpy + +Array context based on :mod:`pyopencl.array` +-------------------------------------------- + +.. automodule:: arraycontext.impl.pyopencl + + +Lazy/Deferred evaluation array context based on :mod:`pytato` +------------------------------------------------------------- + +.. automodule:: arraycontext.impl.pytato + + +Array context based on :mod:`jax.numpy` +--------------------------------------- + +.. automodule:: arraycontext.impl.jax + +.. _numpy-coverage: + +:mod:`numpy` coverage +--------------------- + +This is a list of functionality implemented by :attr:`arraycontext.ArrayContext.np`. + +.. note:: + + Only functions and methods that have at least one implementation are listed. + +.. include:: numpy_coverage.rst + diff --git a/doc/index.rst b/doc/index.rst index 48fd25bb..d3f9854b 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -53,6 +53,7 @@ Contents .. toctree:: array_context + implementations container other misc diff --git a/doc/upload-docs.sh b/doc/upload-docs.sh index 176c9fbd..02615433 100755 --- a/doc/upload-docs.sh +++ b/doc/upload-docs.sh @@ -1,3 +1,3 @@ #! /bin/sh -rsync --verbose --archive --delete _build/html/* doc-upload:doc/arraycontext +rsync --verbose --archive --delete _build/html/ doc-upload:doc/arraycontext diff --git a/pyproject.toml b/pyproject.toml index ca64c70a..d971ae20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", + + # for TypeAlias + "typing-extensions>=4; python_version<'3.10'", ] [project.optional-dependencies] diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3f06156b..47d83903 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -22,6 +22,7 @@ import logging from dataclasses import dataclass +from functools import partial from typing import Union import numpy as np @@ -34,6 +35,7 @@ ArrayContainer, ArrayContext, EagerJAXArrayContext, + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, @@ -46,6 +48,7 @@ ) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, @@ -97,6 +100,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) @@ -113,11 +117,11 @@ def _acf(): # {{{ stand-in DOFArray implementation @with_container_arithmetic( - bcast_obj_array=True, - bcast_numpy_array=True, + bcasts_across_obj_array=True, bitwise=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) class DOFArray: def __init__(self, actx, data): if not (actx is None or isinstance(actx, ArrayContext)): @@ -129,7 +133,8 @@ def __init__(self, actx, data): self.array_context = actx self.data = data - __array_priority__ = 10 + # prevent numpy broadcasting + __array_ufunc__ = None def __bool__(self): if len(self) == 1 and self.data[0].size == 1: @@ -165,11 +170,11 @@ def size(self): @property def real(self): - return DOFArray(self.array_context, tuple([subary.real for subary in self])) + return DOFArray(self.array_context, tuple(subary.real for subary in self)) @property def imag(self): - return DOFArray(self.array_context, tuple([subary.imag for subary in self])) + return DOFArray(self.array_context, tuple(subary.imag for subary in self)) @serialize_container.register(DOFArray) @@ -203,9 +208,10 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: # {{{ nested containers -@with_container_arithmetic(bcast_obj_array=False, +@with_container_arithmetic(bcasts_across_obj_array=False, eq_comparison=False, rel_comparison=False, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainer: @@ -214,6 +220,8 @@ class MyContainer: momentum: np.ndarray enthalpy: Union[DOFArray, np.ndarray] + __array_ufunc__ = None + @property def array_context(self): if isinstance(self.mass, np.ndarray): @@ -223,11 +231,12 @@ def array_context(self): @with_container_arithmetic( - bcast_obj_array=False, + bcasts_across_obj_array=False, bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainerDOFBcast: @@ -249,9 +258,8 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000): if isinstance(shapes, (Number, tuple)): shapes = [shapes] - x = DOFArray(actx, tuple([ - actx.from_numpy(randn(shape, np.float64)) - for shape in shapes])) + x = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) + for shape in shapes)) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter dataclass_of_dofs = MyContainer( @@ -934,8 +942,6 @@ def test_container_arithmetic(actx_factory): def _check_allclose(f, arg1, arg2, atol=5.0e-14): assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol - from functools import partial - from arraycontext import rec_multimap_array_container for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: rec_multimap_array_container( @@ -1074,13 +1080,11 @@ def test_flatten_array_container(actx_factory, shapes): if isinstance(shapes, (int, tuple)): shapes = [shapes] - ary = DOFArray(actx, tuple([ - actx.from_numpy(randn(shape, np.float64)) - for shape in shapes])) + ary = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) + for shape in shapes)) - template = DOFArray(actx, tuple([ - actx.from_numpy(randn(shape, np.complex128)) - for shape in shapes])) + template = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.complex128)) + for shape in shapes)) flat = flatten(ary, actx) ary_roundtrip = unflatten(template, flat, actx, strict=False) @@ -1111,9 +1115,10 @@ def test_flatten_array_container_failure(actx_factory): ary = _get_test_containers(actx, shapes=512)[0] flat_ary = _checked_flatten(ary, actx) - with pytest.raises(TypeError): - # cannot unflatten from a numpy array - unflatten(ary, actx.to_numpy(flat_ary), actx) + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + # cannot unflatten from a numpy array (except for numpy actx) + unflatten(ary, actx.to_numpy(flat_ary), actx) with pytest.raises(ValueError): # cannot unflatten non-flat arrays @@ -1169,16 +1174,17 @@ def test_numpy_conversion(actx_factory): assert np.allclose(ac.mass, ac_roundtrip.mass) assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) - from dataclasses import replace - ac_with_cl = replace(ac, enthalpy=ac_actx.mass) - with pytest.raises(TypeError): - actx.from_numpy(ac_with_cl) + if not isinstance(actx, NumpyArrayContext): + from dataclasses import replace + ac_with_cl = replace(ac, enthalpy=ac_actx.mass) + with pytest.raises(TypeError): + actx.from_numpy(ac_with_cl) - with pytest.raises(TypeError): - actx.from_numpy(ac_actx) + with pytest.raises(TypeError): + actx.from_numpy(ac_actx) - with pytest.raises(TypeError): - actx.to_numpy(ac) + with pytest.raises(TypeError): + actx.to_numpy(ac) # }}} @@ -1219,7 +1225,7 @@ def test_norm_ord_none(actx_factory, ndim): # {{{ test_actx_compile helpers -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) @dataclass_array_container @dataclass(frozen=True) class Velocity2D: @@ -1346,56 +1352,39 @@ def test_container_equality(actx_factory): # }}} -# {{{ test_leaf_array_type_broadcasting +# {{{ test_no_leaf_array_type_broadcasting @with_container_arithmetic( - bcast_obj_array=True, - bcast_numpy_array=True, + bcasts_across_obj_array=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class Foo: u: DOFArray + # prevent numpy arithmetic from taking precedence + __array_ufunc__ = None + @property def array_context(self): return self.u.array_context -def test_leaf_array_type_broadcasting(actx_factory): - # test support for https://github.com/inducer/arraycontext/issues/49 +def test_no_leaf_array_type_broadcasting(actx_factory): + # test lack of support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() - foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))) - bar = foo + 4 - baz = foo + actx.from_numpy(4*np.ones((3, ))) - qux = actx.from_numpy(4*np.ones((3, ))) + foo - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(baz.u[0])) - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(qux.u[0])) + dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )) + foo = Foo(dof_ary) - def _actx_allows_scalar_broadcast(actx): - if not isinstance(actx, PyOpenCLArrayContext): - return True - else: - import pyopencl as cl - - # See https://github.com/inducer/pyopencl/issues/498 - return cl.version.VERSION > (2021, 2, 5) - - if _actx_allows_scalar_broadcast(actx): - quux = foo + actx.from_numpy(np.array(4)) - quuz = actx.from_numpy(np.array(4)) + foo - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(quux.u[0])) + actx_ary = actx.from_numpy(4*np.ones((3, ))) + with pytest.raises(TypeError): + foo + actx_ary - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(quuz.u[0])) + with pytest.raises(TypeError): + foo + actx.from_numpy(np.array(4)) # }}} @@ -1465,21 +1454,12 @@ def equal(a, b): b_bcast_dc_of_dofs.momentum), enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy)) - # Non-object numpy arrays should be treated as scalars - ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass)) - assert equal( - outer(ary_of_floats, b_bcast_dc_of_dofs), - ary_of_floats*b_bcast_dc_of_dofs) - assert equal( - outer(a_bcast_dc_of_dofs, ary_of_floats), - a_bcast_dc_of_dofs*ary_of_floats) - # }}} # {{{ test_array_container_with_numpy -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) @dataclass_array_container @dataclass(frozen=True) class ArrayContainerWithNumpy: @@ -1581,8 +1561,8 @@ def test_to_numpy_on_frozen_arrays(actx_factory): def test_tagging(actx_factory): actx = actx_factory() - if isinstance(actx, EagerJAXArrayContext): - pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, (NumpyArrayContext, EagerJAXArrayContext)): + pytest.skip(f"{type(actx)} has no tagging support") from pytools.tag import Tag