diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1eceb497..26535185 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -35,6 +35,7 @@ from typing import Any, Callable, Dict, Sequence, Type, Union from arraycontext.context import ArrayContext +from arraycontext import NumpyArrayContext # {{{ array context factories @@ -195,6 +196,26 @@ def __str__(self): return "" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -203,6 +224,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108e..0975d5ce 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -45,6 +45,8 @@ _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory) + _PytestPytatoPyOpenCLArrayContextFactory, + _PytestNumpyArrayContextFactory) import logging @@ -93,6 +95,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ])