From 14d44766a6d1333fe53637fe98bb18e3a429c72a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 25 Aug 2024 14:36:23 -0500 Subject: [PATCH 01/32] Fix some newly-flagged UP031 issues --- arraycontext/loopy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index dc5d84f4..1bee3eb0 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -82,8 +82,8 @@ def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): def get(c_name, nargs, naxes): from pymbolic import var - var_names = ["i%d" % i for i in range(naxes)] - size_names = ["n%d" % i for i in range(naxes)] + var_names = [f"i{i}" for i in range(naxes)] + size_names = [f"n{i}" for i in range(naxes)] subscript = tuple(var(vname) for vname in var_names) from islpy import make_zero_and_vars v = make_zero_and_vars(var_names, params=size_names) @@ -103,12 +103,12 @@ def get(c_name, nargs, naxes): lp.Assignment( var("out")[subscript], var(c_name)(*[ - var("inp%d" % i)[subscript] for i in range(nargs)])) + var(f"inp{i}")[subscript] for i in range(nargs)])) ], [ lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto)] + [ - lp.GlobalArg("inp%d" % i, + lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto) for i in range(nargs)] + [...], name=f"actx_special_{c_name}", @@ -142,7 +142,7 @@ def loopy_implemented_elwise_func(*args): prg = _get_scalar_func_loopy_program(actx, c_name, nargs=len(args), naxes=len(args[0].shape)) outputs = actx.call_loopy(prg, - **{"inp%d" % i: arg for i, arg in enumerate(args)}) + **{f"inp{i}": arg for i, arg in enumerate(args)}) return outputs["out"] if name in self._c_to_numpy_arc_functions: From ee96ff5c2c078f55103b4e8865f2a0ff22ed25a0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 5 Aug 2024 11:11:59 -0500 Subject: [PATCH 02/32] Drop deprecated actx.{empty,zeros}{,_like} --- arraycontext/context.py | 20 --------------- arraycontext/impl/jax/__init__.py | 32 ----------------------- arraycontext/impl/pyopencl/__init__.py | 35 -------------------------- arraycontext/impl/pytato/__init__.py | 32 ----------------------- 4 files changed, 119 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 8b42bca7..0640e5a1 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -297,12 +297,6 @@ def _get_fake_numpy_namespace(self) -> Any: def __hash__(self) -> int: raise TypeError(f"unhashable type: '{type(self).__name__}'") - @abstractmethod - def empty(self, - shape: Union[int, Tuple[int, ...]], - dtype: "np.dtype[Any]") -> Array: - pass - def zeros(self, shape: Union[int, Tuple[int, ...]], dtype: "np.dtype[Any]") -> Array: @@ -312,20 +306,6 @@ def zeros(self, return self.np.zeros(shape, dtype) - def empty_like(self, ary: Array) -> Array: - warn(f"{type(self).__name__}.empty_like is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.empty(shape=ary.shape, dtype=ary.dtype) - - def zeros_like(self, ary: Array) -> Array: - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.zeros(shape=ary.shape, dtype=ary.dtype) - @abstractmethod def from_numpy(self, array: NumpyOrContainerOrScalar diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 03045419..26cb9db5 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -87,38 +87,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def empty(self, shape, dtype): - from warnings import warn - warn(f"{type(self).__name__}.empty is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros instead.", - DeprecationWarning, stacklevel=2) - - import jax.numpy as jnp - return jnp.empty(shape=shape, dtype=dtype) - - def zeros(self, shape, dtype): - import jax.numpy as jnp - return jnp.zeros(shape=shape, dtype=dtype) - - def empty_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.empty_like is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - def _empty_like(array): - return self.empty(array.shape, array.dtype) - - return self._rec_map_container(_empty_like, ary) - - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): def _from_numpy(ary): import jax diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 9be77a44..de188cbd 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -198,41 +198,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def empty(self, shape, dtype): - from warnings import warn - warn(f"{type(self).__name__}.empty is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros instead.", - DeprecationWarning, stacklevel=2) - - import arraycontext.impl.pyopencl.taggable_cl_array as tga - return tga.empty(self.queue, shape, dtype, allocator=self.allocator) - - def zeros(self, shape, dtype): - import arraycontext.impl.pyopencl.taggable_cl_array as tga - return tga.zeros(self.queue, shape, dtype, allocator=self.allocator) - - def empty_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.empty_like is deprecated and will stop " - "working in 2023. Prefer actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - import arraycontext.impl.pyopencl.taggable_cl_array as tga - - def _empty_like(array): - return tga.empty(self.queue, array.shape, array.dtype, - allocator=self.allocator, axes=array.axes, tags=array.tags) - - return self._rec_map_container(_empty_like, ary) - - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8737e5fa..5ece78e9 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -171,22 +171,6 @@ def _frozen_array_types(self) -> Tuple[Type, ...]: Returns valid frozen array types for the array context. """ - # {{{ ArrayContext interface - - def empty(self, shape, dtype): - raise NotImplementedError( - f"{type(self).__name__}.empty is not supported") - - def zeros(self, shape, dtype): - import pytato as pt - return pt.zeros(shape, dtype) - - def empty_like(self, ary): - raise NotImplementedError( - f"{type(self).__name__}.empty_like is not supported") - - # }}} - # {{{ compilation def transform_dag(self, dag: pytato.DictOfNamedArrays @@ -380,14 +364,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): import pytato as pt @@ -776,14 +752,6 @@ def _wrapper(ary): # {{{ ArrayContext interface - def zeros_like(self, ary): - from warnings import warn - warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " - "working in 2023. Use actx.np.zeros_like instead.", - DeprecationWarning, stacklevel=2) - - return self.np.zeros_like(ary) - def from_numpy(self, array): import jax import pytato as pt From 0c24aad6f1050a98e3a8af5e3256a49f1db9551b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 6 Aug 2024 12:31:42 -0500 Subject: [PATCH 03/32] Fix a return type in ArgSizeLimitingPytatoLoopyPyOpenCLTarget --- arraycontext/impl/pytato/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 0af54204..a5582d18 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,7 +23,7 @@ """ -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Mapping, Set, Tuple from pytato.array import ( AbstractResultWithNamedArrays, @@ -118,7 +118,7 @@ def __init__(self, limit_arg_size_nbytes: int) -> None: self.limit_arg_size_nbytes = limit_arg_size_nbytes @memoize_method - def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: + def get_loopy_target(self) -> "lp.PyOpenCLTarget": from loopy import PyOpenCLTarget return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) From 02ab097bd4c7373bada3a9fc230fdb87b34cfc83 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 5 Aug 2024 10:52:00 -0500 Subject: [PATCH 04/32] Separate doc page for actx abstraction from doc page for implementations --- doc/array_context.rst | 40 ---------------------------------------- doc/implementations.rst | 40 ++++++++++++++++++++++++++++++++++++++++ doc/index.rst | 1 + 3 files changed, 41 insertions(+), 40 deletions(-) create mode 100644 doc/implementations.rst diff --git a/doc/array_context.rst b/doc/array_context.rst index 85a6cc44..db1182e2 100644 --- a/doc/array_context.rst +++ b/doc/array_context.rst @@ -4,43 +4,3 @@ The Array Context Abstraction .. automodule:: arraycontext .. automodule:: arraycontext.context - -Implementations of the Array Context Abstraction -================================================ - -.. - When adding a new array context here, make sure to also add it to and run - ``` - doc/make_numpy_coverage_table.py - ``` - to update the coverage table below! - -Array context based on :mod:`pyopencl.array` --------------------------------------------- - -.. automodule:: arraycontext.impl.pyopencl - - -Lazy/Deferred evaluation array context based on :mod:`pytato` -------------------------------------------------------------- - -.. automodule:: arraycontext.impl.pytato - - -Array context based on :mod:`jax.numpy` ---------------------------------------- - -.. automodule:: arraycontext.impl.jax - -.. _numpy-coverage: - -:mod:`numpy` coverage ---------------------- - -This is a list of functionality implemented by :attr:`arraycontext.ArrayContext.np`. - -.. note:: - - Only functions and methods that have at least one implementation are listed. - -.. include:: numpy_coverage.rst diff --git a/doc/implementations.rst b/doc/implementations.rst new file mode 100644 index 00000000..db35cca9 --- /dev/null +++ b/doc/implementations.rst @@ -0,0 +1,40 @@ +Implementations of the Array Context Abstraction +================================================ + +.. + When adding a new array context here, make sure to also add it to and run + ``` + doc/make_numpy_coverage_table.py + ``` + to update the coverage table below! + +Array context based on :mod:`pyopencl.array` +-------------------------------------------- + +.. automodule:: arraycontext.impl.pyopencl + + +Lazy/Deferred evaluation array context based on :mod:`pytato` +------------------------------------------------------------- + +.. automodule:: arraycontext.impl.pytato + + +Array context based on :mod:`jax.numpy` +--------------------------------------- + +.. automodule:: arraycontext.impl.jax + +.. _numpy-coverage: + +:mod:`numpy` coverage +--------------------- + +This is a list of functionality implemented by :attr:`arraycontext.ArrayContext.np`. + +.. note:: + + Only functions and methods that have at least one implementation are listed. + +.. include:: numpy_coverage.rst + diff --git a/doc/index.rst b/doc/index.rst index 48fd25bb..d3f9854b 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -53,6 +53,7 @@ Contents .. toctree:: array_context + implementations container other misc From c8884899825f41e66d723ef6b366367c564cf0ef Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 13:48:08 -0500 Subject: [PATCH 05/32] Give up on precisely typing Array.__getitem__ --- arraycontext/context.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 0640e5a1..71121d2c 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -217,7 +217,11 @@ def size(self) -> int: def dtype(self) -> "np.dtype[Any]": ... - def __getitem__(self, index: Union[slice, int]) -> "Array": + # Covering all the possible index variations is hard and (kind of) futile. + # If you'd like to see how, try changing the Any to + # AxisIndex = slice | int | "Array" + # Index = AxisIndex |tuple[AxisIndex] + def __getitem__(self, index: Any) -> "Array": ... From cd124bac661e63a85651e17e6c18b7a007da4fac Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 13:47:04 -0500 Subject: [PATCH 06/32] Fix doc upload script to properly sync deletions --- doc/upload-docs.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/upload-docs.sh b/doc/upload-docs.sh index 176c9fbd..02615433 100755 --- a/doc/upload-docs.sh +++ b/doc/upload-docs.sh @@ -1,3 +1,3 @@ #! /bin/sh -rsync --verbose --archive --delete _build/html/* doc-upload:doc/arraycontext +rsync --verbose --archive --delete _build/html/ doc-upload:doc/arraycontext From 5d8158de9e88fce371a13d257560a43f6304457a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 18:48:17 -0500 Subject: [PATCH 07/32] Deprecate with_container_arithmetic's bcast_numpy_array arg Passing both 'bcast_numpy_array' and '_bcast_actx_array_types' was ill-defined. For example, in the case of an ArrayContext whose thawed array type is np.ndarray the specification would contradict between broadcasting the argument numpy_array to return an object array *OR* peforming the operation with every leaf array. Consider the example below, ( - 'Foo: ArrayContainer' whose arithmetic routines are generated by `with_container_arithmetic(bcast_numpy=True, _bcast_actx_array_types=True)` - 'actx: ArrayContextT' for whom `np.ndarray` is a valid thawed array type. ) Foo(DOFArray(actx, [38*actx.ones(3, np.float64)])) + np.array([3, 4, 5]) could be either of: - array([Foo(DOFArray([array([41, 41, 41])])), Foo(DOFArray([array([42, 42, 42])])), Foo(DOFArray([array([43, 43, 43])]))]), OR, - Foo(DOFArray(actx, array([41, 42, 43]))) --- arraycontext/container/arithmetic.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 2ef5ddc9..4c8a09a6 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -214,6 +214,15 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if rel_comparison is None: raise TypeError("rel_comparison must be specified") + if bcast_numpy_array: + from warnings import warn + warn("'bcast_numpy_array=True' is deprecated and will be unsupported" + " from December 2021", DeprecationWarning, stacklevel=2) + + if _bcast_actx_array_type: + raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" + " cannot be both set.") + if rel_comparison and eq_comparison is None: eq_comparison = True From 228ef166ece3d88318b527b1d0cb99b28ac0029a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 02:38:28 -0500 Subject: [PATCH 08/32] Implements NumpyArrayContext --- arraycontext/__init__.py | 7 +- arraycontext/container/arithmetic.py | 1 - arraycontext/impl/numpy/__init__.py | 126 +++++++++++++++++++++++ arraycontext/impl/numpy/fake_numpy.py | 143 ++++++++++++++++++++++++++ doc/implementations.rst | 5 + 5 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 arraycontext/impl/numpy/__init__.py create mode 100644 arraycontext/impl/numpy/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1d0efb36..2f2640dc 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -78,6 +78,7 @@ tag_axes, ) from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program @@ -91,7 +92,6 @@ __all__ = ( - "Array", "Array", "ArrayContainer", "ArrayContainerT", @@ -105,13 +105,13 @@ "EagerJAXArrayContext", "ElementwiseMapKernelTag", "NotAnArrayContainerError", + "NumpyArrayContext", "PyOpenCLArrayContext", "PytatoJAXArrayContext", "PytatoPyOpenCLArrayContext", "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", "Scalar", - "Scalar", "ScalarLike", "dataclass_array_container", "deserialize_container", @@ -146,8 +146,9 @@ "to_numpy", "unflatten", "with_array_context", + "with_container_arithmetic", "with_container_arithmetic" -) + ) # {{{ deprecation handling diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 4c8a09a6..63f93272 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -33,7 +33,6 @@ """ from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union -from warnings import warn import numpy as np diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 00000000..dbc725f7 --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,126 @@ +""" +.. currentmodule:: arraycontext + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Dict, Sequence, Union + +import numpy as np + +import loopy as lp +from pytools.tag import Tag + +from arraycontext.context import ArrayContext + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays. + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + + self.array_types = (np.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return np.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def from_numpy(self, np_array: np.ndarray): + # Uh oh... + return np_array + + def to_numpy(self, array): + # Uh oh... + return array + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + return array + + def thaw(self, array): + return array + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("NumpyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags: Union[Sequence[Tag], Tag], array): + # Numpy doesn't support tagging + return array + + def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 00000000..54867c8d --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,143 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np + +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + multimap_reduce_array_container, + rec_map_array_container, + rec_map_reduce_array_container, + rec_multimap_array_container, + rec_multimap_reduce_array_container, +) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, + BaseFakeNumpyNamespace, +) + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = frozenset({"concatenate", "reshape", "transpose", + "ones_like", "where", + *BaseFakeNumpyNamespace._numpy_math_functions + }) + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + raise AttributeError(name) + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y): + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a, b): + if type(a) != type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return np.all(np.equal(a, b)) + else: + return multimap_reduce_array_container(partial(reduce, + np.logical_and), + self.array_equal, a, b) + +# vim: fdm=marker diff --git a/doc/implementations.rst b/doc/implementations.rst index db35cca9..4023e37c 100644 --- a/doc/implementations.rst +++ b/doc/implementations.rst @@ -8,6 +8,11 @@ Implementations of the Array Context Abstraction ``` to update the coverage table below! +Array context based on :mod:`numpy` +-------------------------------------------- + +.. automodule:: arraycontext.impl.numpy + Array context based on :mod:`pyopencl.array` -------------------------------------------- From 1dc8c94915a319c390d7537ac6452a54233d1e74 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 03:03:53 -0500 Subject: [PATCH 09/32] ArrayContainer fixes for numpy arrays as leaf classes --- arraycontext/container/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index ea20a5ac..53506a0f 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -218,7 +218,11 @@ def is_array_container(ary: Any) -> bool: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # type:ignore[attr-defined] + # numpy values with scalar elements aren't array containers + and not (isinstance(ary, np.ndarray) + and ary.dtype.kind != "O") + ) @singledispatch From 51b46bd1c10b339a4395b5fe43ae68d55d445ce3 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 01:32:30 -0500 Subject: [PATCH 10/32] arithmetic fixes to account for np.ndarray being a leaf array --- arraycontext/container/arithmetic.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 63f93272..663cdde2 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -492,16 +492,17 @@ def {fname}(arg1): bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -538,16 +539,16 @@ def {fname}(arg1): def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_init_args}) return NotImplemented cls.__r{dunder_name}__ = {fname}""") From 6308dc1a5135e6b66ab7e8f86451de7c5c2b8b89 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 02:41:25 -0500 Subject: [PATCH 11/32] test NumpyArrayContext --- arraycontext/pytest.py | 22 ++++++++++++++++++++++ test/test_arraycontext.py | 2 ++ 2 files changed, 24 insertions(+) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 8a1e0274..e74a6aef 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,6 +34,7 @@ from typing import Any, Callable, Dict, Sequence, Type, Union +from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -221,6 +222,26 @@ def __str__(self): return "" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -229,6 +250,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3f06156b..ffd7553d 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -46,6 +46,7 @@ ) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, @@ -97,6 +98,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) From b5ea2703fba370f8cdd7effdd2eda2818ab24984 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 01:35:55 -0500 Subject: [PATCH 12/32] test tweaks for NumpyArrayContext --- test/test_arraycontext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ffd7553d..26d4d807 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1360,6 +1360,8 @@ def test_container_equality(actx_factory): class Foo: u: DOFArray + __array_priority__ = 1 # disallow numpy arithmetic to take precedence + @property def array_context(self): return self.u.array_context From 80c0672fa2da922bf0b8fc439513ac5ac0a3c711 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 24 May 2024 13:31:59 -0500 Subject: [PATCH 13/32] Numpy actx: add arange, linspace --- arraycontext/impl/numpy/fake_numpy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 54867c8d..4ac10055 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -140,4 +140,10 @@ def array_equal(self, a, b): np.logical_and), self.array_equal, a, b) + def arange(self, *args, **kwargs): + return np.arange(*args, **kwargs) + + def linspace(self, *args, **kwargs): + return np.linspace(*args, **kwargs) + # vim: fdm=marker From 6d3b02ad65be3fa9bdd7cbc1d9271079a285ccdb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:12:15 -0500 Subject: [PATCH 14/32] Numpy actx: add zeros_like, reshape --- arraycontext/impl/numpy/fake_numpy.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 4ac10055..b7a2335a 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -146,4 +146,13 @@ def arange(self, *args, **kwargs): def linspace(self, *args, **kwargs): return np.linspace(*args, **kwargs) + def zeros_like(self, ary): + return rec_map_array_container(np.zeros_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + + # vim: fdm=marker From 4125e020e82c49041f257976a12be7e940ee270e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:13:18 -0500 Subject: [PATCH 15/32] Numpy actx: better freeze/thaw --- arraycontext/impl/numpy/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index dbc725f7..89c4e885 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -38,6 +38,8 @@ from pytools.tag import Tag from arraycontext.context import ArrayContext +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) class NumpyArrayContext(ArrayContext): @@ -91,10 +93,16 @@ def call_loopy(self, t_unit, **kwargs): return result def freeze(self, array): - return array + def _freeze(ary): + return ary + + return with_array_context(rec_map_array_container(_freeze, array), actx=None) def thaw(self, array): - return array + def _thaw(ary): + return ary + + return with_array_context(rec_map_array_container(_thaw, array), actx=self) # }}} From 5da96a8ea4e0904e4fce6348dfbe52c8fcd0dfd9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 16:21:23 -0500 Subject: [PATCH 16/32] Numpy actx: Narrow array_types to non-obj arrays --- arraycontext/impl/numpy/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 89c4e885..28910150 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,16 +30,24 @@ THE SOFTWARE. """ -from typing import Dict, Sequence, Union +from typing import Any, Dict, Sequence, Union import numpy as np import loopy as lp from pytools.tag import Tag +from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ArrayContext -from arraycontext.container.traversal import ( - rec_map_array_container, with_array_context) + + +class NumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype != object + + +class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass): + pass class NumpyArrayContext(ArrayContext): @@ -53,7 +61,7 @@ def __init__(self): self._loopy_transform_cache: \ Dict[lp.TranslationUnit, lp.TranslationUnit] = {} - self.array_types = (np.ndarray,) + array_types = (NumpyNonObjectArray,) def _get_fake_numpy_namespace(self): from .fake_numpy import NumpyFakeNumpyNamespace From aa53572584a049b90acd40f4a4bf9f7c6bf8ed86 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 13:49:21 -0500 Subject: [PATCH 17/32] Numpy actx: improve type annotations --- arraycontext/impl/numpy/__init__.py | 40 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 28910150..7d724b84 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,15 +30,20 @@ THE SOFTWARE. """ -from typing import Any, Dict, Sequence, Union +from typing import Any, Dict import numpy as np import loopy as lp -from pytools.tag import Tag +from pytools.tag import ToTagSetConvertible from arraycontext.container.traversal import rec_map_array_container, with_array_context -from arraycontext.context import ArrayContext +from arraycontext.context import ( + ArrayContext, + ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, + NumpyOrContainerOrScalar, +) class NumpyNonObjectArrayMetaclass(type): @@ -56,7 +61,7 @@ class NumpyArrayContext(ArrayContext): .. automethod:: __init__ """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._loopy_transform_cache: \ Dict[lp.TranslationUnit, lp.TranslationUnit] = {} @@ -72,18 +77,14 @@ def _get_fake_numpy_namespace(self): def clone(self): return type(self)() - def empty(self, shape, dtype): - return np.empty(shape, dtype=dtype) - - def zeros(self, shape, dtype): - return np.zeros(shape, dtype) - - def from_numpy(self, np_array: np.ndarray): - # Uh oh... - return np_array + def from_numpy(self, + array: NumpyOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: + return array - def to_numpy(self, array): - # Uh oh... + def to_numpy(self, + array: ArrayOrContainerOrScalar + ) -> NumpyOrContainerOrScalar: return array def call_loopy(self, t_unit, **kwargs): @@ -119,11 +120,16 @@ def transform_loopy_program(self, t_unit): "transform_loopy_program. Sub-classes are supposed " "to implement it.") - def tag(self, tags: Union[Sequence[Tag], Tag], array): + def tag(self, + tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: # Numpy doesn't support tagging return array - def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + def tag_axis(self, + iaxis: int, tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + # Numpy doesn't support tagging return array def einsum(self, spec, *args, arg_names=None, tagged=()): From cf3f4fbc94e3e8e77ae723645a024fe317ae6dbe Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 12 Jul 2024 16:37:34 -0500 Subject: [PATCH 18/32] Array container arithemtic: drop deprecated fail-safe actx retrieval --- arraycontext/container/arithmetic.py | 107 ++++++--------------------- 1 file changed, 23 insertions(+), 84 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 663cdde2..dbfdd5a6 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -1,4 +1,6 @@ # mypy: disallow-untyped-defs +from __future__ import annotations + """ .. currentmodule:: arraycontext @@ -32,7 +34,7 @@ THE SOFTWARE. """ -from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Optional, Tuple, TypeVar, Union import numpy as np @@ -125,10 +127,6 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) -class _FailSafe: - pass - - def with_container_arithmetic( *, bcast_number: bool = True, @@ -266,34 +264,28 @@ def numpy_pred(name: str) -> str: # }}} def wrap(cls: Any) -> Any: - cls_has_array_context_attr: Optional[Union[bool, Type[_FailSafe]]] = \ + cls_has_array_context_attr: bool | None = \ _cls_has_array_context_attr - bcast_actx_array_type: Optional[Union[bool, Type[_FailSafe]]] = \ + bcast_actx_array_type: bool | None = \ _bcast_actx_array_type if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): - cls_has_array_context_attr = _FailSafe - warn(f"{cls} has an 'array_context' attribute, but it does not " - "set '_cls_has_array_context_attr' to 'True' when calling " - "'with_container_arithmetic'. This is being interpreted " - "as 'array_context' being permitted to fail. Tolerating " - "these failures comes at a substantial cost. It is " - "deprecated and will stop working in 2023. " - "Having a working 'array_context' attribute is desirable " - "to enable arithmetic with other array types supported " - "by the array context. " - f"If '{cls.__name__}.array_context' will not fail, pass " + raise TypeError( + f"{cls} has an 'array_context' attribute, but it does not " + "set '_cls_has_array_context_attr' to *True* when calling " + "with_container_arithmetic. This is being interpreted " + "as 'array_context' being permitted to fail with an exception, " + "which is no longer allowed. " + f"If {cls.__name__}.array_context will not fail, pass " "'_cls_has_array_context_attr=True'. " "If you do not want container arithmetic to make " "use of the array context, set " - "'_cls_has_array_context_attr=False'.", - stacklevel=2) + "'_cls_has_array_context_attr=False'.") if bcast_actx_array_type is None: if cls_has_array_context_attr: if bcast_number: - # copy over _FailSafe if present bcast_actx_array_type = cls_has_array_context_attr else: bcast_actx_array_type = False @@ -310,20 +302,12 @@ def wrap(cls: Any) -> Any: "'_deserialize_init_arrays_code'. If this is a dataclass, " "use the 'dataclass_array_container' decorator first.") - if cls_has_array_context_attr is _FailSafe: - def actx_getter_code(arg: str) -> str: - return f"_get_actx({arg})" - else: - def actx_getter_code(arg: str) -> str: - return f"{arg}.array_context" - from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() gen(""" from numbers import Number import numpy as np - from arraycontext import ( - ArrayContainer, get_container_context_recursively) + from arraycontext import ArrayContainer from warnings import warn def _raise_if_actx_none(actx): @@ -331,45 +315,6 @@ def _raise_if_actx_none(actx): raise ValueError("array containers with frozen arrays " "cannot be operated upon") return actx - - def _get_actx(ary): - try: - return ary.array_context - except Exception as e: - warn(f"Accessing '{type(ary).__name__}.array_context' failed " - f"({type(e)}: {e}). This should not happen and is " - "deprecated. " - "Please fix the implementation of " - f"'{type(ary).__name__}.array_context' " - "and then set _cls_has_array_context_attr=True when " - "calling with_container_arithmetic to avoid the run time " - "cost of the check that gave you this warning. " - "Using expensive recovery for now.", - DeprecationWarning, stacklevel=3) - - return get_container_context_recursively(ary) - - def _get_actx_array_types_failsafe(ary): - try: - actx = ary.array_context - except Exception as e: - warn(f"Accessing '{type(ary).__name__}.array_context' failed " - f"({type(e)}: {e}). This should not happen and is " - "deprecated. " - "Please fix the implementation of " - f"'{type(ary).__name__}.array_context' " - "and then set _cls_has_array_context_attr=True when " - "calling with_container_arithmetic to avoid the run time " - "cost of the check that gave you this warning. " - "Using expensive recovery for now.", - DeprecationWarning, stacklevel=3) - - actx = get_container_context_recursively(ary) - - if actx is None: - return () - - return actx.array_types """) gen("") @@ -459,9 +404,9 @@ def {fname}(arg1): gen("if arg2.__class__ is cls:") with Indentation(gen): if __debug__ and cls_has_array_context_attr: - gen(f""" - arg1_actx = {actx_getter_code("arg1")} - arg2_actx = {actx_getter_code("arg2")} + gen(""" + arg1_actx = arg1.array_context + arg2_actx = arg2.array_context if arg1_actx is not arg2_actx: msg = ("array contexts of both arguments " "must match") @@ -477,17 +422,14 @@ def {fname}(arg1): raise ValueError(msg)""") gen(f"return cls({zip_init_args})") - if bcast_actx_array_type is _FailSafe: - bcast_actx_ary_types: Tuple[str, ...] = ( - "*_get_actx_array_types_failsafe(arg1)",) - elif bcast_actx_array_type: + if bcast_actx_array_type: if __debug__: bcast_actx_ary_types = ( "*_raise_if_actx_none(" - f"{actx_getter_code('arg1')}).array_types",) + "arg1.array_context).array_types",) else: bcast_actx_ary_types = ( - f"*{actx_getter_code('arg1')}.array_types",) + "*arg1.array_context.array_types",) else: bcast_actx_ary_types = () @@ -521,17 +463,14 @@ def {fname}(arg1): cls._serialize_init_arrays_code("arg2").items() }) - if bcast_actx_array_type is _FailSafe: - bcast_actx_ary_types = ( - "*_get_actx_array_types_failsafe(arg2)",) - elif bcast_actx_array_type: + if bcast_actx_array_type: if __debug__: bcast_actx_ary_types = ( "*_raise_if_actx_none(" - f"{actx_getter_code('arg2')}).array_types",) + "arg2.array_context).array_types",) else: bcast_actx_ary_types = ( - f"*{actx_getter_code('arg2')}.array_types",) + "*arg2.array_context.array_types",) else: bcast_actx_ary_types = () From 1af76ce12df5ba1855d1854283f65413236972b8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 17:01:45 -0500 Subject: [PATCH 19/32] Skip tagging test for numpy actx --- test/test_arraycontext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 26d4d807..4bc14872 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1585,8 +1585,8 @@ def test_to_numpy_on_frozen_arrays(actx_factory): def test_tagging(actx_factory): actx = actx_factory() - if isinstance(actx, EagerJAXArrayContext): - pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, (NumpyArrayContext, EagerJAXArrayContext)): + pytest.skip(f"{type(actx)} has no tagging support") from pytools.tag import Tag From b58e38e4b180f696ce3637eb172990ab8546369e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 20:20:14 -0500 Subject: [PATCH 20/32] Skip numpy conversion tests when using the numpy actx --- test/test_arraycontext.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 4bc14872..2d952fb1 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1171,16 +1171,17 @@ def test_numpy_conversion(actx_factory): assert np.allclose(ac.mass, ac_roundtrip.mass) assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) - from dataclasses import replace - ac_with_cl = replace(ac, enthalpy=ac_actx.mass) - with pytest.raises(TypeError): - actx.from_numpy(ac_with_cl) + if not isinstance(actx, NumpyArrayContext): + from dataclasses import replace + ac_with_cl = replace(ac, enthalpy=ac_actx.mass) + with pytest.raises(TypeError): + actx.from_numpy(ac_with_cl) - with pytest.raises(TypeError): - actx.from_numpy(ac_actx) + with pytest.raises(TypeError): + actx.from_numpy(ac_actx) - with pytest.raises(TypeError): - actx.to_numpy(ac) + with pytest.raises(TypeError): + actx.to_numpy(ac) # }}} From eca314ff71e741f410f895bc6978284bb0f1a4f6 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 17:07:31 -0500 Subject: [PATCH 21/32] Don't expect unflatten failure from numpy array for numpy actx --- test/test_arraycontext.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 2d952fb1..94d7d748 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1113,9 +1113,10 @@ def test_flatten_array_container_failure(actx_factory): ary = _get_test_containers(actx, shapes=512)[0] flat_ary = _checked_flatten(ary, actx) - with pytest.raises(TypeError): - # cannot unflatten from a numpy array - unflatten(ary, actx.to_numpy(flat_ary), actx) + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + # cannot unflatten from a numpy array (except for numpy actx) + unflatten(ary, actx.to_numpy(flat_ary), actx) with pytest.raises(ValueError): # cannot unflatten non-flat arrays From 4b4ee86f902e82ff8adfdd56023288d2b03ef299 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 14:17:50 -0500 Subject: [PATCH 22/32] Container serialization: iterable -> sequence, plus type aliases --- arraycontext/__init__.py | 6 +++- arraycontext/container/__init__.py | 50 ++++++++++++++++++++++------- arraycontext/container/traversal.py | 10 +++--- pyproject.toml | 3 ++ 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 2f2640dc..e8e6e9f3 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -32,6 +32,8 @@ ArrayContainer, ArrayContainerT, NotAnArrayContainerError, + SerializationKey, + SerializedContainer, deserialize_container, get_container_context_opt, get_container_context_recursively, @@ -113,6 +115,8 @@ "PytestPyOpenCLArrayContextFactory", "Scalar", "ScalarLike", + "SerializationKey", + "SerializedContainer", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", @@ -148,7 +152,7 @@ "with_array_context", "with_container_arithmetic", "with_container_arithmetic" - ) +) # {{{ deprecation handling diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 53506a0f..38a23412 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -12,6 +12,9 @@ Serialization/deserialization ----------------------------- + +.. autoclass:: SerializationKey +.. autoclass:: SerializedContainer .. autofunction:: is_array_container_type .. autofunction:: serialize_container .. autofunction:: deserialize_container @@ -39,6 +42,14 @@ .. class:: ArrayOrContainerT :canonical: arraycontext.ArrayOrContainerT + +.. class:: SerializationKey + + :canonical: arraycontext.SerializationKey + +.. class:: SerializedContainer + + :canonical: arraycontext.SerializedContainer """ from __future__ import annotations @@ -69,12 +80,23 @@ """ from functools import singledispatch -from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + Iterable, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, +) # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy import numpy as np +from typing_extensions import TypeAlias from arraycontext.context import ArrayContext @@ -142,23 +164,27 @@ class NotAnArrayContainerError(TypeError): """:class:`TypeError` subclass raised when an array container is expected.""" +SerializationKey: TypeAlias = Hashable +SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]] + + @singledispatch def serialize_container( - ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]: - r"""Serialize the array container into an iterable over its components. + ary: ArrayContainer) -> SerializedContainer: + r"""Serialize the array container into a sequence over its components. The order of the components and their identifiers are entirely under the control of the container class. However, the order is required to be deterministic, i.e. two calls to :func:`serialize_container` on array containers of the same types with the same number of - sub-arrays must result in an iterable with the keys in the same + sub-arrays must result in a sequence with the keys in the same order. If *ary* is mutable, the serialization function is not required to ensure that the serialization result reflects the array state at the time of the call to :func:`serialize_container`. - :returns: an :class:`Iterable` of 2-tuples where the first + :returns: a :class:`Sequence` of 2-tuples where the first entry is an identifier for the component and the second entry is an array-like component of the :class:`ArrayContainer`. Components can themselves be :class:`ArrayContainer`\ s, allowing @@ -172,13 +198,13 @@ def serialize_container( @singledispatch def deserialize_container( template: ArrayContainerT, - iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT: - """Deserialize an iterable into an array container. + serialized: SerializedContainer) -> ArrayContainerT: + """Deserialize a sequence into an array container following a *template*. :param template: an instance of an existing object that can be used to aid in the deserialization. For a similar choice see :attr:`~numpy.class.__array_finalize__`. - :param iterable: an iterable that mirrors the output of + :param serialized: a sequence that mirrors the output of :meth:`serialize_container`. """ raise NotAnArrayContainerError( @@ -242,7 +268,7 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: @serialize_container.register(np.ndarray) def _serialize_ndarray_container( - ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]: + ary: numpy.ndarray) -> SerializedContainer: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") @@ -256,20 +282,20 @@ def _serialize_ndarray_container( for j in range(ary.shape[1]) ] else: - return np.ndenumerate(ary) + return list(np.ndenumerate(ary)) @deserialize_container.register(np.ndarray) # https://github.com/python/mypy/issues/13040 def _deserialize_ndarray_container( # type: ignore[misc] template: numpy.ndarray, - iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray: + serialized: SerializedContainer) -> numpy.ndarray: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" result = type(template)(template.shape, dtype=object) - for i, subary in iterable: + for i, subary in serialized: result[i] = subary return result diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 4a60a8f9..31b3bcf5 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -77,6 +77,7 @@ from arraycontext.container import ( ArrayContainer, NotAnArrayContainerError, + SerializationKey, deserialize_container, get_container_context_recursively_opt, serialize_container, @@ -373,12 +374,9 @@ def wrapper(*args: Any) -> Any: # {{{ keyed array container traversal -KeyType = Any - - def keyed_map_array_container( f: Callable[ - [KeyType, ArrayOrContainer], + [SerializationKey, ArrayOrContainer], ArrayOrContainer], ary: ArrayOrContainer) -> ArrayOrContainer: r"""Applies *f* to all components of an :class:`ArrayContainer`. @@ -403,7 +401,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT], + f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -412,7 +410,7 @@ def rec_keyed_map_array_container( the current array. """ - def rec(keys: Tuple[Union[str, int], ...], + def rec(keys: Tuple[SerializationKey, ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) diff --git a/pyproject.toml b/pyproject.toml index ca64c70a..d971ae20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", + + # for TypeAlias + "typing-extensions>=4; python_version<'3.10'", ] [project.optional-dependencies] From 3d36c07f68184a74d01b50d45e5e8da124649edd Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 14:20:32 -0500 Subject: [PATCH 23/32] Improve, type, fix array_equal across all array contexts --- arraycontext/impl/jax/fake_numpy.py | 28 ++++++++++++------ arraycontext/impl/numpy/fake_numpy.py | 37 +++++++++++++++--------- arraycontext/impl/pyopencl/fake_numpy.py | 29 +++++++++++++------ arraycontext/impl/pytato/fake_numpy.py | 35 ++++++++++++++-------- 4 files changed, 86 insertions(+), 43 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 3fc5f2e6..bc9481e3 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -27,12 +27,16 @@ import jax.numpy as jnp -from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container import ( + NotAnArrayContainerError, + serialize_container, +) from arraycontext.container.traversal import ( rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace @@ -156,29 +160,35 @@ def any(self, a): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) def rec_equal(x, y): if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: if x.shape != y.shape: - return false + return false_ary else: return jnp.all(jnp.equal(x, y)) else: + if len(serialized_x) != len(serialized_y): + return false_ary return reduce( jnp.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) return rec_equal(a, b) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index b7a2335a..b305717e 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -25,14 +25,14 @@ import numpy as np -from arraycontext.container import is_array_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - multimap_reduce_array_container, rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import ( BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace, @@ -127,18 +127,29 @@ def all(self, a): return rec_map_reduce_array_container(partial(reduce, np.logical_and), lambda subary: np.all(subary), a) - def array_equal(self, a, b): - if type(a) != type(b): - return False - elif not is_array_container(a): - if a.shape != b.shape: - return False - else: - return np.all(np.equal(a, b)) + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: + false_ary = np.array(False) + true_ary = np.array(True) + if type(a) is not type(b): + return false_ary + + try: + serialized_x = serialize_container(a) + serialized_y = serialize_container(b) + except NotAnArrayContainerError: + assert isinstance(a, np.ndarray) + assert isinstance(b, np.ndarray) + return np.array(np.array_equal(a, b)) else: - return multimap_reduce_array_container(partial(reduce, - np.logical_and), - self.array_equal, a, b) + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( + np.logical_and, + [(true_ary if kx_i == ky_i else false_ary) + and self.array_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 59be99e8..848870a9 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -38,6 +38,7 @@ rec_multimap_array_container, rec_multimap_reduce_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -215,30 +216,40 @@ def _any(ary): result = result.get()[()] return result - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context queue = actx.queue # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, cl_array.Array) + assert isinstance(y, cl_array.Array) + if x.shape != y.shape: - return false + return false_ary else: return (x == y).all() else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( partial(cl_array.minimum, queue=queue), - [rec_equal(x_i, y_i)for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) result = rec_equal(a, b) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d3d018d6..c6508e3a 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -22,7 +22,7 @@ THE SOFTWARE. """ from functools import partial, reduce -from typing import Any +from typing import Any, cast import numpy as np @@ -34,6 +34,7 @@ rec_map_reduce_array_container, rec_multimap_array_container, ) +from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace @@ -171,31 +172,41 @@ def any(self, a): partial(reduce, pt.logical_or), lambda subary: pt.any(subary), a) - def array_equal(self, a, b): + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context # NOTE: not all backends support `bool` properly, so use `int8` instead - true = actx.from_numpy(np.int8(True)) - false = actx.from_numpy(np.int8(False)) + true_ary = actx.from_numpy(np.int8(True)) + false_ary = actx.from_numpy(np.int8(False)) - def rec_equal(x, y): + def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: if type(x) is not type(y): - return false + return false_ary try: - iterable = zip(serialize_container(x), serialize_container(y)) + serialized_x = serialize_container(x) + serialized_y = serialize_container(y) except NotAnArrayContainerError: + assert isinstance(x, pt.Array) + assert isinstance(y, pt.Array) + if x.shape != y.shape: - return false + return false_ary else: - return pt.all(pt.equal(x, y)) + return pt.all(cast(pt.Array, pt.equal(x, y))) else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( pt.logical_and, - [rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable], - true) + [(true_ary if kx_i == ky_i else false_ary) + and rec_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y)], + true_ary) - return rec_equal(a, b) + return cast(Array, rec_equal(a, b)) # }}} From 58acd1f5624ff42c8d3ac52049dfa8ed753558a5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 16:20:10 -0500 Subject: [PATCH 24/32] Clarify that actx.array_types allows ABCs --- arraycontext/context.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 71121d2c..3ccc357c 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -75,8 +75,8 @@ .. currentmodule:: arraycontext -The interface of an array context ---------------------------------- +The :class:`ArrayContext` Interface +----------------------------------- .. autoclass:: ArrayContext @@ -278,8 +278,10 @@ class ArrayContext(ABC): A :class:`tuple` of types that are the valid array classes the context can operate on. However, it is not necessary that *all* the - :class:`ArrayContext`\ 's operations would be legal for the types in - *array_types*. + :class:`ArrayContext`\ 's operations are legal for the types in + *array_types*. Note that this tuple is *only* intended for use + with :func:`isinstance`. Other uses are not allowed. This allows + for 'types' with overridden :meth:`class.__instancecheck__`. .. automethod:: freeze .. automethod:: thaw From 0feaae1e9da0ea4db0776edff9bc0cd0daf5e242 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 16:57:52 -0500 Subject: [PATCH 25/32] Rework dataclass array container arithmetic - Deprecate automatic broadcasting of array context arrays - Warn about uses of numpy array broadcasting, deprecated earlier - Clarify documentation, warning wording --- arraycontext/__init__.py | 5 +- arraycontext/container/arithmetic.py | 202 +++++++++++++++++++++------ arraycontext/container/traversal.py | 4 +- test/test_arraycontext.py | 58 +++----- 4 files changed, 185 insertions(+), 84 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index e8e6e9f3..4e0ba830 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -43,7 +43,9 @@ register_multivector_as_array_container, serialize_container, ) -from .container.arithmetic import with_container_arithmetic +from .container.arithmetic import ( + with_container_arithmetic, +) from .container.dataclass import dataclass_array_container from .container.traversal import ( flat_size_and_dtype, @@ -151,7 +153,6 @@ "unflatten", "with_array_context", "with_container_arithmetic", - "with_container_arithmetic" ) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index dbfdd5a6..b085a7dc 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -2,13 +2,12 @@ from __future__ import annotations -""" +__doc__ = """ .. currentmodule:: arraycontext + .. autofunction:: with_container_arithmetic """ -import enum - __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -34,7 +33,9 @@ THE SOFTWARE. """ +import enum from typing import Any, Callable, Optional, Tuple, TypeVar, Union +from warnings import warn import numpy as np @@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_binary_op_str(op_str: str, - arg1: Union[Tuple[str, ...], str], - arg2: Union[Tuple[str, ...], str]) -> str: + arg1: Union[Tuple[str, str], str], + arg2: Union[Tuple[str, str], str]) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): import sys if sys.version_info >= (3, 10): @@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) +class NumpyObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype == object + + +class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass): + pass + + +class ComplainingNumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + if isinstance(instance, np.ndarray) and instance.dtype != object: + # Example usage site: + # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149 + # where normal is passed in by test_lfr_flux as a 'custom-made' + # numpy array of dtype float64. + warn( + "Broadcasting container against non-object numpy array. " + "This was never documented to work and will now stop working in " + "2025. Convert the array to an object array to preserve the " + "current semantics.", DeprecationWarning, stacklevel=3) + return True + else: + return False + + +class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass): + pass + + def with_container_arithmetic( *, bcast_number: bool = True, @@ -146,22 +177,16 @@ def with_container_arithmetic( :arg bcast_number: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). - :arg _bcast_actx_array_type: If *True*, instances of base array types of the - container's array context are broadcasted over the container. Can be - *True* only if the container has *_cls_has_array_context_attr* set. - Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, - else *False*. - :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over - the container. (with the container as the 'inner' structure) - :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast - over the container. (with the container as the 'inner' structure) - If this is set to *True*, *bcast_obj_array* must also be *True*. + :arg bcast_obj_array: If *True*, this container will be broadcast + across :mod:`numpy` object arrays + (with the object array as the 'outer' structure). + Add :class:`numpy.ndarray` to *bcast_container_types* to achieve + the 'reverse' broadcasting. :arg bcast_container_types: A sequence of container types that will broadcast - over this container (with this container as the 'outer' structure). + across this container, with this container as the 'outer' structure. :class:`numpy.ndarray` is permitted to be part of this sequence to - indicate that, in such broadcasting situations, this container should - be the 'outer' structure. In this case, *bcast_obj_array* - (and consequently *bcast_numpy_array*) must be *False*. + indicate that object arrays (and *only* object arrays) will be broadcasat. + In this case, *bcast_obj_array* must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -203,6 +228,17 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): should nest "outside" :func:dataclass_array_container`. """ + # Hard-won design lessons: + # + # - Anything that special-cases np.ndarray by type is broken by design because: + # - np.ndarray is an array context array. + # - numpy object arrays can be array containers. + # Using NumpyObjectArray and NumpyNonObjectArray *may* be better? + # They're new, so there is no operational experience with them. + # + # - Broadcast rules are hard to change once established, particularly + # because one cannot grep for their use. + # {{{ handle inputs if bcast_obj_array is None: @@ -212,9 +248,8 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): raise TypeError("rel_comparison must be specified") if bcast_numpy_array: - from warnings import warn warn("'bcast_numpy_array=True' is deprecated and will be unsupported" - " from December 2021", DeprecationWarning, stacklevel=2) + " from 2025.", DeprecationWarning, stacklevel=2) if _bcast_actx_array_type: raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" @@ -231,7 +266,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if bcast_numpy_array: def numpy_pred(name: str) -> str: - return f"isinstance({name}, np.ndarray)" + return f"is_numpy_array({name})" elif bcast_obj_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" @@ -241,12 +276,21 @@ def numpy_pred(name: str) -> str: if bcast_container_types is None: bcast_container_types = () - bcast_container_types_count = len(bcast_container_types) if np.ndarray in bcast_container_types and bcast_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") + numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray] + bcast_container_types = tuple( + new_ct + for old_ct in bcast_container_types + for new_ct in + (numpy_check_types + if old_ct is np.ndarray + else [old_ct]) + ) + desired_op_classes = set() if arithmetic: desired_op_classes.add(_OpClass.ARITHMETIC) @@ -264,10 +308,15 @@ def numpy_pred(name: str) -> str: # }}} def wrap(cls: Any) -> Any: - cls_has_array_context_attr: bool | None = \ - _cls_has_array_context_attr - bcast_actx_array_type: bool | None = \ - _bcast_actx_array_type + if not hasattr(cls, "__array_ufunc__"): + warn(f"{cls} does not have __array_ufunc__ set. " + "This will cause numpy to attempt broadcasting, in a way that " + "is likely undesired. " + f"To avoid this, set __array_ufunc__ = None in {cls}.", + stacklevel=2) + + cls_has_array_context_attr: bool | None = _cls_has_array_context_attr + bcast_actx_array_type: bool | None = _bcast_actx_array_type if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): @@ -275,8 +324,8 @@ def wrap(cls: Any) -> Any: f"{cls} has an 'array_context' attribute, but it does not " "set '_cls_has_array_context_attr' to *True* when calling " "with_container_arithmetic. This is being interpreted " - "as 'array_context' being permitted to fail with an exception, " - "which is no longer allowed. " + "as '.array_context' being permitted to fail " + "with an exception, which is no longer allowed. " f"If {cls.__name__}.array_context will not fail, pass " "'_cls_has_array_context_attr=True'. " "If you do not want container arithmetic to make " @@ -294,6 +343,30 @@ def wrap(cls: Any) -> Any: raise TypeError("_bcast_actx_array_type can be True only if " "_cls_has_array_context_attr is set.") + if bcast_actx_array_type: + if _bcast_actx_array_type: + warn( + f"Broadcasting array context array types across {cls} " + "has been explicitly " + "enabled. As of 2025, this will stop working. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/pull/190. " + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False. ", + DeprecationWarning, stacklevel=2) + else: + warn( + f"Broadcasting array context array types across {cls} " + "has been implicitly " + "enabled. As of 2025, this will no longer work. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/pull/190. " + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False.", + DeprecationWarning, stacklevel=2) + if (not hasattr(cls, "_serialize_init_arrays_code") or not hasattr(cls, "_deserialize_init_arrays_code")): raise TypeError(f"class '{cls.__name__}' must provide serialization " @@ -304,7 +377,7 @@ def wrap(cls: Any) -> Any: from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() - gen(""" + gen(f""" from numbers import Number import numpy as np from arraycontext import ArrayContainer @@ -315,6 +388,24 @@ def _raise_if_actx_none(actx): raise ValueError("array containers with frozen arrays " "cannot be operated upon") return actx + + def is_numpy_array(arg): + if isinstance(arg, np.ndarray): + if arg.dtype != "O": + warn("Operand is a non-object numpy array, " + "and the broadcasting behavior of this array container " + "({cls}) " + "is influenced by this because of its use of " + "the deprecated bcast_numpy_array. This broadcasting " + "behavior will change in 2025. If you would like the " + "broadcasting behavior to stay the same, make sure " + "to convert the passed numpy array to an " + "object array.", + DeprecationWarning, stacklevel=3) + return True + else: + return False + """) gen("") @@ -323,7 +414,7 @@ def _raise_if_actx_none(actx): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") outer_bcast_type_names = tuple([ - f"_bctype{i}" for i in range(bcast_container_types_count) + f"_bctype{i}" for i in range(len(bcast_container_types)) ]) if bcast_number: outer_bcast_type_names += ("Number",) @@ -384,8 +475,6 @@ def {fname}(arg1): continue - # {{{ "forward" binary operators - zip_init_args = cls._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2) @@ -393,11 +482,18 @@ def {fname}(arg1): cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg2").items()) }) - bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { + bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") for key_arg1, expr_arg1 in cls._serialize_init_arrays_code("arg1").items() }) + bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", { + key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2) + for key_arg2, expr_arg2 in + cls._serialize_init_arrays_code("arg2").items() + }) + + # {{{ "forward" binary operators gen(f"def {fname}(arg1, arg2):") with Indentation(gen): @@ -424,7 +520,7 @@ def {fname}(arg1): if bcast_actx_array_type: if __debug__: - bcast_actx_ary_types = ( + bcast_actx_ary_types: tuple[str, ...] = ( "*_raise_if_actx_none(" "arg1.array_context).array_types",) else: @@ -444,7 +540,19 @@ def {fname}(arg1): if isinstance(arg2, {tup_str(outer_bcast_type_names + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) + if __debug__: + if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg2)}} is deprecated " + "and will no longer work in 2025. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/" + "pull/190. ", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg1_is_outer}) + return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -456,12 +564,6 @@ def {fname}(arg1): if reversible: fname = f"_{cls.__name__.lower()}_r{dunder_name}" - bcast_init_args = cls._deserialize_init_arrays_code("arg2", { - key_arg2: _format_binary_op_str( - op_str, "arg1", expr_arg2) - for key_arg2, expr_arg2 in - cls._serialize_init_arrays_code("arg2").items() - }) if bcast_actx_array_type: if __debug__: @@ -487,7 +589,21 @@ def {fname}(arg2, arg1): if isinstance(arg1, {tup_str(outer_bcast_type_names + bcast_actx_ary_types)}): - return cls({bcast_init_args}) + if __debug__: + if isinstance(arg1, + {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg1)}} " + "is deprecated " + "and will no longer work in 2025." + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/" + "pull/190. ", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg2_is_outer}) + return NotImplemented cls.__r{dunder_name}__ = {fname}""") diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 31b3bcf5..100f0775 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -43,6 +43,8 @@ from __future__ import annotations +from arraycontext.container.arithmetic import NumpyObjectArray + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -964,7 +966,7 @@ def treat_as_scalar(x: Any) -> bool: return ( not isinstance(x, np.ndarray) # This condition is whether "ndarrays should broadcast inside x". - and np.ndarray not in x.__class__._outer_bcast_types) + and NumpyObjectArray not in x.__class__._outer_bcast_types) if treat_as_scalar(a) or treat_as_scalar(b): return a*b diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 94d7d748..868790fe 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -22,6 +22,7 @@ import logging from dataclasses import dataclass +from functools import partial from typing import Union import numpy as np @@ -34,6 +35,7 @@ ArrayContainer, ArrayContext, EagerJAXArrayContext, + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, @@ -116,10 +118,10 @@ def _acf(): @with_container_arithmetic( bcast_obj_array=True, - bcast_numpy_array=True, bitwise=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) class DOFArray: def __init__(self, actx, data): if not (actx is None or isinstance(actx, ArrayContext)): @@ -207,7 +209,8 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @with_container_arithmetic(bcast_obj_array=False, eq_comparison=False, rel_comparison=False, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainer: @@ -229,7 +232,8 @@ def array_context(self): bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainerDOFBcast: @@ -936,8 +940,6 @@ def test_container_arithmetic(actx_factory): def _check_allclose(f, arg1, arg2, atol=5.0e-14): assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol - from functools import partial - from arraycontext import rec_multimap_array_container for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: rec_multimap_array_container( @@ -1350,13 +1352,13 @@ def test_container_equality(actx_factory): # }}} -# {{{ test_leaf_array_type_broadcasting +# {{{ test_no_leaf_array_type_broadcasting @with_container_arithmetic( bcast_obj_array=True, - bcast_numpy_array=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class Foo: @@ -1369,39 +1371,19 @@ def array_context(self): return self.u.array_context -def test_leaf_array_type_broadcasting(actx_factory): - # test support for https://github.com/inducer/arraycontext/issues/49 +def test_no_leaf_array_type_broadcasting(actx_factory): + # test lack of support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() - foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))) - bar = foo + 4 - baz = foo + actx.from_numpy(4*np.ones((3, ))) - qux = actx.from_numpy(4*np.ones((3, ))) + foo - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(baz.u[0])) - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(qux.u[0])) - - def _actx_allows_scalar_broadcast(actx): - if not isinstance(actx, PyOpenCLArrayContext): - return True - else: - import pyopencl as cl - - # See https://github.com/inducer/pyopencl/issues/498 - return cl.version.VERSION > (2021, 2, 5) - - if _actx_allows_scalar_broadcast(actx): - quux = foo + actx.from_numpy(np.array(4)) - quuz = actx.from_numpy(np.array(4)) + foo + dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )) + foo = Foo(dof_ary) - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(quux.u[0])) + actx_ary = actx.from_numpy(4*np.ones((3, ))) + with pytest.raises(TypeError): + foo + actx_ary - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(quuz.u[0])) + with pytest.raises(TypeError): + foo + actx.from_numpy(np.array(4)) # }}} From 4873ef4f10627ad64423c111554cc22526ac380c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 5 Aug 2024 19:31:02 -0500 Subject: [PATCH 26/32] Switch to __array_ufunc__ in tests as a way to avoid numpy broadcasting --- test/test_arraycontext.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 868790fe..ae65a3d0 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -133,7 +133,8 @@ def __init__(self, actx, data): self.array_context = actx self.data = data - __array_priority__ = 10 + # prevent numpy broadcasting + __array_ufunc__ = None def __bool__(self): if len(self) == 1 and self.data[0].size == 1: @@ -219,6 +220,8 @@ class MyContainer: momentum: np.ndarray enthalpy: Union[DOFArray, np.ndarray] + __array_ufunc__ = None + @property def array_context(self): if isinstance(self.mass, np.ndarray): @@ -1364,7 +1367,8 @@ def test_container_equality(actx_factory): class Foo: u: DOFArray - __array_priority__ = 1 # disallow numpy arithmetic to take precedence + # prevent numpy arithmetic from taking precedence + __array_ufunc__ = None @property def array_context(self): From 74cd298c3ad98417476ded6eb2b0a855c975fab5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 6 Aug 2024 12:49:52 -0500 Subject: [PATCH 27/32] outer: disallow non-object numpy arrays --- arraycontext/container/traversal.py | 16 +++++++++++----- test/test_arraycontext.py | 9 --------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 100f0775..a7547df2 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -949,8 +949,7 @@ def outer(a: Any, b: Any) -> Any: Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer` always returns a matrix). Here the definition of "scalar" includes - all non-array-container types and any scalar-like array container types - (including non-object numpy arrays). + all non-array-container types and any scalar-like array container types. If *a* and *b* are both array containers, the result will have the same type as *a*. If both are array containers and neither is an object array, they must @@ -968,12 +967,19 @@ def treat_as_scalar(x: Any) -> bool: # This condition is whether "ndarrays should broadcast inside x". and NumpyObjectArray not in x.__class__._outer_bcast_types) + a_is_ndarray = isinstance(a, np.ndarray) + b_is_ndarray = isinstance(b, np.ndarray) + + if a_is_ndarray and a.dtype != object: + raise TypeError("passing a non-object numpy array is not allowed") + if b_is_ndarray and b.dtype != object: + raise TypeError("passing a non-object numpy array is not allowed") + if treat_as_scalar(a) or treat_as_scalar(b): return a*b - # After this point, "isinstance(o, ndarray)" means o is an object array. - elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + elif a_is_ndarray and b_is_ndarray: return np.outer(a, b) - elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + elif a_is_ndarray or b_is_ndarray: return map_array_container(lambda x: outer(x, b), a) else: if type(a) is not type(b): diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ae65a3d0..63387dbe 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1457,15 +1457,6 @@ def equal(a, b): b_bcast_dc_of_dofs.momentum), enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy)) - # Non-object numpy arrays should be treated as scalars - ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass)) - assert equal( - outer(ary_of_floats, b_bcast_dc_of_dofs), - ary_of_floats*b_bcast_dc_of_dofs) - assert equal( - outer(a_bcast_dc_of_dofs, ary_of_floats), - a_bcast_dc_of_dofs*ary_of_floats) - # }}} From 125e9365aee01432523dd7f7e019a821fadc7326 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 11 Aug 2024 22:32:20 +0200 Subject: [PATCH 28/32] Fix ruff C409 failures --- arraycontext/container/arithmetic.py | 5 ++--- arraycontext/context.py | 2 +- arraycontext/impl/pytato/compile.py | 7 ++++--- arraycontext/pytest.py | 2 +- arraycontext/version.py | 2 +- test/test_arraycontext.py | 19 ++++++++----------- 6 files changed, 17 insertions(+), 20 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index b085a7dc..66e10ff0 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -413,9 +413,8 @@ def is_numpy_array(arg): for i, bct in enumerate(bcast_container_types): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") - outer_bcast_type_names = tuple([ - f"_bctype{i}" for i in range(len(bcast_container_types)) - ]) + outer_bcast_type_names = tuple( + f"_bctype{i}" for i in range(len(bcast_container_types))) if bcast_number: outer_bcast_type_names += ("Number",) diff --git a/arraycontext/context.py b/arraycontext/context.py index 3ccc357c..d296f8f7 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -475,7 +475,7 @@ def einsum(self, :return: the output of the einsum :mod:`loopy` program """ if arg_names is None: - arg_names = tuple([f"arg{i}" for i in range(len(args))]) + arg_names = tuple(f"arg{i}" for i in range(len(args))) prg = self._get_einsum_prg(spec, arg_names, tagged) out_ary = self.call_loopy( diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 523ef3a4..54d2cbb8 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -214,10 +214,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): :attr:`BaseLazilyCompilingFunctionCaller.f`. """ if np.isscalar(arg): - name = arg_id_to_name[(kw,)] + name = arg_id_to_name[kw,] return pt.make_placeholder(name, (), np.dtype(type(arg))) elif isinstance(arg, pt.Array): - name = arg_id_to_name[(kw,)] + name = arg_id_to_name[kw,] # Transform the DAG to give metadata inference a chance to do its job arg = _to_input_for_compiled(arg, actx) return pt.make_placeholder(name, arg.shape, arg.dtype, @@ -225,7 +225,8 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): tags=arg.tags) elif is_array_container_type(arg.__class__): def _rec_to_placeholder(keys, ary): - name = arg_id_to_name[(kw, *keys)] + index = (kw, *keys) + name = arg_id_to_name[index] # Transform the DAG to give metadata inference a chance to do its job ary = _to_input_for_compiled(ary, actx) return pt.make_placeholder(name, diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index e74a6aef..088c7e3e 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -396,7 +396,7 @@ def inner(metafunc): # NOTE: sorts the args so that parallel pytest works arg_value_tuples = sorted([ - tuple([arg_dict[name] for name in arg_names]) + tuple(arg_dict[name] for name in arg_names) for arg_dict in arg_values_with_actx ], key=lambda x: str(x)) diff --git a/arraycontext/version.py b/arraycontext/version.py index a1da82c1..31baea05 100644 --- a/arraycontext/version.py +++ b/arraycontext/version.py @@ -8,7 +8,7 @@ def _parse_version(version: str) -> Tuple[Tuple[int, ...], str]: m = re.match("^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT) assert m is not None - return tuple([int(nr) for nr in m.group(1).split(".")]), m.group(2) + return tuple(int(nr) for nr in m.group(1).split(".")), m.group(2) VERSION_TEXT = metadata.version("arraycontext") diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 63387dbe..107539b4 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -170,11 +170,11 @@ def size(self): @property def real(self): - return DOFArray(self.array_context, tuple([subary.real for subary in self])) + return DOFArray(self.array_context, tuple(subary.real for subary in self)) @property def imag(self): - return DOFArray(self.array_context, tuple([subary.imag for subary in self])) + return DOFArray(self.array_context, tuple(subary.imag for subary in self)) @serialize_container.register(DOFArray) @@ -258,9 +258,8 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000): if isinstance(shapes, (Number, tuple)): shapes = [shapes] - x = DOFArray(actx, tuple([ - actx.from_numpy(randn(shape, np.float64)) - for shape in shapes])) + x = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) + for shape in shapes)) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter dataclass_of_dofs = MyContainer( @@ -1081,13 +1080,11 @@ def test_flatten_array_container(actx_factory, shapes): if isinstance(shapes, (int, tuple)): shapes = [shapes] - ary = DOFArray(actx, tuple([ - actx.from_numpy(randn(shape, np.float64)) - for shape in shapes])) + ary = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64)) + for shape in shapes)) - template = DOFArray(actx, tuple([ - actx.from_numpy(randn(shape, np.complex128)) - for shape in shapes])) + template = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.complex128)) + for shape in shapes)) flat = flatten(ary, actx) ary_roundtrip = unflatten(template, flat, actx, strict=False) From 9f1cad4a1cf2fd64b75c5c54fafc8f91620e2161 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 27 Aug 2024 17:51:42 -0500 Subject: [PATCH 29/32] Fix a typo in the pytato actx --- arraycontext/impl/pytato/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 5ece78e9..099738a9 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -194,7 +194,7 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU "to transform a translation unit. " "This is a no-op and will result in unoptimized C code for" "the requested optimization, all in a single statement." - "This will work, but is unlikely to be performatn." + "This will work, but is unlikely to be performant." f"Instead, subclass {type(self).__name__} and implement " "the specific transform logic required to transform the program " "for your package or application. Check higher-level packages " From 8b1b795b7ac6194e3e638c14f8f6f407bb260f1e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 27 Aug 2024 17:51:56 -0500 Subject: [PATCH 30/32] Numpy actx: warn (not error) on no user-provided transforms --- arraycontext/impl/numpy/__init__.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 7d724b84..77b7b49f 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -43,6 +43,7 @@ ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, NumpyOrContainerOrScalar, + UntransformedCodeWarning, ) @@ -116,9 +117,21 @@ def _thaw(ary): # }}} def transform_loopy_program(self, t_unit): - raise ValueError("NumpyArrayContext does not implement " - "transform_loopy_program. Sub-classes are supposed " - "to implement it.") + from warnings import warn + warn("Using the base " + f"{type(self).__name__}.transform_loopy_program " + "to transform a translation unit. " + "This is a no-op and will result in unoptimized C code for" + "the requested optimization, all in a single statement." + "This will work, but is unlikely to be performant." + f"Instead, subclass {type(self).__name__} and implement " + "the specific transform logic required to transform the program " + "for your package or application. Check higher-level packages " + "(e.g. meshmode), which may already have subclasses you may want " + "to build on.", + UntransformedCodeWarning, stacklevel=2) + + return t_unit def tag(self, tags: ToTagSetConvertible, From 510dc1b3465ee2e6f7521223ba2a299879c2533d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 4 Sep 2024 14:56:27 -0500 Subject: [PATCH 31/32] with_container_arithmetic: Rename arguments to signal who broadcasts across who Names suggested by @majosm --- arraycontext/container/arithmetic.py | 161 +++++++++++++++++++-------- test/test_arraycontext.py | 12 +- 2 files changed, 119 insertions(+), 54 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 66e10ff0..9366b260 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -159,34 +159,40 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet def with_container_arithmetic( - *, - bcast_number: bool = True, - _bcast_actx_array_type: Optional[bool] = None, - bcast_obj_array: Optional[bool] = None, - bcast_numpy_array: bool = False, - bcast_container_types: Optional[Tuple[type, ...]] = None, - arithmetic: bool = True, - matmul: bool = False, - bitwise: bool = False, - shift: bool = False, - _cls_has_array_context_attr: Optional[bool] = None, - eq_comparison: Optional[bool] = None, - rel_comparison: Optional[bool] = None) -> Callable[[type], type]: + *, + number_bcasts_across: Optional[bool] = None, + bcasts_across_obj_array: Optional[bool] = None, + container_types_bcast_across: Optional[Tuple[type, ...]] = None, + arithmetic: bool = True, + matmul: bool = False, + bitwise: bool = False, + shift: bool = False, + _cls_has_array_context_attr: Optional[bool] = None, + eq_comparison: Optional[bool] = None, + rel_comparison: Optional[bool] = None, + + # deprecated: + bcast_number: Optional[bool] = None, + bcast_obj_array: Optional[bool] = None, + bcast_numpy_array: bool = False, + _bcast_actx_array_type: Optional[bool] = None, + bcast_container_types: Optional[Tuple[type, ...]] = None, + ) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. - :arg bcast_number: If *True*, numbers broadcast over the container + :arg number_bcasts_across: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). - :arg bcast_obj_array: If *True*, this container will be broadcast + :arg bcasts_across_obj_array: If *True*, this container will be broadcast across :mod:`numpy` object arrays (with the object array as the 'outer' structure). - Add :class:`numpy.ndarray` to *bcast_container_types* to achieve + Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve the 'reverse' broadcasting. - :arg bcast_container_types: A sequence of container types that will broadcast + :arg container_types_bcast_across: A sequence of container types that will broadcast across this container, with this container as the 'outer' structure. :class:`numpy.ndarray` is permitted to be part of this sequence to - indicate that object arrays (and *only* object arrays) will be broadcasat. - In this case, *bcast_obj_array* must be *False*. + indicate that object arrays (and *only* object arrays) will be broadcast. + In this case, *bcasts_across_obj_array* must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -241,8 +247,71 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): # {{{ handle inputs - if bcast_obj_array is None: - raise TypeError("bcast_obj_array must be specified") + if rel_comparison and eq_comparison is None: + eq_comparison = True + + if eq_comparison is None: + raise TypeError("eq_comparison must be specified") + + # {{{ handle bcast_number + + if bcast_number is not None: + if number_bcasts_across is not None: + raise TypeError( + "may specify at most one of 'bcast_number' and " + "'number_bcasts_across'") + + warn("'bcast_number' is deprecated and will be unsupported from 2025. " + "Use 'number_bcasts_across', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + number_bcasts_across = bcast_number + else: + if number_bcasts_across is None: + number_bcasts_across = True + + del bcast_number + + # }}} + + # {{{ handle bcast_obj_array + + if bcast_obj_array is not None: + if bcasts_across_obj_array is not None: + raise TypeError( + "may specify at most one of 'bcast_obj_array' and " + "'bcasts_across_obj_array'") + + warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. " + "Use 'bcasts_across_obj_array', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + bcasts_across_obj_array = bcast_obj_array + else: + if bcasts_across_obj_array is None: + raise TypeError("bcasts_across_obj_array must be specified") + + del bcast_obj_array + + # }}} + + # {{{ handle bcast_container_types + + if bcast_container_types is not None: + if container_types_bcast_across is not None: + raise TypeError( + "may specify at most one of 'bcast_container_types' and " + "'container_types_bcast_across'") + + warn("'bcast_container_types' is deprecated and will be unsupported from 2025. " + "Use 'container_types_bcast_across', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + container_types_bcast_across = bcast_container_types + else: + if container_types_bcast_across is None: + container_types_bcast_across = () + + del bcast_container_types + + # }}} if rel_comparison is None: raise TypeError("rel_comparison must be specified") @@ -255,36 +324,27 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" " cannot be both set.") - if rel_comparison and eq_comparison is None: - eq_comparison = True - - if eq_comparison is None: - raise TypeError("eq_comparison must be specified") - - if not bcast_obj_array and bcast_numpy_array: + if not bcasts_across_obj_array and bcast_numpy_array: raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") if bcast_numpy_array: def numpy_pred(name: str) -> str: return f"is_numpy_array({name})" - elif bcast_obj_array: + elif bcasts_across_obj_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" else: def numpy_pred(name: str) -> str: return "False" # optimized away - if bcast_container_types is None: - bcast_container_types = () - - if np.ndarray in bcast_container_types and bcast_obj_array: + if np.ndarray in container_types_bcast_across and bcasts_across_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray] - bcast_container_types = tuple( + container_types_bcast_across = tuple( new_ct - for old_ct in bcast_container_types + for old_ct in container_types_bcast_across for new_ct in (numpy_check_types if old_ct is np.ndarray @@ -334,7 +394,7 @@ def wrap(cls: Any) -> Any: if bcast_actx_array_type is None: if cls_has_array_context_attr: - if bcast_number: + if number_bcasts_across: bcast_actx_array_type = cls_has_array_context_attr else: bcast_actx_array_type = False @@ -409,14 +469,14 @@ def is_numpy_array(arg): """) gen("") - if bcast_container_types: - for i, bct in enumerate(bcast_container_types): + if container_types_bcast_across: + for i, bct in enumerate(container_types_bcast_across): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") - outer_bcast_type_names = tuple( - f"_bctype{i}" for i in range(len(bcast_container_types))) - if bcast_number: - outer_bcast_type_names += ("Number",) + container_type_names_bcast_across = tuple( + f"_bctype{i}" for i in range(len(container_types_bcast_across))) + if number_bcasts_across: + container_type_names_bcast_across += ("Number",) def same_key(k1: T, k2: T) -> T: assert k1 == k2 @@ -428,9 +488,14 @@ def tup_str(t: Tuple[str, ...]) -> str: else: return "({},)".format(", ".join(t)) - gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}") + gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}") + gen("cls._container_types_bcast_across = " + f"{tup_str(container_type_names_bcast_across)}") + gen(f"cls._bcast_numpy_array = {bcast_numpy_array}") - gen(f"cls._bcast_obj_array = {bcast_obj_array}") + + gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}") + gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}") gen("") # {{{ unary operators @@ -535,9 +600,9 @@ def {fname}(arg1): result[i] = {op_str.format("arg1", "arg2[i]")} return result - if {bool(outer_bcast_type_names)}: # optimized away + if {bool(container_type_names_bcast_across)}: # optimized away if isinstance(arg2, - {tup_str(outer_bcast_type_names + {tup_str(container_type_names_bcast_across + bcast_actx_ary_types)}): if __debug__: if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): @@ -584,9 +649,9 @@ def {fname}(arg2, arg1): for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result - if {bool(outer_bcast_type_names)}: # optimized away + if {bool(container_type_names_bcast_across)}: # optimized away if isinstance(arg1, - {tup_str(outer_bcast_type_names + {tup_str(container_type_names_bcast_across + bcast_actx_ary_types)}): if __debug__: if isinstance(arg1, diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 107539b4..47d83903 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -117,7 +117,7 @@ def _acf(): # {{{ stand-in DOFArray implementation @with_container_arithmetic( - bcast_obj_array=True, + bcasts_across_obj_array=True, bitwise=True, rel_comparison=True, _cls_has_array_context_attr=True, @@ -208,7 +208,7 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: # {{{ nested containers -@with_container_arithmetic(bcast_obj_array=False, +@with_container_arithmetic(bcasts_across_obj_array=False, eq_comparison=False, rel_comparison=False, _cls_has_array_context_attr=True, _bcast_actx_array_type=False) @@ -231,7 +231,7 @@ def array_context(self): @with_container_arithmetic( - bcast_obj_array=False, + bcasts_across_obj_array=False, bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, @@ -1225,7 +1225,7 @@ def test_norm_ord_none(actx_factory, ndim): # {{{ test_actx_compile helpers -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) @dataclass_array_container @dataclass(frozen=True) class Velocity2D: @@ -1355,7 +1355,7 @@ def test_container_equality(actx_factory): # {{{ test_no_leaf_array_type_broadcasting @with_container_arithmetic( - bcast_obj_array=True, + bcasts_across_obj_array=True, rel_comparison=True, _cls_has_array_context_attr=True, _bcast_actx_array_type=False) @@ -1459,7 +1459,7 @@ def equal(a, b): # {{{ test_array_container_with_numpy -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) @dataclass_array_container @dataclass(frozen=True) class ArrayContainerWithNumpy: From bc323fc50ba3f8c7a3f387d3113c9c6c3805b6cb Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 4 Sep 2024 14:57:54 -0500 Subject: [PATCH 32/32] Numpy actx: cache execuctor --- arraycontext/context.py | 2 +- arraycontext/impl/numpy/__init__.py | 27 +++++++++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index d296f8f7..30f58cb1 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -339,7 +339,7 @@ def to_numpy(self, @abstractmethod def call_loopy(self, - program: "loopy.TranslationUnit", + t_unit: "loopy.TranslationUnit", **kwargs: Any) -> Dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 77b7b49f..f8ba95e3 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + """ .. currentmodule:: arraycontext @@ -30,7 +33,7 @@ THE SOFTWARE. """ -from typing import Any, Dict +from typing import Any import numpy as np @@ -39,6 +42,7 @@ from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ( + Array, ArrayContext, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, @@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext): .. automethod:: __init__ """ + + _loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase] + def __init__(self) -> None: super().__init__() - self._loopy_transform_cache: \ - Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + self._loopy_transform_cache = {} array_types = (NumpyNonObjectArray,) @@ -88,17 +94,18 @@ def to_numpy(self, ) -> NumpyOrContainerOrScalar: return array - def call_loopy(self, t_unit, **kwargs): + def call_loopy( + self, + t_unit: lp.TranslationUnit, **kwargs: Any + ) -> dict[str, Array]: t_unit = t_unit.copy(target=lp.ExecutableCTarget()) try: - t_unit = self._loopy_transform_cache[t_unit] + executor = self._loopy_transform_cache[t_unit] except KeyError: - orig_t_unit = t_unit - t_unit = self.transform_loopy_program(t_unit) - self._loopy_transform_cache[orig_t_unit] = t_unit - del orig_t_unit + executor = self.transform_loopy_program(t_unit).executor() + self._loopy_transform_cache[t_unit] = executor - _, result = t_unit(**kwargs) + _, result = executor(**kwargs) return result