Skip to content

Commit

Permalink
test NumpyArrayContext
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Sep 1, 2022
1 parent ce8ab7c commit d22dddc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
22 changes: 22 additions & 0 deletions arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from typing import Any, Callable, Dict, Sequence, Type, Union

from arraycontext.context import ArrayContext
from arraycontext import NumpyArrayContext


# {{{ array context factories
Expand Down Expand Up @@ -195,6 +196,26 @@ def __str__(self):
return "<PytatoJAXArrayContext>"


# {{{ _PytestArrayContextFactory

class _NumpyArrayContextForTests(NumpyArrayContext):
def transform_loopy_program(self, t_unit):
return t_unit


class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
super().__init__()

def __call__(self):
return _NumpyArrayContextForTests()

def __str__(self):
return "<NumpyArrayContext>"

# }}}


_ARRAY_CONTEXT_FACTORY_REGISTRY: \
Dict[str, Type[PytestArrayContextFactory]] = {
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
Expand All @@ -203,6 +224,7 @@ def __str__(self):
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
"eagerjax": _PytestEagerJaxArrayContextFactory,
"numpy": _PytestNumpyArrayContextFactory,
}


Expand Down
3 changes: 3 additions & 0 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
_PytestPytatoPyOpenCLArrayContextFactory,
_PytestEagerJaxArrayContextFactory,
_PytestPytatoJaxArrayContextFactory)
_PytestPytatoPyOpenCLArrayContextFactory,
_PytestNumpyArrayContextFactory)


import logging
Expand Down Expand Up @@ -93,6 +95,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory(
_PytatoPyOpenCLArrayContextForTestsFactory,
_PytestEagerJaxArrayContextFactory,
_PytestPytatoJaxArrayContextFactory,
_PytestNumpyArrayContextFactory,
])


Expand Down

0 comments on commit d22dddc

Please sign in to comment.