Skip to content

Commit

Permalink
Aligned code paths between getitem and load_slice.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Sep 21, 2024
1 parent e707006 commit ad25b8b
Showing 1 changed file with 122 additions and 99 deletions.
221 changes: 122 additions & 99 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import itertools
from collections import defaultdict
from collections.abc import Sequence
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -187,7 +186,7 @@ def tensors(self) -> dict[str, np.ndarray]:
@cached_property
def _tensor_keys(self) -> set[str]:
if self._tensors is None:
with safe_open(self._fp, framework="np") as f:
with safe_open(self._tensors_fp, framework="np") as f:
return set(f.keys())
else:
return set(self._tensors.keys())
Expand Down Expand Up @@ -360,6 +359,19 @@ def save(self, fp: Path):
Args:
fp: The path to which the tensors will be saved.
Examples:
>>> import tempfile
>>> J = JointNestedRaggedTensorDict({
... "T": [[1, 2, 3 ], [4, 5 ]],
... "id": [[[1, 2, 3], [3, 4], [1, 2 ]], [[3], [3, 2, 2]]],
... "val": [[[1, 0.2, 0], [3.1, 0], [1, 2.2]], [[3], [3.3, 2, 0]]],
... })
>>> with tempfile.TemporaryDirectory() as dirpath:
... fp = Path(dirpath) / "tensors.nrt"
... J.save(fp)
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... assert J == J2
"""
save_file(self.tensors, fp)

Expand Down Expand Up @@ -520,7 +532,7 @@ def __getitem__(self, idx: int | slice | np.ndarray):
>>> J["T"]
Traceback (most recent call last):
...
TypeError: <class 'str'> not supported for JointNestedRaggedTensorDict.__getitem__
TypeError: <class 'str'> not supported for JointNestedRaggedTensorDict slicing
>>> as_dense = J[np.array([0, 2])].to_dense()
>>> as_dense['T']
array([[1, 2, 3],
Expand All @@ -542,58 +554,9 @@ def __getitem__(self, idx: int | slice | np.ndarray):
[1. , 0. , 0. ],
[0. , 0. , 0. ]]])
"""
match idx:
case np.ndarray() as arr if arr.dtype in (NP_INT_TYPES + NP_UINT_TYPES) and arr.ndim == 1:
return self.__class__.vstack([self[int(i)] for i in arr])
case int() as i:
if self.min_n_dims == 1:
raise ValueError(
"Cannot index into a tensor collection with a 1D tensor with an integer."
)

return self[slice(i, i + 1)].squeeze(dim=0)

case slice() as S:
if S.step not in (None, 1):
raise ValueError("Only slices with step size of None or 1 are supported; got {S.step}")

st_i = 0 if S.start is None else S.start
end_i = S.stop

out_tensors = {}
for key in self.keys_at_dim(0):
out_tensors[f"dim0/{key}"] = self.tensors[f"dim0/{key}"][st_i:end_i]

for dim in range(1, self.max_n_dims):
L = self.tensors[f"dim{dim}/lengths"]
out_tensors[f"dim{dim}/lengths"] = L[st_i:end_i]

B = self.tensors[f"dim{dim}/bounds"]

if st_i == 0:
offset = 0
else:
offset = B[st_i - 1]

B = B[st_i:end_i] - offset

out_tensors[f"dim{dim}/bounds"] = B

vals_start = offset
if len(B) == 0:
vals_end = offset
else:
vals_end = B[-1] + offset

for key in self.keys_at_dim(dim):
out_tensors[f"dim{dim}/{key}"] = self.tensors[f"dim{dim}/{key}"][vals_start:vals_end]
indices = self._get_slice_indices(idx)

st_i = offset
end_i = vals_end

return JointNestedRaggedTensorDict(processed_tensors=out_tensors, schema=self.schema)
case _:
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__}.__getitem__")
return self._slice(indices)

def to_dense(self, padding_side: str = "right") -> dict[str, np.array]:
"""Returns a dense view of these ragged tensors.
Expand Down Expand Up @@ -1123,7 +1086,9 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge
>>> J_dense = J.to_dense()
>>> J2_dense = J2.to_dense()
>>> for k in J_dense.keys():
... assert (J_dense[k] == J2_dense[k]).all(), f"Tensors at {k} unequal!"
... assert (J_dense[k] == J2_dense[k]).all(), (
... f"Tensors at {k} unequal! Want {J_dense[k]}, got {J2_dense[k]}"
... )
>>> J = JointNestedRaggedTensorDict({
... "T": [[1, 2, 3 ], [4, 5 ]],
... "id": [[[1, 2, 3], [3, 4], [1, 2 ]], [[3], [3, 2, 2]]],
Expand All @@ -1138,68 +1103,126 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge
>>> J_dense = J.to_dense()
>>> J2_dense = J2.to_dense()
>>> for k in J_dense.keys():
... assert (J_dense[k] == J2_dense[k]).all(), f"Tensors at {k} unequal!"
... assert (J_dense[k] == J2_dense[k]).all(), (
... f"Tensors at {k} unequal! Want {J_dense[k]}, got {J2_dense[k]}"
... )
"""

J = cls(tensors_fp=fp)
indices = J._get_slice_indices(idx)

return J._slice(indices)

