Skip to content

Commit

Permalink
Add grib_tree method
Browse files Browse the repository at this point in the history
  • Loading branch information
emfdavid committed Nov 24, 2023
1 parent 7047d14 commit cee6410
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 2 deletions.
199 changes: 199 additions & 0 deletions kerchunk/grib2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import base64
import copy
import logging
from collections import defaultdict
from typing import Iterable, List, Dict, Set

import ujson

try:
import cfgrib
Expand Down Expand Up @@ -354,3 +359,197 @@ def example_combine(
identical_dims=["heightAboveGround", "latitude", "longitude"],
)
return mzz.translate()


def grib_tree(
message_groups: Iterable[Dict],
remote_options=None,
) -> Dict:
"""
Build a hierarchical data model from a set of scanned grib messages. The iterable input groups should
be a collection of results from scan_grib. Multiple grib files can be processed together to produce an
FMRC like collection.
The time (reference_time) and step coordinates will be used as concat_dims in the MultiZarrToZarr
aggregation. Each variable name will become a group with nested subgroups representing the grib
step type and grib level. The resulting hierarchy can be opened as a zarr_group or a xarray datatree.
Grib message variable names that decode as "unknown" are dropped
Grib typeOfLevel attributes that decode as unknown are treated as a single group
Grib steps that are missing due to WrongStepUnitError are patched with NaT
:param message_groups: a collection of zarr store like dictionaries as produced by scan_grib
:param remote_options: remote options to pass to ZarrToMultiZarr
:return: A new zarr store like dictionary for use as a reference filesystem mapper
"""
from kerchunk.combine import MultiZarrToZarr

# Hard code the filters in the correct order for the group hierarchy
filters = ["stepType", "typeOfLevel"]

zarr_store = {}
zroot = zarr.open_group(store=zarr_store)
result = dict(refs=zarr_store)

aggregations: Dict[str, List] = defaultdict(list)
aggregation_dims: Dict[str, Set] = defaultdict(set)

unknown_counter = 0
for msg_ind, group in enumerate(message_groups):
if "version" not in result:
result["version"] = group["version"]

gattrs = ujson.loads(group["refs"][".zattrs"])
coordinates = gattrs["coordinates"].split(" ")

# Find the data variable
vname = None
for key, entry in group["refs"].items():
name = key.split("/")[0]
if name not in [".zattrs", ".zgroup"] and name not in coordinates:
vname = name
break

if vname is None:
raise RuntimeError(
f"Can not find a data var for msg# {msg_ind} in {group['refs'].keys()}"
)

if vname == "unknown":
# To resolve unknown variables add custom grib tables.
# https://confluence.ecmwf.int/display/UDOC/Creating+your+own+local+definitions+-+ecCodes+GRIB+FAQ
# If you process the groups from a single file in order, you can use the msg# to compare with the
# IDX file.
logger.warning(
"Found unknown variable in msg# %s... it will be dropped", msg_ind
)
unknown_counter += 1
continue

logger.debug("Processing vname: %s", vname)
dattrs = ujson.loads(group["refs"][f"{vname}/.zattrs"])
# filter order matters - it determines the hierarchy
gfilters = {}
for key in filters:
attr_val = dattrs.get(f"GRIB_{key}")
if attr_val is None:
continue
if attr_val == "unknown":
logger.warning(
"Found 'unknown' attribute value for key %s in var %s of msg# %s",
key,
vname,
msg_ind,
)
# Use unknown as a group or drop it?

gfilters[key] = attr_val

zgroup = zroot.require_group(vname)
if "name" not in zgroup.attrs:
zgroup.attrs["name"] = dattrs.get("GRIB_name")

for key, value in gfilters.items():
if value: # Ignore empty string and None
# name the group after the attribute values: surface, instant, etc
zgroup = zgroup.require_group(value)
# Add an attribute to give context
zgroup.attrs[key] = value

# add to the list of groups to multi-zarr
aggregations[zgroup.path].append(group)

# keep track of the level coordinate variables and their values
for key, entry in group["refs"].items():
name = key.split("/")[0]
if name == gfilters.get("typeOfLevel") and key.endswith("0"):
if isinstance(entry, list):
entry = tuple(entry)
aggregation_dims[zgroup.path].add(entry)

