Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concatenate arrays with varchunks #374

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions kerchunk/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,8 @@ def concatenate_arrays(
Input reference sets, maybe generated by ``kerchunk.zarr.single_zarr``
storage_options: dict | None
To create the filesystems, such at target/remote protocol and target/remote options
axis
Axis to concatenate along.
key_seperator: str
"." or "/", how the zarr keys are stored
path: str or None
Expand All @@ -589,15 +591,43 @@ def _replace(l: list, i: int, v) -> list:
return l

n_files = len(files)
fss = [
fsspec.filesystem("reference", fo=fn, **(storage_options or {})) for fn in files
]
zarrays = [ujson.load(fs.open(f"{path}.zarray")) for fs in fss]

# Determine chunk type
_prev_chunks = zarrays[0]["chunks"][axis]
for zarray in zarrays:
axis_chunks = zarray["chunks"][axis]
if isinstance(axis_chunks, list) or (axis_chunks != _prev_chunks):
chunk_type = "variable"
break
_prev_chunks = axis_chunks
else:
chunk_type = "fixed"

chunks_offset = 0
for i, fn in enumerate(files):
fs = fsspec.filesystem("reference", fo=fn, **(storage_options or {}))
zarray = ujson.load(fs.open(f"{path}.zarray"))
for i, (fs, zarray) in enumerate(zip(fss, zarrays)):
shape = zarray["shape"]
chunks = zarray["chunks"]
n_chunks, rem = divmod(shape[axis], chunks[axis])
n_chunks += rem > 0

if isinstance(chunks[axis], int):
n_chunks, rem = divmod(shape[axis], chunks[axis])
n_chunks += rem > 0
if chunk_type == "variable":
if rem != 0:
raise ValueError(
"Cannot handle padded chunks when creating variably chunked arrays. "
f"{i}-th array has {rem} padding for the last chunk."
)
assert rem == 0
chunks[axis] = [chunks[axis] for _ in range(n_chunks)]
else:
n_chunks = len(chunks[axis])
rem = 0
# Can this ever work?
# rem = sum(chunks[axis]) - shape[axis]

if i == 0:
base_shape = _replace(shape, axis, None)
Expand All @@ -610,6 +640,13 @@ def _replace(l: list, i: int, v) -> list:
out[name] = fs.references[name]
else:
result_shape[axis] += shape[axis]
if chunk_type == "fixed":
assert isinstance(chunks[axis], int)
elif chunk_type == "variable":
assert isinstance(chunks[axis], list)
result_zarray["chunks"][axis] = (
result_zarray["chunks"][axis] + chunks[axis]
)

# Safety checks
if check_arrays:
Expand All @@ -620,10 +657,22 @@ def _replace(l: list, i: int, v) -> list:
raise ValueError(
f"Incompatible array shape at index {i}. Expected {expected_shape}, got {shape}."
)
if chunks != base_chunks:
raise ValueError(
f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}."
)
if chunk_type == "fixed":
if chunks != base_chunks:
raise ValueError(
f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}."
)
else:
assert chunk_type == "variable"
base_chunks_to_compare = [
v for i, v in enumerate(base_chunks) if i != axis
]
chunks_to_compare = [v for i, v in enumerate(chunks) if i != axis]
# TODO: deduplicate
if chunks_to_compare != base_chunks_to_compare:
raise ValueError(
f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}."
)
if i < (n_files - 1) and rem != 0:
raise ValueError(
f"Array at index {i} has irregular chunking at its boundary. "
Expand Down
116 changes: 115 additions & 1 deletion kerchunk/tests/test_combine_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def test_fail_chunks(tmpdir):
ref1 = kerchunk.zarr.single_zarr(fn1, inline=0)
ref2 = kerchunk.zarr.single_zarr(fn2, inline=0)

with pytest.raises(ValueError, match=r"Incompatible array chunks at index 1.*"):
with pytest.raises(
ValueError,
match=r"Cannot handle padded chunks when creating variably chunked arrays.",
):
kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True)


