Skip to content

Commit

Permalink
Backport PR scverse#1783: IO: treat arrays with an empty shape like s…
Browse files Browse the repository at this point in the history
…calars when writing
  • Loading branch information
ilia-kats authored and meeseeksmachine committed Dec 9, 2024
1 parent 8bf09f8 commit 793dac6
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1783.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`write_elem` now filters out incompatible `dataset_kwargs` when saving zero-dimensional arrays {user}`ilia-kats`
16 changes: 13 additions & 3 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from anndata._core.index import _normalize_indices
from anndata._core.merge import intersect_keys
from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset
from anndata._io.utils import H5PY_V3, check_key
from anndata._io.utils import H5PY_V3, check_key, zero_dim_array_as_scalar
from anndata._warnings import OldFormatWarning
from anndata.compat import (
AwkArray,
Expand Down Expand Up @@ -382,6 +382,7 @@ def write_list(
@_REGISTRY.register_write(ZarrGroup, h5py.Dataset, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0"))
@zero_dim_array_as_scalar
def write_basic(
f: GroupStorageType,
k: str,
Expand Down Expand Up @@ -477,6 +478,7 @@ def read_string_array_partial(d, items=None, indices=slice(None)):
)
@_REGISTRY.register_write(H5Group, (np.ndarray, "U"), IOSpec("string-array", "0.2.0"))
@_REGISTRY.register_write(H5Group, (np.ndarray, "O"), IOSpec("string-array", "0.2.0"))
@zero_dim_array_as_scalar
def write_vlen_string_array(
f: H5Group,
k: str,
Expand All @@ -498,6 +500,7 @@ def write_vlen_string_array(
)
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "U"), IOSpec("string-array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "O"), IOSpec("string-array", "0.2.0"))
@zero_dim_array_as_scalar
def write_vlen_string_array_zarr(
f: ZarrGroup,
k: str,
Expand Down Expand Up @@ -1134,8 +1137,15 @@ def write_hdf5_scalar(
):
# Can’t compress scalars, error is thrown
dataset_kwargs = dict(dataset_kwargs)
dataset_kwargs.pop("compression", None)
dataset_kwargs.pop("compression_opts", None)
for arg in (
"compression",
"compression_opts",
"chunks",
"shuffle",
"fletcher32",
"scaleoffset",
):
dataset_kwargs.pop(arg, None)
f.create_dataset(key, data=np.array(value), **dataset_kwargs)


Expand Down
31 changes: 27 additions & 4 deletions src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import wraps
from functools import WRAPPER_ASSIGNMENTS, wraps
from itertools import pairwise
from typing import TYPE_CHECKING, cast
from warnings import warn
Expand All @@ -12,11 +12,12 @@
from ..compat import add_note

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal
from collections.abc import Callable, Mapping
from typing import Any, Literal

from .._types import StorageType
from .._types import ContravariantRWAble, StorageType, _WriteInternal
from ..compat import H5Group, ZarrGroup
from .specs.registry import Writer

Storage = StorageType | BaseCompressedSparseDataset

Expand Down Expand Up @@ -285,3 +286,25 @@ def _read_legacy_raw(
if "varm" in attrs and "raw.varm" in f:
raw["varm"] = read_attr(f["raw.varm"])
return raw


def zero_dim_array_as_scalar(func: _WriteInternal):
"""\
A decorator for write_elem implementations of arrays where zero-dimensional arrays need special handling.
"""

@wraps(func, assigned=WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__"))
def func_wrapper(
f: StorageType,
k: str,
elem: ContravariantRWAble,
*,
_writer: Writer,
dataset_kwargs: Mapping[str, Any],
):
if elem.shape == ():
_writer.write_elem(f, k, elem[()], dataset_kwargs=dataset_kwargs)
else:
func(f, k, elem, _writer=_writer, dataset_kwargs=dataset_kwargs)

return func_wrapper
21 changes: 21 additions & 0 deletions tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,27 @@ def test_io_spec(store, value, encoding_type):
assert get_spec(store[key]) == _REGISTRY.get_spec(value)


@pytest.mark.parametrize(
("value", "encoding_type"),
[
pytest.param(np.asarray(1), "numeric-scalar", id="scalar_int"),
pytest.param(np.asarray(1.0), "numeric-scalar", id="scalar_float"),
pytest.param(np.asarray(True), "numeric-scalar", id="scalar_bool"),
pytest.param(np.asarray("test"), "string", id="scalar_string"),
],
)
def test_io_spec_compressed_scalars(store: G, value: np.ndarray, encoding_type: str):
key = f"key_for_{encoding_type}"
write_elem(
store, key, value, dataset_kwargs={"compression": "gzip", "compression_opts": 5}
)

assert encoding_type == _read_attr(store[key].attrs, "encoding-type")

from_disk = read_elem(store[key])
assert_equal(value, from_disk)


# Can't instantiate cupy types at the top level, so converting them within the test
@pytest.mark.gpu
@pytest.mark.parametrize(
Expand Down

0 comments on commit 793dac6

Please sign in to comment.