Skip to content

Commit

Permalink
Merge pull request #4 from NOAA-GFDL/fix/break-pyFV3-dependency
Browse files Browse the repository at this point in the history
[fix] Break pyFV3 dependency
  • Loading branch information
fmalatino authored Feb 2, 2024
2 parents f74d364 + d6b9eb7 commit 0e0320c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
10 changes: 9 additions & 1 deletion ndsl/dsl/dace/wrapped_halo_exchange.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import List, Optional
from typing import Any, List, Optional

from ndsl.comm.communicator import Communicator
from ndsl.dsl.dace.orchestration import dace_inhibitor
Expand Down Expand Up @@ -29,6 +29,14 @@ def __init__(
self._qtx_y_names = qty_y_names
self._comm = comm

@staticmethod
def check_for_attribute(state: Any, attr: str):
if dataclasses.is_dataclass(state):
return state.__getattribute__(attr)
elif isinstance(state, dict):
return attr in state.keys()
return False

@dace_inhibitor
def start(self):
if self._qtx_y_names is None:
Expand Down
25 changes: 23 additions & 2 deletions ndsl/grid/eta.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import math
import os
from dataclasses import dataclass

import numpy as np
import xarray as xr


ETA_0 = 0.252
SURFACE_PRESSURE = 1.0e5 # units of (Pa), from Table VI of DCMIP2016


@dataclass
class HybridPressureCoefficients:
"""
Expand Down Expand Up @@ -75,7 +80,23 @@ def set_hybrid_pressure_coefficients(
return pressure_data


def check_eta(ak, bk):
from pyFV3.initialization.init_utils import compute_eta
def vertical_coordinate(eta_value):
"""
Equation (1) JRMS2006
computes eta_v, the auxiliary variable vertical coordinate
"""
return (eta_value - ETA_0) * math.pi * 0.5


def compute_eta(ak, bk):
"""
Equation (1) JRMS2006
eta is the vertical coordinate and eta_v is an auxiliary vertical coordinate
"""
eta = 0.5 * ((ak[:-1] + ak[1:]) / SURFACE_PRESSURE + bk[:-1] + bk[1:])
eta_v = vertical_coordinate(eta)
return eta, eta_v


def check_eta(ak, bk):
return compute_eta(ak, bk)
16 changes: 12 additions & 4 deletions ndsl/stencils/c2l_ord.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
interval,
region,
)
import pyFV3

import ndsl.dsl.gt4py_utils as utils
from ndsl.comm.communicator import Communicator
Expand Down Expand Up @@ -157,7 +156,7 @@ class CubedToLatLon:

def __init__(
self,
state: pyFV3.DycoreState,
state, # No type hint on purpose to remove dependency on pyFV3
stencil_factory: StencilFactory,
quantity_factory: QuantityFactory,
grid_data: GridData,
Expand Down Expand Up @@ -215,8 +214,6 @@ def __init__(
compute_halos=halos,
)

origin = grid_indexing.origin_compute()
shape = grid_indexing.max_shape
if not self.one_rank:
full_size_xyiz_halo_spec = quantity_factory.get_quantity_halo_spec(
dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM],
Expand All @@ -228,6 +225,17 @@ def __init__(
n_halo=grid_indexing.n_halo,
dtype=Float,
)

# TODO:
# To break the depedency to pyFV3 we allow ourselves to not have a type
# hint around state and we check for u and v to make sure we don't
# have bad input.
# This entire code should be retired when WrappedHaloUpdater is no longer
# required.
if not WrappedHaloUpdater.check_for_attribute(
state, "u"
) and WrappedHaloUpdater.check_for_attribute(state, "v"):
raise RuntimeError("Cube To Lat Lon: state given is not readable.")
self.u__v = WrappedHaloUpdater(
comm.get_vector_halo_updater(
[full_size_xyiz_halo_spec], [full_size_xiyz_halo_spec]
Expand Down

0 comments on commit 0e0320c

Please sign in to comment.