Skip to content

Commit

Permalink
added caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Sep 21, 2024
1 parent ba23512 commit 85a3b11
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence
from functools import cached_property
from pathlib import Path
from functools import lru_cache

import numpy as np
from safetensors import safe_open
Expand Down Expand Up @@ -1200,6 +1201,12 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge

return J._slice(indices)

@staticmethod
@lru_cache(maxsize=None)
def _cached_slice(tensors_fp, k, idx):
with safe_open(tensors_fp, framework="np") as f:
return f.get_slice(k)[idx]

def _slice_single(
self, indices: dict[str, slice | np.ndarray], reduce_dim: bool = False
) -> JointNestedRaggedTensorDict:
Expand All @@ -1223,8 +1230,7 @@ def _slice_single(
match idx:
case slice() as S:
if self._tensors is None:
with safe_open(self._tensors_fp, framework="np") as f:
tensors[new_key] = f.get_slice(k)[S]
tensors[new_key] = self._cached_slice(self._tensors_fp, k, S)
else:
tensors[new_key] = self._tensors[k][S]
case np.ndarray() as arr if arr.ndim == 1:
Expand Down Expand Up @@ -1261,6 +1267,19 @@ def _slice(
case _:
raise TypeError(f"{type(indices)} not supported for {self.__class__.__name__} slicing")


@classmethod
@lru_cache(maxsize=None)
def _cached_bounds(cls, tensors_fp, dim):
with safe_open(tensors_fp, framework="np") as f:
return f.get_slice(f"dim{dim}/bounds")

def _get_bounds(self, dim):
if self._tensors is None:
return self._cached_bounds(self._tensors_fp, dim)
else:
return self.tensors[f"dim{dim}/bounds"]

def _get_slice_indices(
self, idx: int | slice | np.ndarray
) -> tuple[dict[str, slice], bool] | list[dict[str, slice]] | dict[str, slice]:
Expand All @@ -1276,7 +1295,7 @@ def _get_slice_indices(
end_i = S.stop

if S.step not in (None, 1):
raise ValueError("Only slices with step size of None or 1 are supported; got {S.step}")
raise ValueError(f"Only slices with step size of None or 1 are supported; got {S.step}")

out = {}
for key in self.keys_at_dim(0):
Expand All @@ -1290,11 +1309,7 @@ def _get_slice_indices(
else:
B_slice = slice(st_i - 1, end_i)

if self._tensors is None:
with safe_open(self._tensors_fp, framework="np") as f:
B = f.get_slice(f"dim{dim}/bounds")[B_slice]
else:
B = self.tensors[f"dim{dim}/bounds"][B_slice]
B = self._get_bounds(dim)[B_slice]

if st_i == 0:
offset = 0
Expand All @@ -1312,4 +1327,4 @@ def _get_slice_indices(

return out
case _:
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing")
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing")

0 comments on commit 85a3b11

Please sign in to comment.