Expand Down Expand Up @@ -139,3 +142,114 @@ def test_fail_irregular_chunk_boundaries(tmpdir):

with pytest.raises(ValueError, match=r"Array at index 0 has irregular chunking.*"):
kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True)


@pytest.mark.parametrize(
"arrays,axis,expected_chunks",
[
(
[
zarr.array(np.arange(10), chunks=([2, 2, 2, 2, 2],)),
zarr.array(np.arange(10, 20), chunks=([3, 3, 3, 1],)),
],
0,
([2, 2, 2, 2, 2, 3, 3, 3, 1],),
),
(
[
zarr.array(
np.broadcast_to(np.arange(6), (10, 6)), chunks=([5, 5], [6])
),
zarr.array(
np.broadcast_to(np.arange(7, 10), (10, 3)), chunks=([5, 5], [3])
),
],
1,
([5, 5], [6, 3]),
),
(
[
zarr.array(np.arange(10), chunks=(5,)),
zarr.array(np.arange(10, 20), chunks=([3, 7],)),
],
0,
([5, 5, 3, 7],),
),
( # Inferring variable chunking from fixed inputs
[
zarr.array(np.arange(12), chunks=(6,)),
zarr.array(np.arange(12, 24), chunks=(4,)),
],
0,
([6, 6, 4, 4, 4],),
),
],
)
def test_variable_length_chunks_success(tmpdir, arrays, axis, expected_chunks):
fns = []
refs = []
for i, x in enumerate(arrays):
fn = f"{tmpdir}/out{i}.zarr"
g = zarr.open(fn)
g.create_dataset("x", data=x, chunks=x.chunks)
fns.append(fn)
ref = kerchunk.zarr.single_zarr(fn, inline=0)
refs.append(ref)

out = kerchunk.combine.concatenate_arrays(
refs, axis=axis, path="x", check_arrays=True
)

mapper = fsspec.get_mapper("reference://", fo=out)
g_result = zarr.open(mapper)

assert g_result["x"].chunks == expected_chunks
np.testing.assert_array_equal(
g_result["x"][...], np.concatenate([a[...] for a in arrays], axis=axis)
)


def test_variable_length_chunks_mismatch_chunk_failure(tmpdir):
arrays = [
zarr.array(np.arange(12).reshape(4, 3), chunks=([2, 2], [1, 2])),
zarr.array(np.arange(12, 24).reshape(4, 3), chunks=([4], [2, 1])),
]
axis = 0

fns = []
refs = []
for i, x in enumerate(arrays):
fn = f"{tmpdir}/out{i}.zarr"
g = zarr.open(fn)
g.create_dataset("x", data=x, chunks=x.chunks)
fns.append(fn)
ref = kerchunk.zarr.single_zarr(fn, inline=0)
refs.append(ref)

with pytest.raises(ValueError):
kerchunk.combine.concatenate_arrays(
refs, axis=axis, path="x", check_arrays=True
)


def test_fixed_to_variable_padding_failure(tmpdir):
arrays = [
zarr.array(np.arange(12), chunks=(6,)),
zarr.array(np.arange(12, 24), chunks=(8,)),
]
axis = 0

fns = []
refs = []
for i, x in enumerate(arrays):
fn = f"{tmpdir}/out{i}.zarr"
g = zarr.open(fn)
g.create_dataset("x", data=x, chunks=x.chunks)
fns.append(fn)
ref = kerchunk.zarr.single_zarr(fn, inline=0)
refs.append(ref)

with pytest.raises(ValueError):
kerchunk.combine.concatenate_arrays(
refs, axis=axis, path="x", check_arrays=True
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"numcodecs<=0.11",
"numpy",
"ujson",
"zarr",
"zarr @ git+https://github.com/martindurant/zarr.git@varchunk",
]

[project.optional-dependencies]
Expand Down
Loading