Skip to content

Commit

Permalink
Some more changes; still not working.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Sep 22, 2024
1 parent 347dd60 commit 54515dc
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,9 @@ def __getitem__(self, idx: int | slice | np.ndarray):
>>> as_dense['val']
array([[3.3, 2. , 0. ]])
>>> as_dense = J[1:2, 1:].to_dense()
>>> as_dense['T']
array([[5]])
>>> as_dense['id']
array([[[3, 2, 2]]])
>>> as_dense['val']
array([[[3.3, 2. , 0. ]]])
Traceback (most recent call last):
...
ValueError: Multi-level non-int slicing is not supported.
"""
return self._slice(self._get_slice_indices(idx))

Expand Down Expand Up @@ -1267,10 +1264,17 @@ def _get_slice_indices(
case tuple() as T:
squeeze_dims = []
out_indices = {}
seen_non_int = False
for dim, idx in enumerate(T):
if seen_non_int:
raise ValueError("Multi-level non-int slicing is not supported.")

if isinstance(idx, int):
idx = slice(idx, idx + 1)
squeeze_dims.append(dim)
else:
seen_non_int = True

if not isinstance(idx, slice):
raise TypeError(
f"{type(idx)} at index {dim} not supported for "
Expand Down Expand Up @@ -1300,16 +1304,17 @@ def _get_slice_indices_internal(
for key in self.keys_at_dim(0):
out[f"dim0/{key}"] = slice(st_i, end_i)

if f"dim{starting_dim}/bounds" in out:
st_i += out[f"dim{starting_dim}/bounds"].start
if f"dim{starting_dim+1}/bounds" in out:
out_B_slice = out[f"dim{starting_dim+1}/bounds"]
st_i += out_B_slice.start
if end_i is None:
end_i = None
end_i = out_B_slice.stop
else:
end_i = end_i + out[f"dim{starting_dim}/bounds"].start
if out[f"dim{starting_dim}/bounds"].stop is not None:
end_i = min(end_i, out[f"dim{starting_dim}/bounds"].stop)
end_i += out_B_slice.start
if out_B_slice.stop is not None:
end_i = min(end_i, out_B_slice.stop)

for dim in range(max(starting_dim, 1), self.max_n_dims):
for dim in range(max(starting_dim + 1, 1), self.max_n_dims):
out[f"dim{dim}/bounds"] = slice(st_i, end_i)

with self._tensor_at_key(f"dim{dim}/bounds") as bounds:
Expand Down

0 comments on commit 54515dc

Please sign in to comment.