Skip to content

Commit

Permalink
add NumpyArrayContext subclass and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 5, 2024
1 parent a449133 commit abd4b30
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 37 deletions.
27 changes: 27 additions & 0 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from warnings import warn

from arraycontext import (
NumpyArrayContext as NumpyArrayContextBase,
PyOpenCLArrayContext as PyOpenCLArrayContextBase,
PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase,
)
from arraycontext.pytest import (
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
Expand Down Expand Up @@ -198,6 +200,24 @@ def _transform_with_element_and_dof_inames(t_unit, el_inames, dof_inames):
# }}}


# {{{ numpy array context subclass

class NumpyArrayContext(NumpyArrayContextBase):
def transform_loopy_program(self, t_unit):
default_ep = t_unit.default_entrypoint
options = default_ep.options
if not (options.return_dict and options.no_numpy):
raise ValueError("Loopy kernel passed to call_loopy must "
"have return_dict and no_numpy options set. "
"Did you use arraycontext.make_loopy_program "
"to create this kernel?")

import loopy as lp
return lp.add_inames_for_unused_hw_axes(t_unit)

# }}}


# {{{ pyopencl array context subclass

class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
Expand Down Expand Up @@ -268,6 +288,11 @@ def transform_loopy_program(self, t_unit):

# {{{ pytest actx factory

class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory):
def __call__(self):
return NumpyArrayContext()


class PytestPyOpenCLArrayContextFactory(
_PytestPyOpenCLArrayContextFactoryWithClass):
actx_class = PyOpenCLArrayContext
Expand All @@ -281,6 +306,8 @@ def actx_class(self):
return PytatoPyOpenCLArrayContext


register_pytest_array_context_factory("meshmode.numpy",
PytestNumpyArrayContextFactory)
register_pytest_array_context_factory("meshmode.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("meshmode.pytato_cl",
Expand Down
14 changes: 10 additions & 4 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from meshmode import _acf # noqa: F401
from meshmode.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
Expand All @@ -46,10 +47,11 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestNumpyArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])


@with_container_arithmetic(bcast_obj_array=False,
Expand Down Expand Up @@ -163,6 +165,10 @@ class FooAxisTag2(Tag):
def test_dof_array_pickling_tags(actx_factory):
actx = actx_factory()

from meshmode.array_context import NumpyArrayContext
if isinstance(actx, NumpyArrayContext):
pytest.skip(f"{type(actx).__name__} does not support tags")

from pickle import dumps, loads

state = DOFArray(actx, (
Expand Down
9 changes: 6 additions & 3 deletions test/test_chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
from arraycontext import flatten, pytest_generate_tests_for_array_contexts

from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.dof_array import flat_norm


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


def create_discretization(actx, ndim,
Expand Down
9 changes: 6 additions & 3 deletions test/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization import Discretization
from meshmode.discretization.connection import FACE_RESTR_ALL
from meshmode.discretization.poly_element import (
Expand All @@ -43,8 +45,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


@pytest.mark.parametrize("group_factory", [
Expand Down
9 changes: 6 additions & 3 deletions test/test_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization import Discretization
from meshmode.discretization.poly_element import (
InterpolatoryQuadratureSimplexGroupFactory,
)


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


def test_discr_nodes_caching(actx_factory):
Expand Down
5 changes: 3 additions & 2 deletions test/test_firedrake_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

CLOSE_ATOL = 1e-12

Expand Down
9 changes: 6 additions & 3 deletions test/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
from arraycontext import pytest_generate_tests_for_array_contexts

from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.dof_array import flat_norm


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


@pytest.mark.parametrize("dim", [1, 2, 3])
Expand Down
9 changes: 6 additions & 3 deletions test/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
import meshmode.mesh.io as mio
import meshmode.mesh.processing as mproc
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import (
LegendreGaussLobattoTensorProductGroupFactory,
default_simplex_group_factory,
Expand All @@ -54,8 +56,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
10 changes: 6 additions & 4 deletions test/test_meshmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
Expand All @@ -58,10 +59,11 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestNumpyArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
9 changes: 6 additions & 3 deletions test/test_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization import Discretization
from meshmode.discretization.connection.modal import (
ModalToNodalDiscretizationConnection,
Expand All @@ -51,8 +53,9 @@
from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


@pytest.mark.parametrize("nodal_group_factory", [
Expand Down
9 changes: 6 additions & 3 deletions test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
from arraycontext import flatten, pytest_generate_tests_for_array_contexts, unflatten

from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import default_simplex_group_factory
from meshmode.dof_array import flat_norm
from meshmode.mesh import (
Expand All @@ -45,8 +47,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

# Is there a smart way of choosing this number?
# Currently it is the same as the base from MPIBoundaryCommSetupHelper
Expand Down
9 changes: 6 additions & 3 deletions test/test_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import (
GaussLegendreTensorProductGroupFactory,
InterpolatoryQuadratureSimplexGroupFactory,
Expand All @@ -46,8 +48,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
9 changes: 6 additions & 3 deletions test/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import (
InterpolatoryQuadratureSimplexGroupFactory,
LegendreGaussLobattoTensorProductGroupFactory,
Expand All @@ -43,8 +45,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down

0 comments on commit abd4b30

Please sign in to comment.