Skip to content

Commit

Permalink
Updated import method to mimic other package styles
Browse files Browse the repository at this point in the history
  • Loading branch information
fmalatino committed Feb 28, 2024
1 parent fba0d95 commit 8400c83
Show file tree
Hide file tree
Showing 23 changed files with 113 additions and 127 deletions.
121 changes: 92 additions & 29 deletions ndsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,110 @@
from .checkpointer import Checkpointer, NullCheckpointer, SnapshotCheckpointer
from .comm import (
from .buffer import Buffer
from .checkpointer.base import Checkpointer
from .checkpointer.null import NullCheckpointer
from .checkpointer.snapshots import SnapshotCheckpointer, _Snapshots
from .checkpointer.thresholds import (
InsufficientTrialsError,
SavepointThresholds,
Threshold,
ThresholdCalibrationCheckpointer,
)
from .checkpointer.validation import ValidationCheckpointer
from .comm.boundary import Boundary, SimpleBoundary
from .comm.caching_comm import (
CachingCommData,
CachingCommReader,
CachingCommWriter,
Comm,
Communicator,
ConcurrencyError,
CubedSphereCommunicator,
CubedSpherePartitioner,
LocalComm,
MPIComm,
NullComm,
TileCommunicator,
TilePartitioner,
)
from .dsl import (
CachingRequestReader,
CachingRequestWriter,
NullRequest,
)
from .comm.comm_abc import Comm, Request
from .comm.communicator import Communicator, CubedSphereCommunicator, TileCommunicator
from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm
from .comm.mpi import MPIComm
from .comm.null_comm import NullAsyncResult, NullComm
from .comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner
from .constants import ConstantVersions
from .dsl.caches.codepath import FV3CodePath
from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG
from .dsl.dace.orchestration import orchestrate, orchestrate_function
from .dsl.dace.utils import (
ArrayReport,
DaCeProgress,
MaxBandwithBenchmarkProgram,
StorageReport,
)
from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater
from .dsl.stencil import (
CompareToNumpyStencil,
CompilationConfig,
DaceConfig,
DaCeOrchestration,
FrozenStencil,
GridIndexing,
RunMode,
StencilConfig,
StencilFactory,
WrappedHaloUpdater,
TimingCollector,
)
from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig
from .exceptions import OutOfBoundsError
from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater
from .initialization import GridSizer, QuantityFactory, SubtileGridSizer
from .grid.eta import HybridPressureCoefficients
from .grid.generation import GridDefinition, GridDefinitions, MetricTerms
from .grid.helper import (
AngleGridData,
ContravariantGridData,
DampingCoefficients,
DriverGridData,
GridData,
HorizontalGridData,
VerticalGridData,
)
from .halo.data_transformer import (
HaloDataTransformer,
HaloDataTransformerCPU,
HaloDataTransformerGPU,
HaloExchangeSpec,
)
from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater
from .initialization.allocator import QuantityFactory, StorageNumpy
from .initialization.sizer import GridSizer, SubtileGridSizer
from .logging import ndsl_log
from .monitor import NetCDFMonitor, ZarrMonitor
from .performance import NullTimer, PerformanceCollector, Timer
from .quantity import Quantity, QuantityHaloSpec
from .stencils import (
CubedToLatLon,
Grid,
from .monitor.netcdf_monitor import NetCDFMonitor
from .monitor.protocol import Protocol
from .monitor.zarr_monitor import ZarrMonitor
from .namelist import Namelist
from .optional_imports import RaiseWhenAccessed
from .performance.collector import (
AbstractPerformanceCollector,
NullPerformanceCollector,
PerformanceCollector,
)
from .performance.config import PerformanceConfig
from .performance.profiler import NullProfiler, Profiler
from .performance.report import Experiment, Report, TimeReport
from .performance.timer import NullTimer, Timer
from .quantity import (
BoundaryArrayView,
BoundedArrayView,
Quantity,
QuantityHaloSpec,
QuantityMetadata,
)
from .stencils.c2l_ord import CubedToLatLon
from .stencils.corners import CopyCorners, CopyCornersXY, FillCornersBGrid
from .stencils.testing.grid import Grid # type: ignore
from .stencils.testing.parallel_translate import (
ParallelTranslate,
ParallelTranslate2Py,
ParallelTranslate2PyState,
ParallelTranslateBaseSlicing,
ParallelTranslateGrid,
)
from .stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict
from .stencils.testing.temporaries import assert_same_temporaries, copy_temporaries
from .stencils.testing.translate import (
TranslateFortranData2Py,
TranslateGrid,
pad_field_in_j,
read_serialized_data,
)
from .testing import DummyComm
from .testing.dummy_comm import DummyComm
from .types import Allocator, AsyncRequest, NumpyModule
from .units import UnitsError
from .utils import MetaEnumStr
3 changes: 0 additions & 3 deletions ndsl/checkpointer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .base import Checkpointer
from .null import NullCheckpointer
from .snapshots import SnapshotCheckpointer
7 changes: 0 additions & 7 deletions ndsl/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
from .caching_comm import CachingCommReader, CachingCommWriter
from .comm_abc import Comm
from .communicator import Communicator, CubedSphereCommunicator, TileCommunicator
from .local_comm import ConcurrencyError, LocalComm
from .mpi import MPIComm
from .null_comm import NullComm
from .partitioner import CubedSpherePartitioner, TilePartitioner
11 changes: 0 additions & 11 deletions ndsl/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@

from ndsl.comm.mpi import MPI

from . import dace
from .dace import (
DaceConfig,
DaCeOrchestration,
WrappedHaloUpdater,
orchestrate,
orchestrate_function,
)
from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory
from .stencil_config import CompilationConfig, RunMode, StencilConfig


if MPI is not None:
import os
Expand Down
3 changes: 0 additions & 3 deletions ndsl/dsl/dace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .dace_config import DaceConfig, DaCeOrchestration
from .orchestration import orchestrate, orchestrate_function
from .wrapped_halo_exchange import WrappedHaloUpdater
6 changes: 4 additions & 2 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from gt4py.cartesian import gtscript
from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline

from ndsl import testing
from ndsl.comm.comm_abc import Comm
from ndsl.comm.communicator import Communicator
from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles
Expand All @@ -34,6 +33,9 @@
from ndsl.initialization.sizer import GridSizer, SubtileGridSizer
from ndsl.quantity import Quantity

# from ndsl import testing
from ndsl.testing import comparison


try:
import cupy as cp
Expand Down Expand Up @@ -68,7 +70,7 @@ def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id


def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label) -> str:
metric_err = testing.compare_arr(arg, numpy_arg)
metric_err = comparison.compare_arr(arg, numpy_arg)
nans_match = np.logical_and(np.isnan(arg), np.isnan(numpy_arg))
n_points = np.product(arg.shape)
failures_14 = n_points - np.sum(
Expand Down
13 changes: 0 additions & 13 deletions ndsl/grid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
# flake8: noqa: F401

from .eta import set_hybrid_pressure_coefficients
from .gnomonic import (
great_circle_distance_along_axis,
great_circle_distance_lon_lat,
lon_lat_corner_to_cell_center,
lon_lat_midpoint,
lon_lat_to_xyz,
xyz_midpoint,
xyz_to_lon_lat,
)
from .stretch_transformation import direct_transform
2 changes: 1 addition & 1 deletion ndsl/grid/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ndsl.constants import Z_DIM, Z_INTERFACE_DIM
from ndsl.filesystem import get_fs
from ndsl.grid.generation import MetricTerms
from ndsl.initialization import QuantityFactory
from ndsl.initialization.allocator import QuantityFactory
from ndsl.quantity import Quantity


Expand Down
2 changes: 0 additions & 2 deletions ndsl/halo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from .data_transformer import HaloDataTransformer, HaloExchangeSpec
from .updater import HaloUpdater
2 changes: 0 additions & 2 deletions ndsl/initialization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from .allocator import QuantityFactory
from .sizer import GridSizer, SubtileGridSizer
2 changes: 0 additions & 2 deletions ndsl/monitor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from .netcdf_monitor import NetCDFMonitor
from .zarr_monitor import ZarrMonitor
2 changes: 0 additions & 2 deletions ndsl/performance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from .collector import PerformanceCollector
from .timer import NullTimer, Timer
13 changes: 0 additions & 13 deletions ndsl/stencils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1 @@
from .c2l_ord import CubedToLatLon
from .testing import (
Grid,
ParallelTranslate,
ParallelTranslate2Py,
ParallelTranslate2PyState,
ParallelTranslateBaseSlicing,
ParallelTranslateGrid,
TranslateFortranData2Py,
TranslateGrid,
)


__version__ = "0.2.0"
17 changes: 0 additions & 17 deletions ndsl/stencils/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +0,0 @@
from . import parallel_translate, translate
from .grid import Grid # type: ignore
from .parallel_translate import (
ParallelTranslate,
ParallelTranslate2Py,
ParallelTranslate2PyState,
ParallelTranslateBaseSlicing,
ParallelTranslateGrid,
)
from .savepoint import dataset_to_dict
from .temporaries import assert_same_temporaries, copy_temporaries
from .translate import (
TranslateFortranData2Py,
TranslateGrid,
pad_field_in_j,
read_serialized_data,
)
5 changes: 3 additions & 2 deletions ndsl/stencils/testing/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from ndsl.dsl.stencil import CompilationConfig, StencilConfig
from ndsl.quantity import Quantity
from ndsl.restart._legacy_restart import RESTART_PROPERTIES
from ndsl.stencils.testing import SavepointCase, dataset_to_dict
from ndsl.testing import compare_scalar, perturb, success, success_array
from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict
from ndsl.testing.comparison import compare_scalar, success, success_array
from ndsl.testing.perturbation import perturb


# this only matters for manually-added print statements
Expand Down
3 changes: 0 additions & 3 deletions ndsl/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .comparison import compare_arr, compare_scalar, success, success_array
from .dummy_comm import ConcurrencyError, DummyComm
from .perturbation import perturb
1 change: 0 additions & 1 deletion ndsl/testing/dummy_comm.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from ndsl.comm.local_comm import ConcurrencyError # noqa
from ndsl.comm.local_comm import LocalComm as DummyComm # noqa
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def local_pkg(name: str, relative_path: str) -> str:
"mpi4py",
"cftime",
"xarray",
"f90nml>=1.1.0",
"fsspec",
"netcdf4",
"scipy", # restart capacities only
Expand Down
6 changes: 1 addition & 5 deletions tests/checkpointer/test_thresholds.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import numpy as np
import pytest

from ndsl.checkpointer.thresholds import (
InsufficientTrialsError,
Threshold,
ThresholdCalibrationCheckpointer,
)
from ndsl import InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer


def test_thresholds_no_trials():
Expand Down
7 changes: 2 additions & 5 deletions tests/checkpointer/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import numpy as np
import pytest

from ndsl.checkpointer.thresholds import SavepointThresholds, Threshold
from ndsl.checkpointer.validation import (
ValidationCheckpointer,
_clip_pace_array_to_target,
)
from ndsl import SavepointThresholds, Threshold, ValidationCheckpointer
from ndsl.checkpointer.validation import _clip_pace_array_to_target
from ndsl.optional_imports import xarray as xr


Expand Down
2 changes: 1 addition & 1 deletion tests/dsl/test_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
StencilFactory,
)
from ndsl.comm.mpi import MPI
from ndsl.dsl.dace import orchestrate
from ndsl.dsl.dace.orchestration import orchestrate


def _make_storage(
Expand Down
2 changes: 1 addition & 1 deletion tests/mpi/test_mpi_mock.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest

from ndsl import ConcurrencyError, DummyComm
from ndsl.comm.communicator import recv_buffer
from ndsl.testing import ConcurrencyError, DummyComm
from tests.mpi.mpi_comm import MPI


Expand Down
9 changes: 7 additions & 2 deletions tests/test_halo_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
import numpy as np
import pytest

from ndsl import HaloDataTransformer, HaloExchangeSpec, Quantity, QuantityHaloSpec
from ndsl.buffer import Buffer
from ndsl import (
Buffer,
HaloDataTransformer,
HaloExchangeSpec,
Quantity,
QuantityHaloSpec,
)
from ndsl.comm import _boundary_utils
from ndsl.constants import (
EAST,
Expand Down

0 comments on commit 8400c83

Please sign in to comment.