Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zarr-python v3 compatibility #516

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions kerchunk/codecs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import ast
from dataclasses import dataclass
import io
from typing import Self, TYPE_CHECKING

import numcodecs
from numcodecs.abc import Codec
import numpy as np
import threading
import zlib
from zarr.core.array_spec import ArraySpec
from zarr.abc.codec import ArrayBytesCodec
from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer
from zarr.core.common import JSON, parse_enum, parse_named_configuration
from zarr.registry import register_codec


class FillStringsCodec(Codec):
Expand Down Expand Up @@ -115,6 +122,78 @@ def decode(self, buf, out=None):
numcodecs.register_codec(GRIBCodec, "grib")


@dataclass(frozen=True)
class GRIBZarrCodec(ArrayBytesCodec):
eclock = threading.RLock()

var: str
dtype: np.dtype

def __init__(self, *, var: str, dtype: np.dtype) -> None:
object.__setattr__(self, "var", var)
object.__setattr__(self, "dtype", dtype)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(
data, "bytes", require_configuration=True
)
configuration_parsed = configuration_parsed or {}
return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
if self.endian is None:
return {"name": "grib"}
else:
return {
"name": "grib",
"configuration": {"var": self.var, "dtype": self.dtype},
}

async def _decode_single(
self,
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> NDBuffer:
assert isinstance(chunk_bytes, Buffer)
import eccodes

if self.var in ["latitude", "longitude"]:
var = self.var + "s"
dt = self.dtype or "float64"
else:
var = "values"
dt = self.dtype or "float32"

with self.eclock:
mid = eccodes.codes_new_from_message(chunk_bytes.to_bytes())
try:
data = eccodes.codes_get_array(mid, var)
missingValue = eccodes.codes_get_string(mid, "missingValue")
if var == "values" and missingValue:
data[data == float(missingValue)] = np.nan
return data.astype(dt, copy=False)

finally:
eccodes.codes_release(mid)

async def _encode_single(
self,
chunk_array: NDBuffer,
chunk_spec: ArraySpec,
) -> Buffer | None:
# This is a one way codec
raise NotImplementedError

def compute_encoded_size(
self, input_byte_length: int, _chunk_spec: ArraySpec
) -> int:
raise NotImplementedError


register_codec("grib", GRIBZarrCodec)


class AsciiTableCodec(numcodecs.abc.Codec):
"""Decodes ASCII-TABLE extensions in FITS files"""

Expand Down Expand Up @@ -166,7 +245,6 @@ def decode(self, buf, out=None):
arr2 = np.empty((self.nrow,), dtype=dt_out)
heap = buf[arr.nbytes :]
for name in dt_out.names:

if dt_out[name] == "O":
dt = np.dtype(self.ftypes[self.types[name]])
counts = arr[name][:, 0]
Expand Down Expand Up @@ -244,8 +322,7 @@ def encode(self, buf):
class ZlibCodec(Codec):
codec_id = "zlib"

def __init__(self):
...
def __init__(self): ...

def decode(self, data, out=None):
if out:
Expand Down
8 changes: 4 additions & 4 deletions kerchunk/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def append(
ds = xr.open_dataset(
fs.get_mapper(), engine="zarr", backend_kwargs={"consolidated": False}
)
z = zarr.open(fs.get_mapper())
z = zarr.open(fs.get_mapper(), zarr_format=2)
mzz = MultiZarrToZarr(
path,
out=fs.references, # dict or parquet/lazy
Expand Down Expand Up @@ -360,7 +360,7 @@ def first_pass(self):
fs._dircache_from_items()

logger.debug("First pass: %s", i)
z = zarr.open_group(fs.get_mapper(""))
z = zarr.open_group(fs.get_mapper(""), zarr_format=2)
for var in self.concat_dims:
value = self._get_value(i, z, var, fn=self._paths[i])
if isinstance(value, np.ndarray):
Expand All @@ -387,7 +387,7 @@ def store_coords(self):
"""
kv = {}
store = zarr.storage.KVStore(kv)
group = zarr.open(store)
group = zarr.open(store, zarr_format=2)
m = self.fss[0].get_mapper("")
z = zarr.open(m)
for k, v in self.coos.items():
Expand Down Expand Up @@ -461,7 +461,7 @@ def second_pass(self):
for i, fs in enumerate(self.fss):
to_download = {}
m = fs.get_mapper("")
z = zarr.open(m)
z = zarr.open(m, zarr_format=2)

if no_deps is None:
# done first time only
Expand Down
7 changes: 4 additions & 3 deletions kerchunk/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fsspec.implementations.reference import LazyReferenceMapper


from kerchunk.utils import class_factory
from kerchunk.utils import class_factory, dict_to_store
from kerchunk.codecs import AsciiTableCodec, VarArrCodec

try:
Expand Down Expand Up @@ -72,7 +72,8 @@ def process_file(

storage_options = storage_options or {}
out = out or {}
g = zarr.open(out)
store = dict_to_store(out)
g = zarr.open_group(store=store, zarr_format=2)

with fsspec.open(url, mode="rb", **storage_options) as f:
infile = fits.open(f, do_not_scale_image_data=True)
Expand Down Expand Up @@ -164,7 +165,7 @@ def process_file(
# TODO: we could sub-chunk on biggest dimension
name = hdu.name or str(ext)
arr = g.empty(
name, dtype=dtype, shape=shape, chunks=shape, compression=None, **kwargs
name=name, dtype=dtype, shape=shape, chunks=shape, compressor=None, zarr_format=2, **kwargs
)
arr.attrs.update(
{
Expand Down
31 changes: 17 additions & 14 deletions kerchunk/grib2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray
import numpy as np

from kerchunk.utils import class_factory, _encode_for_JSON
from kerchunk.utils import class_factory, _encode_for_JSON, dict_to_store, translate_refs_serializable
from kerchunk.codecs import GRIBCodec
from kerchunk.combine import MultiZarrToZarr, drop
from kerchunk._grib_idx import parse_grib_idx, build_idx_grib_mapping, map_from_index
Expand Down Expand Up @@ -71,13 +71,13 @@ def _store_array(store, z, data, var, inline_threshold, offset, size, attr):
shape = tuple(data.shape or ())
if nbytes < inline_threshold:
logger.debug(f"Store {var} inline")
d = z.create_dataset(
d = z.create_array(
name=var,
shape=shape,
chunks=shape,
dtype=data.dtype,
fill_value=attr.get("missingValue", None),
compressor=False,
compressor=None,
)
if hasattr(data, "tobytes"):
b = data.tobytes()
Expand All @@ -91,15 +91,14 @@ def _store_array(store, z, data, var, inline_threshold, offset, size, attr):
store[f"{var}/0"] = b.decode("ascii")
else:
logger.debug(f"Store {var} reference")
d = z.create_dataset(
d = z.create_array(
name=var,
shape=shape,
chunks=shape,
dtype=data.dtype,
fill_value=attr.get("missingValue", None),
filters=[GRIBCodec(var=var, dtype=str(data.dtype))],
compressor=False,
overwrite=True,
compressor=None,
)
store[f"{var}/" + ".".join(["0"] * len(shape))] = ["{{u}}", offset, size]
d.attrs.update(attr)
Expand Down Expand Up @@ -153,7 +152,9 @@ def scan_grib(
with fsspec.open(url, "rb", **storage_options) as f:
logger.debug(f"File {url}")
for offset, size, data in _split_file(f, skip=skip):
store = {}
store_dict = {}
store = dict_to_store(store_dict)

mid = eccodes.codes_new_from_message(data)
m = cfgrib.cfmessage.CfMessage(mid)

Expand Down Expand Up @@ -191,7 +192,7 @@ def scan_grib(
if good is False:
continue

z = zarr.open_group(store)
z = zarr.open_group(store, zarr_format=2)
global_attrs = {
f"GRIB_{k}": m[k]
for k in cfgrib.dataset.GLOBAL_ATTRIBUTES_KEYS
Expand Down Expand Up @@ -227,7 +228,7 @@ def scan_grib(
varName = m["cfVarName"]
if varName in ("undef", "unknown"):
varName = m["shortName"]
_store_array(store, z, vals, varName, inline_threshold, offset, size, attrs)
_store_array(store_dict, z, vals, varName, inline_threshold, offset, size, attrs)
if "typeOfLevel" in message_keys and "level" in message_keys:
name = m["typeOfLevel"]
coordinates.append(name)
Expand All @@ -241,7 +242,7 @@ def scan_grib(
attrs = {}
attrs["_ARRAY_DIMENSIONS"] = []
_store_array(
store, z, data, name, inline_threshold, offset, size, attrs
store_dict, z, data, name, inline_threshold, offset, size, attrs
)
dims = (
["y", "x"]
Expand Down Expand Up @@ -298,7 +299,7 @@ def scan_grib(
dims = [coord]
attrs = cfgrib.dataset.COORD_ATTRS[coord]
_store_array(
store,
store_dict,
z,
x,
coord,
Expand All @@ -311,10 +312,11 @@ def scan_grib(
if coordinates:
z.attrs["coordinates"] = " ".join(coordinates)

translate_refs_serializable(store_dict)
out.append(
{
"version": 1,
"refs": _encode_for_JSON(store),
"refs": _encode_for_JSON(store_dict),
"templates": {"u": url},
}
)
Expand Down Expand Up @@ -397,8 +399,9 @@ def grib_tree(
filters = ["stepType", "typeOfLevel"]

# TODO allow passing a LazyReferenceMapper as output?
zarr_store = {}
zroot = zarr.open_group(store=zarr_store)
zarr_store_dict = {}
zarr_store = dict_to_store(zarr_store_dict)
zroot = zarr.open_group(store=zarr_store, zarr_format=2)

aggregations: Dict[str, List] = defaultdict(list)
aggregation_dims: Dict[str, Set] = defaultdict(set)
Expand Down
Loading
Loading