Skip to content

Commit

Permalink
recursive Group.members
Browse files Browse the repository at this point in the history
This PR adds a recursive=True flag to Group.members, for recursively
listing the members of some hierarhcy.

This is useful for Consolidated Metadata, which needs to recursively
inspect children. IMO, it's useful (and simple) enough to include
in the public API.
  • Loading branch information
TomAugspurger committed Aug 25, 2024
1 parent 90940a0 commit 8ee89f4
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 14 deletions.
53 changes: 43 additions & 10 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,20 +424,43 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
def __repr__(self) -> str:
return f"<AsyncGroup {self.store_path}>"

async def nmembers(self) -> int:
async def nmembers(self, recursive: bool = False) -> int:
"""
Count the number of members in this group.
Parameters
----------
recursive : bool, default False
Whether to recursively count arrays and groups in child groups of
this Group. By default, just immediate child array and group members
are counted.
Returns
-------
count : int
"""
# TODO: consider using aioitertools.builtins.sum for this
# return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0)
n = 0
async for _ in self.members():
async for _ in self.members(recursive=recursive):
n += 1
return n

async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
async def members(
self, recursive: bool = False
) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
This method requires that `store_path.store` supports directory listing.
The results are not guaranteed to be ordered.
Parameters
----------
recursive : bool, default False
Whether to recursively include arrays and groups in child groups of
this Group. By default, just immediate child array and group members
are included.
"""
if not self.store_path.store.supports_listing:
msg = (
Expand All @@ -456,7 +479,19 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
if key in _skip_keys:
continue
try:
yield (key, await self.getitem(key))
obj = await self.getitem(key)
yield (key, obj)

if (
recursive
and hasattr(obj.metadata, "node_type")
and obj.metadata.node_type == "group"
):
# the assert is just for mypy to know that `obj.metadata.node_type`
# implies an AsyncGroup, not an AsyncArray
assert isinstance(obj, AsyncGroup)
async for child_key, val in obj.members(recursive=recursive):
yield "/".join([key, child_key]), val
except KeyError:
# keyerror is raised when `key` names an object (in the object storage sense),
# as opposed to a prefix, in the store under the prefix associated with this group
Expand Down Expand Up @@ -628,17 +663,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
self._sync(self._async_group.update_attributes(new_attributes))
return self

@property
def nmembers(self) -> int:
return self._sync(self._async_group.nmembers())
def nmembers(self, recursive: bool = False) -> int:
return self._sync(self._async_group.nmembers(recursive=recursive))

@property
def members(self) -> tuple[tuple[str, Array | Group], ...]:
def members(self, recursive: bool = False) -> tuple[tuple[str, Array | Group], ...]:
"""
Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group)
pairs
"""
_members = self._sync_iter(self._async_group.members())
_members = self._sync_iter(self._async_group.members(recursive=recursive))

result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members))
return result
Expand Down
52 changes: 48 additions & 4 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
members_expected["subgroup"] = group.create_group("subgroup")
# make a sub-sub-subgroup, to ensure that the children calculation doesn't go
# too deep in the hierarchy
_ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore
subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore

members_expected["subarray"] = group.create_array(
"subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True
Expand All @@ -101,7 +101,13 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat)
# this creates a directory with a random key in it
# this should not show up as a member
sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000")))
members_observed = group.members
members_observed = group.members()
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

# recursive=True
members_observed = group.members(recursive=True)
members_expected["subgroup/subsubgroup"] = subsubgroup
# members are not guaranteed to be ordered, so sort before comparing
assert sorted(dict(members_observed)) == sorted(members_expected)

Expand Down Expand Up @@ -349,7 +355,8 @@ def test_group_create_array(
if method == "create_array":
array = group.create_array(name="array", shape=shape, dtype=dtype, data=data)
elif method == "array":
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
with pytest.warns(DeprecationWarning):
array = group.array(name="array", shape=shape, dtype=dtype, data=data)
else:
raise AssertionError

Expand All @@ -358,7 +365,7 @@ def test_group_create_array(
with pytest.raises(ContainsArrayError):
group.create_array(name="array", shape=shape, dtype=dtype, data=data)
elif method == "array":
with pytest.raises(ContainsArrayError):
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
group.array(name="array", shape=shape, dtype=dtype, data=data)
assert array.shape == shape
assert array.dtype == np.dtype(dtype)
Expand Down Expand Up @@ -653,3 +660,40 @@ async def test_asyncgroup_update_attributes(

agroup_new_attributes = await agroup.update_attributes(attributes_new)
assert agroup_new_attributes.attrs == attributes_new


async def test_group_members_async(store: LocalStore | MemoryStore):
group = AsyncGroup(
GroupMetadata(),
store_path=StorePath(store=store, path="root"),
)
a0 = await group.create_array("a0", (1,))
g0 = await group.create_group("g0")
a1 = await g0.create_array("a1", (1,))
g1 = await g0.create_group("g1")
a2 = await g1.create_array("a2", (1,))
g2 = await g1.create_group("g2")

# immediate children
children = sorted([x async for x in group.members()], key=lambda x: x[0])
assert children == [
("a0", a0),
("g0", g0),
]

nmembers = await group.nmembers()
assert nmembers == 2

all_children = sorted([x async for x in group.members(recursive=True)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
("g0/g1/a2", a2),
("g0/g1/g2", g2),
]
assert all_children == expected

nmembers = await group.nmembers(recursive=True)
assert nmembers == 6

0 comments on commit 8ee89f4

Please sign in to comment.