diff --git a/examples/NDSL/03_orchestration_basics.ipynb b/examples/NDSL/03_orchestration_basics.ipynb index eea2415..01a77dd 100644 --- a/examples/NDSL/03_orchestration_basics.ipynb +++ b/examples/NDSL/03_orchestration_basics.ipynb @@ -37,7 +37,7 @@ ")\n", "from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n", "from ndsl.dsl.typing import FloatField, Float\n", - "from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu" + "from ndsl.boilerplate import get_factories_single_tile_orchestrated" ] }, { @@ -126,7 +126,7 @@ " tile_size = (3, 3, 3)\n", "\n", " # Setup\n", - " stencil_factory, qty_factory = get_factories_single_tile_orchestrated_cpu(\n", + " stencil_factory, qty_factory = get_factories_single_tile_orchestrated(\n", " nx=tile_size[0],\n", " ny=tile_size[1],\n", " nz=tile_size[2],\n", diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index f22e0b9..dece7ce 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -79,8 +79,8 @@ def _get_factories( return stencil_factory, quantity_factory -def get_factories_single_tile_orchestrated_cpu( - nx, ny, nz, nhalo +def get_factories_single_tile_orchestrated( + nx, ny, nz, nhalo, on_cpu: bool = True ) -> Tuple[StencilFactory, QuantityFactory]: """Build a Stencil & Quantity factory for orchestrated CPU, on a single tile topology.""" return _get_factories( @@ -88,22 +88,21 @@ def get_factories_single_tile_orchestrated_cpu( ny=ny, nz=nz, nhalo=nhalo, - backend="dace:cpu", + backend="dace:cpu" if on_cpu else "dace:gpu", orchestration=DaCeOrchestration.BuildAndRun, topology="tile", ) -def get_factories_single_tile_numpy( - nx, ny, nz, nhalo +def get_factories_single_tile( + nx, ny, nz, nhalo, backend: str = "numpy" ) -> Tuple[StencilFactory, QuantityFactory]: - """Build a Stencil & Quantity factory for Numpy, on a single tile topology.""" return _get_factories( nx=nx, ny=ny, nz=nz, nhalo=nhalo, - backend="numpy", + backend=backend, orchestration=DaCeOrchestration.Python, topology="tile", ) diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index a8de407..c0211fb 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -35,10 +35,10 @@ def test_boilerplate_import_numpy(): Dev Note: the import inside the function are part of the test. """ - from ndsl.boilerplate import get_factories_single_tile_numpy + from ndsl.boilerplate import get_factories_single_tile # Boilerplate - stencil_factory, quantity_factory = get_factories_single_tile_numpy( + stencil_factory, quantity_factory = get_factories_single_tile( nx=5, ny=5, nz=2, nhalo=1 ) @@ -50,10 +50,10 @@ def test_boilerplate_import_orchestrated_cpu(): Dev Note: the import inside the function are part of the test. """ - from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu + from ndsl.boilerplate import get_factories_single_tile_orchestrated # Boilerplate - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu( + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( nx=5, ny=5, nz=2, nhalo=1 )