def _slice_single(
self, indices: dict[str, slice | np.ndarray], reduce_dim: bool = False
) -> JointNestedRaggedTensorDict:
"""Slices this collection of tensors by the given indices."""

tensors = {}
schema = {}

for k, idx in indices.items():
old_dim, key = k.split("/")
old_dim_int = int(old_dim[3:])

if reduce_dim:
if old_dim_int == 1 and key in ("lengths", "bounds"):
# These keys will be dropped as this tensor will become a 1D tensor in truth.
continue
new_key = f"dim{old_dim_int - 1}/{key}"
else:
new_key = k

match idx:
case slice() as S:
if self._tensors is None:
with safe_open(self._tensors_fp, framework="np") as f:
tensors[new_key] = f.get_slice(k)[S]
else:
tensors[new_key] = self._tensors[k][S]
case np.ndarray() as arr if arr.ndim == 1:
tensors[new_key] = arr
case _:
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing")

schema[new_key] = tensors[new_key].dtype

return self.__class__(processed_tensors=tensors, schema=schema)

def _slice(
self,
indices: tuple[dict[str, slice | np.ndarray], bool]
| list[tuple[dict[str, slice | np.ndarray], bool]],
) -> JointNestedRaggedTensorDict:
"""Returns a new JointNestedRaggedTensorDict that is a slice of this one.
Args:
indices: The indices to slice by.
Returns:
A new JointNestedRaggedTensorDict that is a slice of this one.
Examples:
"""
match indices:
case tuple() as T:
return self._slice_single(*T)
case dict():
return self._slice_single(indices, reduce_dim=False)
case list():
return self.__class__.vstack([self._slice_single(idx, reduce_dim=True) for idx in indices])
case _:
raise TypeError(f"{type(indices)} not supported for {self.__class__.__name__} slicing")

def _get_slice_indices(
self, idx: int | slice | np.ndarray
) -> tuple[dict[str, slice], bool] | list[dict[str, slice]] | dict[str, slice]:
"""Returns the start and end indices for each dimension of self after slicing by idx."""

match idx:
case np.ndarray() as arr if arr.dtype in (NP_INT_TYPES + NP_UINT_TYPES) and arr.ndim == 1:
return cls.vstack([cls.load_slice(fp, int(i)) for i in arr])
return [self._get_slice_indices(slice(i, i + 1)) for i in arr]
case int() as i:
return cls.load_slice(fp, slice(i, i + 1))[0]
return (self._get_slice_indices(slice(i, i + 1)), True)
case slice() as S:
st_i = 0 if S.start is None else S.start
end_i = S.stop

if S.step not in (None, 1):
raise ValueError("Only slices with step size of None or 1 are supported; got {S.step}")

tensors = {}
schema = {}
with safe_open(fp, framework="np") as f:
keys_by_dim = defaultdict(list)

for k in f.keys():
if cls._is_meta_key(k):
continue

keys_by_dim[cls._get_dim_from_key_str(k)].append(k)

max_n_dims = max(keys_by_dim.keys()) + 1
out = {}
for key in self.keys_at_dim(0):
out[f"dim0/{key}"] = slice(st_i, end_i)

for key in keys_by_dim[0]:
v = f.get_slice(k)[st_i:end_i]
schema[k] = v.dtype
tensors[k] = v
for dim in range(1, self.max_n_dims):
out[f"dim{dim}/lengths"] = slice(st_i, end_i)

for dim in range(1, max_n_dims):
try:
tensors[f"dim{dim}/lengths"] = f.get_slice(f"dim{dim}/lengths")[st_i:end_i]
except SystemError as e:
raise ValueError(
f"Error loading lengths for dim {dim} with st_i={st_i} and end_i={end_i}"
) from e
if st_i == 0:
B_slice = slice(0, end_i)
else:
B_slice = slice(st_i - 1, end_i)

if st_i == 0:
offset = 0
B = f.get_slice(f"dim{dim}/bounds")[st_i:end_i]
else:
B = f.get_slice(f"dim{dim}/bounds")[st_i - 1 : end_i]
offset = B[0]
B = B[1:] - offset
if self._tensors is None:
with safe_open(self._tensors_fp, framework="np") as f:
B = f.get_slice(f"dim{dim}/bounds")[B_slice]
else:
B = self.tensors[f"dim{dim}/bounds"][B_slice]

vals_start = offset
vals_end = B[-1] + offset
if st_i == 0:
offset = 0
else:
offset = B[0]
B = B[1:] - offset

tensors[f"dim{dim}/bounds"] = B
st_i = offset
end_i = (B[-1] + offset) if len(B) > 0 else offset

for k in keys_by_dim[dim]:
v = f.get_slice(k)[vals_start:vals_end]
schema[k] = v.dtype
tensors[k] = v # np.split(v, B[:-1])
for key in self.keys_at_dim(dim):
out[f"dim{dim}/{key}"] = slice(st_i, end_i)

st_i = 0 if st_i == 0 else offset
end_i = B[-1] + offset
out[f"dim{dim}/bounds"] = B

return cls(processed_tensors=tensors, schema=schema)
return out
case _:
raise TypeError(f"{type(idx)} not supported for {cls.__name__}.load_slice")
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing")

0 comments on commit ad25b8b

Please sign in to comment.