Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try out NumpyArrayContext #427

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from warnings import warn

from arraycontext import (
NumpyArrayContext as NumpyArrayContextBase,
PyOpenCLArrayContext as PyOpenCLArrayContextBase,
PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase,
)
from arraycontext.pytest import (
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
Expand Down Expand Up @@ -198,6 +200,24 @@ def _transform_with_element_and_dof_inames(t_unit, el_inames, dof_inames):
# }}}


# {{{ numpy array context subclass

class NumpyArrayContext(NumpyArrayContextBase):
def transform_loopy_program(self, t_unit):
default_ep = t_unit.default_entrypoint
options = default_ep.options
if not (options.return_dict and options.no_numpy):
raise ValueError("Loopy kernel passed to call_loopy must "
"have return_dict and no_numpy options set. "
"Did you use arraycontext.make_loopy_program "
"to create this kernel?")

import loopy as lp
return lp.add_inames_for_unused_hw_axes(t_unit)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an attempt to silence some loopy errors. Not sure if it's a good idea..


# }}}


# {{{ pyopencl array context subclass

class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
Expand Down Expand Up @@ -268,6 +288,11 @@ def transform_loopy_program(self, t_unit):

# {{{ pytest actx factory

class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory):
def __call__(self):
return NumpyArrayContext()


class PytestPyOpenCLArrayContextFactory(
_PytestPyOpenCLArrayContextFactoryWithClass):
actx_class = PyOpenCLArrayContext
Expand All @@ -281,6 +306,8 @@ def actx_class(self):
return PytatoPyOpenCLArrayContext


register_pytest_array_context_factory("meshmode.numpy",
PytestNumpyArrayContextFactory)
register_pytest_array_context_factory("meshmode.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("meshmode.pytato_cl",
Expand Down
14 changes: 10 additions & 4 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from meshmode import _acf # noqa: F401
from meshmode.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
Expand All @@ -46,10 +47,11 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestNumpyArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])


@with_container_arithmetic(bcast_obj_array=False,
Expand Down Expand Up @@ -163,6 +165,10 @@ class FooAxisTag2(Tag):
def test_dof_array_pickling_tags(actx_factory):
actx = actx_factory()

from meshmode.array_context import NumpyArrayContext
if isinstance(actx, NumpyArrayContext):
pytest.skip(f"{type(actx).__name__} does not support tags")

from pickle import dumps, loads

state = DOFArray(actx, (
Expand Down
9 changes: 6 additions & 3 deletions test/test_chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
from arraycontext import flatten, pytest_generate_tests_for_array_contexts

from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.dof_array import flat_norm


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


def create_discretization(actx, ndim,
Expand Down
9 changes: 6 additions & 3 deletions test/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization import Discretization
from meshmode.discretization.connection import FACE_RESTR_ALL
from meshmode.discretization.poly_element import (
Expand All @@ -43,8 +45,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


@pytest.mark.parametrize("group_factory", [
Expand Down
9 changes: 6 additions & 3 deletions test/test_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization import Discretization
from meshmode.discretization.poly_element import (
InterpolatoryQuadratureSimplexGroupFactory,
)


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


def test_discr_nodes_caching(actx_factory):
Expand Down
5 changes: 3 additions & 2 deletions test/test_firedrake_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

CLOSE_ATOL = 1e-12

Expand Down
9 changes: 6 additions & 3 deletions test/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
from arraycontext import pytest_generate_tests_for_array_contexts

from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.dof_array import flat_norm


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


@pytest.mark.parametrize("dim", [1, 2, 3])
Expand Down
9 changes: 6 additions & 3 deletions test/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
import meshmode.mesh.io as mio
import meshmode.mesh.processing as mproc
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import (
LegendreGaussLobattoTensorProductGroupFactory,
default_simplex_group_factory,
Expand All @@ -54,8 +56,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
10 changes: 6 additions & 4 deletions test/test_meshmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
Expand All @@ -58,10 +59,11 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestNumpyArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
9 changes: 6 additions & 3 deletions test/test_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization import Discretization
from meshmode.discretization.connection.modal import (
ModalToNodalDiscretizationConnection,
Expand All @@ -51,8 +53,9 @@
from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])


@pytest.mark.parametrize("nodal_group_factory", [
Expand Down
9 changes: 6 additions & 3 deletions test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
from arraycontext import flatten, pytest_generate_tests_for_array_contexts, unflatten

from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import default_simplex_group_factory
from meshmode.dof_array import flat_norm
from meshmode.mesh import (
Expand All @@ -45,8 +47,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

# Is there a smart way of choosing this number?
# Currently it is the same as the base from MPIBoundaryCommSetupHelper
Expand Down
9 changes: 6 additions & 3 deletions test/test_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import (
GaussLegendreTensorProductGroupFactory,
InterpolatoryQuadratureSimplexGroupFactory,
Expand All @@ -46,8 +48,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
9 changes: 6 additions & 3 deletions test/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

import meshmode.mesh.generation as mgen
from meshmode import _acf # noqa: F401
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from meshmode.array_context import (
PytestPyOpenCLArrayContextFactory,
)
from meshmode.discretization.poly_element import (
InterpolatoryQuadratureSimplexGroupFactory,
LegendreGaussLobattoTensorProductGroupFactory,
Expand All @@ -43,8 +45,9 @@


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

thisdir = pathlib.Path(__file__).parent

Expand Down
Loading