concat_dims = ["time", "step"]
identical_dims = ["longitude", "latitude"]
for path in aggregations.keys():
# Parallelize this step!
catdims = concat_dims.copy()
idims = identical_dims.copy()

level_dimension_value_count = len(aggregation_dims.get(path, ()))
level_group_name = path.split("/")[-1]
if level_dimension_value_count == 0:
logger.debug(
"Path % has no value coordinate value associated with the level name %s",
path,
level_group_name,
)
elif level_dimension_value_count == 1:
idims.append(level_group_name)
elif level_dimension_value_count > 1:
# The level name should be the last element in the path
catdims.insert(3, level_group_name)

logger.info(
"%s calling MultiZarrToZarr with idims %s and catdims %s",
path,
idims,
catdims,
)

fix_group_step = add_missing_step_var(aggregations[path], path)
mzz = MultiZarrToZarr(
fix_group_step,
remote_options=remote_options,
concat_dims=catdims,
identical_dims=idims,
)
try:
group = mzz.translate()
except KeyError:
import pprint

gstr = pprint.pformat(fix_group_step)
logger.exception(f"Failed to multizarr {path}\n{gstr}")
continue

for key, value in group["refs"].items():
if key not in [".zattrs", ".zgroup"]:
zarr_store[f"{path}/{key}"] = value

return result


def add_missing_step_var(groups: List[dict], path: str) -> List[dict]:
"""
Attempt to fill in missing step var. Should this be done where the step unit error is handled
in scan grib?
:param groups:
:param path:
:return:
"""
result = []
for group in groups:
if "step/.zarray" not in group["refs"]:
group = copy.deepcopy(group)
logger.warning("Adding missing step variable to group path %s", path)
group["refs"]["step/.zarray"] = (
'{"chunks":[],"compressor":null,"dtype":"<f8","fill_value":"NaN","filters":null,"order":"C",'
'"shape":[],"zarr_format":2}'
)
group["refs"]["step/.zattrs"] = (
'{"_ARRAY_DIMENSIONS":[],"long_name":"time since forecast_reference_time",'
'"standard_name":"forecast_period","units":"hours"}'
)

# # Try to set the value - this doesn't work
# import xarray
# fo = fsspec.filesystem("reference", fo=group, mode="rw")
# xd = xarray.open_dataset(fo.get_mapper(), engine="zarr", consolidated=False)
# if np.isnan(xd.step.values[()]):
# logger.info("%s has step val %s", path, xd.step)
# xd.step[()] = xd.valid_time.values - xd.time.values
# xd.close()
# for k, v in group["refs"].items():
# if "step" in k:
# logger.info("New step value %s, %s", k, v)

result.append(group)

return result
18 changes: 16 additions & 2 deletions kerchunk/tests/test_grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
import pytest
import xarray as xr

from kerchunk.grib2 import scan_grib, _split_file, GribToZarr
import zarr
from kerchunk.grib2 import scan_grib, _split_file, GribToZarr, grib_tree

cfgrib = pytest.importorskip("cfgrib")
here = os.path.dirname(__file__)
Expand Down Expand Up @@ -83,3 +83,17 @@ def test_subhourly():
fpath = os.path.join(here, "hrrr.wrfsubhf.sample.grib2")
result = scan_grib(fpath)
assert len(result) == 2, "Expected two grib messages"


def test_grib_tree():
"""
Additional testing here would be good.
Maybe add json files with scan_grib output?
"""
fpath = os.path.join(here, "hrrr.wrfsubhf.sample.grib2")
scanned_msg_groups = scan_grib(fpath)
result = grib_tree(scanned_msg_groups)
fs = fsspec.filesystem("reference", fo=result)
zg = zarr.open_group(fs.get_mapper(""))
isinstance(zg["refc/instant/atmosphere/refc"], zarr.Array)
isinstance(zg["vbdsf/avg/surface/vbdsf"], zarr.Array)

0 comments on commit cee6410

Please sign in to comment.