diff --git a/kerchunk/combine.py b/kerchunk/combine.py index 4426d445..199c7360 100644 --- a/kerchunk/combine.py +++ b/kerchunk/combine.py @@ -566,6 +566,8 @@ def concatenate_arrays( Input reference sets, maybe generated by ``kerchunk.zarr.single_zarr`` storage_options: dict | None To create the filesystems, such at target/remote protocol and target/remote options + axis + Axis to concatenate along. key_seperator: str "." or "/", how the zarr keys are stored path: str or None @@ -589,15 +591,43 @@ def _replace(l: list, i: int, v) -> list: return l n_files = len(files) + fss = [ + fsspec.filesystem("reference", fo=fn, **(storage_options or {})) for fn in files + ] + zarrays = [ujson.load(fs.open(f"{path}.zarray")) for fs in fss] + + # Determine chunk type + _prev_chunks = zarrays[0]["chunks"][axis] + for zarray in zarrays: + axis_chunks = zarray["chunks"][axis] + if isinstance(axis_chunks, list) or (axis_chunks != _prev_chunks): + chunk_type = "variable" + break + _prev_chunks = axis_chunks + else: + chunk_type = "fixed" chunks_offset = 0 - for i, fn in enumerate(files): - fs = fsspec.filesystem("reference", fo=fn, **(storage_options or {})) - zarray = ujson.load(fs.open(f"{path}.zarray")) + for i, (fs, zarray) in enumerate(zip(fss, zarrays)): shape = zarray["shape"] chunks = zarray["chunks"] - n_chunks, rem = divmod(shape[axis], chunks[axis]) - n_chunks += rem > 0 + + if isinstance(chunks[axis], int): + n_chunks, rem = divmod(shape[axis], chunks[axis]) + n_chunks += rem > 0 + if chunk_type == "variable": + if rem != 0: + raise ValueError( + "Cannot handle padded chunks when creating variably chunked arrays. " + f"{i}-th array has {rem} padding for the last chunk." + ) + assert rem == 0 + chunks[axis] = [chunks[axis] for _ in range(n_chunks)] + else: + n_chunks = len(chunks[axis]) + rem = 0 + # Can this ever work? + # rem = sum(chunks[axis]) - shape[axis] if i == 0: base_shape = _replace(shape, axis, None) @@ -610,6 +640,13 @@ def _replace(l: list, i: int, v) -> list: out[name] = fs.references[name] else: result_shape[axis] += shape[axis] + if chunk_type == "fixed": + assert isinstance(chunks[axis], int) + elif chunk_type == "variable": + assert isinstance(chunks[axis], list) + result_zarray["chunks"][axis] = ( + result_zarray["chunks"][axis] + chunks[axis] + ) # Safety checks if check_arrays: @@ -620,10 +657,22 @@ def _replace(l: list, i: int, v) -> list: raise ValueError( f"Incompatible array shape at index {i}. Expected {expected_shape}, got {shape}." ) - if chunks != base_chunks: - raise ValueError( - f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}." - ) + if chunk_type == "fixed": + if chunks != base_chunks: + raise ValueError( + f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}." + ) + else: + assert chunk_type == "variable" + base_chunks_to_compare = [ + v for i, v in enumerate(base_chunks) if i != axis + ] + chunks_to_compare = [v for i, v in enumerate(chunks) if i != axis] + # TODO: deduplicate + if chunks_to_compare != base_chunks_to_compare: + raise ValueError( + f"Incompatible array chunks at index {i}. Expected {base_chunks}, got {chunks}." + ) if i < (n_files - 1) and rem != 0: raise ValueError( f"Array at index {i} has irregular chunking at its boundary. " diff --git a/kerchunk/tests/test_combine_concat.py b/kerchunk/tests/test_combine_concat.py index 3f7ff823..a89f3d3d 100644 --- a/kerchunk/tests/test_combine_concat.py +++ b/kerchunk/tests/test_combine_concat.py @@ -103,7 +103,10 @@ def test_fail_chunks(tmpdir): ref1 = kerchunk.zarr.single_zarr(fn1, inline=0) ref2 = kerchunk.zarr.single_zarr(fn2, inline=0) - with pytest.raises(ValueError, match=r"Incompatible array chunks at index 1.*"): + with pytest.raises( + ValueError, + match=r"Cannot handle padded chunks when creating variably chunked arrays.", + ): kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True) @@ -139,3 +142,114 @@ def test_fail_irregular_chunk_boundaries(tmpdir): with pytest.raises(ValueError, match=r"Array at index 0 has irregular chunking.*"): kerchunk.combine.concatenate_arrays([ref1, ref2], path="x", check_arrays=True) + + +@pytest.mark.parametrize( + "arrays,axis,expected_chunks", + [ + ( + [ + zarr.array(np.arange(10), chunks=([2, 2, 2, 2, 2],)), + zarr.array(np.arange(10, 20), chunks=([3, 3, 3, 1],)), + ], + 0, + ([2, 2, 2, 2, 2, 3, 3, 3, 1],), + ), + ( + [ + zarr.array( + np.broadcast_to(np.arange(6), (10, 6)), chunks=([5, 5], [6]) + ), + zarr.array( + np.broadcast_to(np.arange(7, 10), (10, 3)), chunks=([5, 5], [3]) + ), + ], + 1, + ([5, 5], [6, 3]), + ), + ( + [ + zarr.array(np.arange(10), chunks=(5,)), + zarr.array(np.arange(10, 20), chunks=([3, 7],)), + ], + 0, + ([5, 5, 3, 7],), + ), + ( # Inferring variable chunking from fixed inputs + [ + zarr.array(np.arange(12), chunks=(6,)), + zarr.array(np.arange(12, 24), chunks=(4,)), + ], + 0, + ([6, 6, 4, 4, 4],), + ), + ], +) +def test_variable_length_chunks_success(tmpdir, arrays, axis, expected_chunks): + fns = [] + refs = [] + for i, x in enumerate(arrays): + fn = f"{tmpdir}/out{i}.zarr" + g = zarr.open(fn) + g.create_dataset("x", data=x, chunks=x.chunks) + fns.append(fn) + ref = kerchunk.zarr.single_zarr(fn, inline=0) + refs.append(ref) + + out = kerchunk.combine.concatenate_arrays( + refs, axis=axis, path="x", check_arrays=True + ) + + mapper = fsspec.get_mapper("reference://", fo=out) + g_result = zarr.open(mapper) + + assert g_result["x"].chunks == expected_chunks + np.testing.assert_array_equal( + g_result["x"][...], np.concatenate([a[...] for a in arrays], axis=axis) + ) + + +def test_variable_length_chunks_mismatch_chunk_failure(tmpdir): + arrays = [ + zarr.array(np.arange(12).reshape(4, 3), chunks=([2, 2], [1, 2])), + zarr.array(np.arange(12, 24).reshape(4, 3), chunks=([4], [2, 1])), + ] + axis = 0 + + fns = [] + refs = [] + for i, x in enumerate(arrays): + fn = f"{tmpdir}/out{i}.zarr" + g = zarr.open(fn) + g.create_dataset("x", data=x, chunks=x.chunks) + fns.append(fn) + ref = kerchunk.zarr.single_zarr(fn, inline=0) + refs.append(ref) + + with pytest.raises(ValueError): + kerchunk.combine.concatenate_arrays( + refs, axis=axis, path="x", check_arrays=True + ) + + +def test_fixed_to_variable_padding_failure(tmpdir): + arrays = [ + zarr.array(np.arange(12), chunks=(6,)), + zarr.array(np.arange(12, 24), chunks=(8,)), + ] + axis = 0 + + fns = [] + refs = [] + for i, x in enumerate(arrays): + fn = f"{tmpdir}/out{i}.zarr" + g = zarr.open(fn) + g.create_dataset("x", data=x, chunks=x.chunks) + fns.append(fn) + ref = kerchunk.zarr.single_zarr(fn, inline=0) + refs.append(ref) + + with pytest.raises(ValueError): + kerchunk.combine.concatenate_arrays( + refs, axis=axis, path="x", check_arrays=True + ) diff --git a/pyproject.toml b/pyproject.toml index 767b0aca..4debcdf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "numcodecs<=0.11", "numpy", "ujson", - "zarr", + "zarr @ git+https://github.com/martindurant/zarr.git@varchunk", ] [project.optional-dependencies]