diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 76661526..5f714342 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -28,10 +28,12 @@ from warnings import warn from arraycontext import ( + NumpyArrayContext as NumpyArrayContextBase, PyOpenCLArrayContext as PyOpenCLArrayContextBase, PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase, ) from arraycontext.pytest import ( + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, @@ -198,6 +200,24 @@ def _transform_with_element_and_dof_inames(t_unit, el_inames, dof_inames): # }}} +# {{{ numpy array context subclass + +class NumpyArrayContext(NumpyArrayContextBase): + def transform_loopy_program(self, t_unit): + default_ep = t_unit.default_entrypoint + options = default_ep.options + if not (options.return_dict and options.no_numpy): + raise ValueError("Loopy kernel passed to call_loopy must " + "have return_dict and no_numpy options set. " + "Did you use arraycontext.make_loopy_program " + "to create this kernel?") + + import loopy as lp + return lp.add_inames_for_unused_hw_axes(t_unit) + +# }}} + + # {{{ pyopencl array context subclass class PyOpenCLArrayContext(PyOpenCLArrayContextBase): @@ -268,6 +288,11 @@ def transform_loopy_program(self, t_unit): # {{{ pytest actx factory +class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory): + def __call__(self): + return NumpyArrayContext() + + class PytestPyOpenCLArrayContextFactory( _PytestPyOpenCLArrayContextFactoryWithClass): actx_class = PyOpenCLArrayContext @@ -281,6 +306,8 @@ def actx_class(self): return PytatoPyOpenCLArrayContext +register_pytest_array_context_factory("meshmode.numpy", + PytestNumpyArrayContextFactory) register_pytest_array_context_factory("meshmode.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("meshmode.pytato_cl", diff --git a/test/test_array.py b/test/test_array.py index ed5e0684..e45e213e 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -37,6 +37,7 @@ from meshmode import _acf # noqa: F401 from meshmode.array_context import ( + PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -46,10 +47,11 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPytatoPyOpenCLArrayContextFactory, - PytestPyOpenCLArrayContextFactory, - ]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestNumpyArrayContextFactory, + PytestPytatoPyOpenCLArrayContextFactory, + PytestPyOpenCLArrayContextFactory, + ]) @with_container_arithmetic(bcast_obj_array=False, @@ -163,6 +165,10 @@ class FooAxisTag2(Tag): def test_dof_array_pickling_tags(actx_factory): actx = actx_factory() + from meshmode.array_context import NumpyArrayContext + if isinstance(actx, NumpyArrayContext): + pytest.skip(f"{type(actx).__name__} does not support tags") + from pickle import dumps, loads state = DOFArray(actx, ( diff --git a/test/test_chained.py b/test/test_chained.py index 45b36620..2681bfb2 100644 --- a/test/test_chained.py +++ b/test/test_chained.py @@ -28,13 +28,16 @@ from arraycontext import flatten, pytest_generate_tests_for_array_contexts from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.dof_array import flat_norm logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) def create_discretization(actx, ndim, diff --git a/test/test_connection.py b/test/test_connection.py index 08ee8a4e..b349f839 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -29,7 +29,9 @@ import meshmode.mesh.generation as mgen from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization import Discretization from meshmode.discretization.connection import FACE_RESTR_ALL from meshmode.discretization.poly_element import ( @@ -43,8 +45,9 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) @pytest.mark.parametrize("group_factory", [ diff --git a/test/test_discretization.py b/test/test_discretization.py index 8c4361c3..b3908ef1 100644 --- a/test/test_discretization.py +++ b/test/test_discretization.py @@ -26,15 +26,18 @@ import meshmode.mesh.generation as mgen from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization import Discretization from meshmode.discretization.poly_element import ( InterpolatoryQuadratureSimplexGroupFactory, ) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) def test_discr_nodes_caching(actx_factory): diff --git a/test/test_firedrake_interop.py b/test/test_firedrake_interop.py index 6aaaec02..32eb5721 100644 --- a/test/test_firedrake_interop.py +++ b/test/test_firedrake_interop.py @@ -43,8 +43,9 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) CLOSE_ATOL = 1e-12 diff --git a/test/test_interop.py b/test/test_interop.py index 22956941..df7ff9db 100644 --- a/test/test_interop.py +++ b/test/test_interop.py @@ -28,13 +28,16 @@ from arraycontext import pytest_generate_tests_for_array_contexts from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.dof_array import flat_norm logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) @pytest.mark.parametrize("dim", [1, 2, 3]) diff --git a/test/test_mesh.py b/test/test_mesh.py index c471e095..dbcdbb44 100644 --- a/test/test_mesh.py +++ b/test/test_mesh.py @@ -38,7 +38,9 @@ import meshmode.mesh.io as mio import meshmode.mesh.processing as mproc from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization.poly_element import ( LegendreGaussLobattoTensorProductGroupFactory, default_simplex_group_factory, @@ -54,8 +56,9 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) thisdir = pathlib.Path(__file__).parent diff --git a/test/test_meshmode.py b/test/test_meshmode.py index f264b646..56c1ac36 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -33,6 +33,7 @@ import meshmode.mesh.generation as mgen from meshmode import _acf # noqa: F401 from meshmode.array_context import ( + PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -58,10 +59,11 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPytatoPyOpenCLArrayContextFactory, - PytestPyOpenCLArrayContextFactory, - ]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestNumpyArrayContextFactory, + PytestPytatoPyOpenCLArrayContextFactory, + PytestPyOpenCLArrayContextFactory, + ]) thisdir = pathlib.Path(__file__).parent diff --git a/test/test_modal.py b/test/test_modal.py index c5e39f19..b92a53a9 100644 --- a/test/test_modal.py +++ b/test/test_modal.py @@ -30,7 +30,9 @@ import meshmode.mesh.generation as mgen from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization import Discretization from meshmode.discretization.connection.modal import ( ModalToNodalDiscretizationConnection, @@ -51,8 +53,9 @@ from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) @pytest.mark.parametrize("nodal_group_factory", [ diff --git a/test/test_partition.py b/test/test_partition.py index e179aef3..b5ffdaec 100644 --- a/test/test_partition.py +++ b/test/test_partition.py @@ -33,7 +33,9 @@ from arraycontext import flatten, pytest_generate_tests_for_array_contexts, unflatten from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization.poly_element import default_simplex_group_factory from meshmode.dof_array import flat_norm from meshmode.mesh import ( @@ -45,8 +47,9 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) # Is there a smart way of choosing this number? # Currently it is the same as the base from MPIBoundaryCommSetupHelper diff --git a/test/test_refinement.py b/test/test_refinement.py index f265d8c3..84625c8d 100644 --- a/test/test_refinement.py +++ b/test/test_refinement.py @@ -31,7 +31,9 @@ import meshmode.mesh.generation as mgen from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization.poly_element import ( GaussLegendreTensorProductGroupFactory, InterpolatoryQuadratureSimplexGroupFactory, @@ -46,8 +48,9 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) thisdir = pathlib.Path(__file__).parent diff --git a/test/test_visualization.py b/test/test_visualization.py index 0cd0f8fb..a21c8a99 100644 --- a/test/test_visualization.py +++ b/test/test_visualization.py @@ -33,7 +33,9 @@ import meshmode.mesh.generation as mgen from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PytestPyOpenCLArrayContextFactory, +) from meshmode.discretization.poly_element import ( InterpolatoryQuadratureSimplexGroupFactory, LegendreGaussLobattoTensorProductGroupFactory, @@ -43,8 +45,9 @@ logger = logging.getLogger(__name__) -pytest_generate_tests = pytest_generate_tests_for_array_contexts( - [PytestPyOpenCLArrayContextFactory]) +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) thisdir = pathlib.Path(__file__).parent