Skip to content

Commit

Permalink
Revert "A further attempt using np.put"
Browse files Browse the repository at this point in the history
This reverts commit b5f3c44. This change, despite using more numpy specific logic, had major negative consequences on overall performance.
  • Loading branch information
mmcdermott committed Sep 12, 2024
1 parent f687e9e commit f2e0466
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,6 @@ def pad_slice(ln: int, max_ln: int) -> slice:
offset = offset_fn(ln, max_ln)
return slice(offset, offset + ln)

def flat_indices(shape: list[int], indices: list[tuple], L: np.ndarray) -> np.ndarray:
orig_indices = []
for idx, ln in zip(indices, L):
sl = pad_slice(ln, shape[-1])
for i in range(sl.start, sl.stop):
orig_indices.append(idx + (i,))

multi_index = tuple(np.array(x) for x in zip(*orig_indices))
return np.ravel_multi_index(multi_index, shape)

for dim in range(1, self.max_n_dims):
old_max_ln = max(L)
indices = list(
Expand All @@ -712,13 +702,16 @@ def flat_indices(shape: list[int], indices: list[tuple], L: np.ndarray) -> np.nd
for idx, ln in zip(indices, L):
out[f"dim{dim}/mask"][idx + (pad_slice(ln, max_ln),)] = True

flat_idx = flat_indices(shape, indices, L)
bounds = self.tensors[f"dim{dim}/bounds"]
for key in self.keys_at_dim(dim):
if len(self.tensors[f"dim{dim}/{key}"]) == 0:
continue

out[key] = np.zeros(shape=tuple(shape), dtype=self.tensors[f"dim{dim}/{key}"].dtype)
np.put(out[key], flat_idx, self.tensors[f"dim{dim}/{key}"])
st = 0
for idx, ln, b in zip(indices, L, bounds):
out[key][idx + (pad_slice(ln, max_ln),)] = self.tensors[f"dim{dim}/{key}"][st:b]
st = b

return out

Expand Down

0 comments on commit f2e0466

Please sign in to comment.