Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IO: treat arrays with an empty shape like scalars when writing #1783

Merged
merged 9 commits into from
Dec 9, 2024
1 change: 1 addition & 0 deletions docs/release-notes/1783.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`write_elem` now filters out incompatible `dataset_kwargs` when saving zero-dimensional arrays {user}`ilia-kats`
16 changes: 13 additions & 3 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from anndata._core.index import _normalize_indices
from anndata._core.merge import intersect_keys
from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset
from anndata._io.utils import H5PY_V3, check_key
from anndata._io.utils import H5PY_V3, check_key, zero_dim_array_as_scalar
from anndata._warnings import OldFormatWarning
from anndata.compat import (
AwkArray,
Expand Down Expand Up @@ -382,6 +382,7 @@ def write_list(
@_REGISTRY.register_write(ZarrGroup, h5py.Dataset, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0"))
@zero_dim_array_as_scalar
def write_basic(
f: GroupStorageType,
k: str,
Expand Down Expand Up @@ -477,6 +478,7 @@ def read_string_array_partial(d, items=None, indices=slice(None)):
)
@_REGISTRY.register_write(H5Group, (np.ndarray, "U"), IOSpec("string-array", "0.2.0"))
@_REGISTRY.register_write(H5Group, (np.ndarray, "O"), IOSpec("string-array", "0.2.0"))
@zero_dim_array_as_scalar
def write_vlen_string_array(
f: H5Group,
k: str,
Expand All @@ -498,6 +500,7 @@ def write_vlen_string_array(
)
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "U"), IOSpec("string-array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "O"), IOSpec("string-array", "0.2.0"))
@zero_dim_array_as_scalar
def write_vlen_string_array_zarr(
f: ZarrGroup,
k: str,
Expand Down Expand Up @@ -1134,8 +1137,15 @@ def write_hdf5_scalar(
):
# Can’t compress scalars, error is thrown
dataset_kwargs = dict(dataset_kwargs)
dataset_kwargs.pop("compression", None)
dataset_kwargs.pop("compression_opts", None)
for arg in (
"compression",
"compression_opts",
"chunks",
"shuffle",
"fletcher32",
"scaleoffset",
):
dataset_kwargs.pop(arg, None)
f.create_dataset(key, data=np.array(value), **dataset_kwargs)


Expand Down
31 changes: 27 additions & 4 deletions src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import wraps
from functools import WRAPPER_ASSIGNMENTS, wraps
from itertools import pairwise
from typing import TYPE_CHECKING, cast
from warnings import warn
Expand All @@ -12,11 +12,12 @@
from ..compat import add_note

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal
from collections.abc import Callable, Mapping
from typing import Any, Literal

from .._types import StorageType
from .._types import ContravariantRWAble, StorageType, _WriteInternal
from ..compat import H5Group, ZarrGroup
from .specs.registry import Writer

Storage = StorageType | BaseCompressedSparseDataset

Expand Down Expand Up @@ -285,3 +286,25 @@ def _read_legacy_raw(
if "varm" in attrs and "raw.varm" in f:
raw["varm"] = read_attr(f["raw.varm"])
return raw


def zero_dim_array_as_scalar(func: _WriteInternal):
"""\
A decorator for write_elem implementations of arrays where zero-dimensional arrays need special handling.
"""

@wraps(func, assigned=WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__"))
def func_wrapper(
f: StorageType,
k: str,
elem: ContravariantRWAble,
*,
_writer: Writer,
dataset_kwargs: Mapping[str, Any],
):
if elem.shape == ():
_writer.write_elem(f, k, elem[()], dataset_kwargs=dataset_kwargs)
else:
func(f, k, elem, _writer=_writer, dataset_kwargs=dataset_kwargs)

return func_wrapper
21 changes: 21 additions & 0 deletions tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,27 @@ def test_io_spec(store, value, encoding_type):
assert get_spec(store[key]) == _REGISTRY.get_spec(value)


@pytest.mark.parametrize(
("value", "encoding_type"),
[
pytest.param(np.asarray(1), "numeric-scalar", id="scalar_int"),
pytest.param(np.asarray(1.0), "numeric-scalar", id="scalar_float"),
pytest.param(np.asarray(True), "numeric-scalar", id="scalar_bool"),
pytest.param(np.asarray("test"), "string", id="scalar_string"),
],
)
def test_io_spec_compressed_scalars(store: G, value: np.ndarray, encoding_type: str):
key = f"key_for_{encoding_type}"
write_elem(
store, key, value, dataset_kwargs={"compression": "gzip", "compression_opts": 5}
)

assert encoding_type == _read_attr(store[key].attrs, "encoding-type")

from_disk = read_elem(store[key])
assert_equal(value, from_disk)


# Can't instantiate cupy types at the top level, so converting them within the test
@pytest.mark.gpu
@pytest.mark.parametrize(
Expand Down
Loading