Skip to content

Commit

Permalink
Make boilerplate calls more nimble
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Nov 12, 2024
1 parent aed5912 commit 3347431
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/NDSL/03_orchestration_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
")\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n",
"from ndsl.dsl.typing import FloatField, Float\n",
"from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu"
"from ndsl.boilerplate import get_factories_single_tile_orchestrated"
]
},
{
Expand Down Expand Up @@ -126,7 +126,7 @@
" tile_size = (3, 3, 3)\n",
"\n",
" # Setup\n",
" stencil_factory, qty_factory = get_factories_single_tile_orchestrated_cpu(\n",
" stencil_factory, qty_factory = get_factories_single_tile_orchestrated(\n",
" nx=tile_size[0],\n",
" ny=tile_size[1],\n",
" nz=tile_size[2],\n",
Expand Down
13 changes: 6 additions & 7 deletions ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,30 @@ def _get_factories(
return stencil_factory, quantity_factory


def get_factories_single_tile_orchestrated_cpu(
nx, ny, nz, nhalo
def get_factories_single_tile_orchestrated(
nx, ny, nz, nhalo, on_cpu: bool = True
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for orchestrated CPU, on a single tile topology."""
return _get_factories(
nx=nx,
ny=ny,
nz=nz,
nhalo=nhalo,
backend="dace:cpu",
backend="dace:cpu" if on_cpu else "dace:gpu",
orchestration=DaCeOrchestration.BuildAndRun,
topology="tile",
)


def get_factories_single_tile_numpy(
nx, ny, nz, nhalo
def get_factories_single_tile(
nx, ny, nz, nhalo, backend: str = "numpy"
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for Numpy, on a single tile topology."""
return _get_factories(
nx=nx,
ny=ny,
nz=nz,
nhalo=nhalo,
backend="numpy",
backend=backend,
orchestration=DaCeOrchestration.Python,
topology="tile",
)
8 changes: 4 additions & 4 deletions tests/test_boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_boilerplate_import_numpy():
Dev Note: the import inside the function are part of the test.
"""
from ndsl.boilerplate import get_factories_single_tile_numpy
from ndsl.boilerplate import get_factories_single_tile

# Boilerplate
stencil_factory, quantity_factory = get_factories_single_tile_numpy(
stencil_factory, quantity_factory = get_factories_single_tile(
nx=5, ny=5, nz=2, nhalo=1
)

Expand All @@ -50,10 +50,10 @@ def test_boilerplate_import_orchestrated_cpu():
Dev Note: the import inside the function are part of the test.
"""
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.boilerplate import get_factories_single_tile_orchestrated

# Boilerplate
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
nx=5, ny=5, nz=2, nhalo=1
)

Expand Down

0 comments on commit 3347431

Please sign in to comment.