diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index add7488d..78a68fa4 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Optional +from typing import Any, List, Optional from ndsl.comm.communicator import Communicator from ndsl.dsl.dace.orchestration import dace_inhibitor @@ -29,6 +29,14 @@ def __init__( self._qtx_y_names = qty_y_names self._comm = comm + @staticmethod + def check_for_attribute(state: Any, attr: str): + if dataclasses.is_dataclass(state): + return state.__getattribute__(attr) + elif isinstance(state, dict): + return attr in state.keys() + return False + @dace_inhibitor def start(self): if self._qtx_y_names is None: diff --git a/ndsl/grid/eta.py b/ndsl/grid/eta.py index f804b3f5..19663fde 100644 --- a/ndsl/grid/eta.py +++ b/ndsl/grid/eta.py @@ -1,3 +1,4 @@ +import math import os from dataclasses import dataclass @@ -5,6 +6,10 @@ import xarray as xr +ETA_0 = 0.252 +SURFACE_PRESSURE = 1.0e5 # units of (Pa), from Table VI of DCMIP2016 + + @dataclass class HybridPressureCoefficients: """ @@ -75,7 +80,23 @@ def set_hybrid_pressure_coefficients( return pressure_data -def check_eta(ak, bk): - from pyFV3.initialization.init_utils import compute_eta +def vertical_coordinate(eta_value): + """ + Equation (1) JRMS2006 + computes eta_v, the auxiliary variable vertical coordinate + """ + return (eta_value - ETA_0) * math.pi * 0.5 + +def compute_eta(ak, bk): + """ + Equation (1) JRMS2006 + eta is the vertical coordinate and eta_v is an auxiliary vertical coordinate + """ + eta = 0.5 * ((ak[:-1] + ak[1:]) / SURFACE_PRESSURE + bk[:-1] + bk[1:]) + eta_v = vertical_coordinate(eta) + return eta, eta_v + + +def check_eta(ak, bk): return compute_eta(ak, bk) diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index a10b797b..4e18c1ff 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -6,7 +6,6 @@ interval, region, ) -import pyFV3 import ndsl.dsl.gt4py_utils as utils from ndsl.comm.communicator import Communicator @@ -157,7 +156,7 @@ class CubedToLatLon: def __init__( self, - state: pyFV3.DycoreState, + state, # No type hint on purpose to remove dependency on pyFV3 stencil_factory: StencilFactory, quantity_factory: QuantityFactory, grid_data: GridData, @@ -215,8 +214,6 @@ def __init__( compute_halos=halos, ) - origin = grid_indexing.origin_compute() - shape = grid_indexing.max_shape if not self.one_rank: full_size_xyiz_halo_spec = quantity_factory.get_quantity_halo_spec( dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], @@ -228,6 +225,17 @@ def __init__( n_halo=grid_indexing.n_halo, dtype=Float, ) + + # TODO: + # To break the depedency to pyFV3 we allow ourselves to not have a type + # hint around state and we check for u and v to make sure we don't + # have bad input. + # This entire code should be retired when WrappedHaloUpdater is no longer + # required. + if not WrappedHaloUpdater.check_for_attribute( + state, "u" + ) and WrappedHaloUpdater.check_for_attribute(state, "v"): + raise RuntimeError("Cube To Lat Lon: state given is not readable.") self.u__v = WrappedHaloUpdater( comm.get_vector_halo_updater( [full_size_xyiz_halo_spec], [full_size_xiyz_halo_spec]