From 85a3b11d9d13d13e0016aebe68ca4f6f54bfea7f Mon Sep 17 00:00:00 2001 From: Nassim Date: Sat, 21 Sep 2024 19:48:41 -0400 Subject: [PATCH] added caching --- src/nested_ragged_tensors/ragged_numpy.py | 33 ++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/nested_ragged_tensors/ragged_numpy.py b/src/nested_ragged_tensors/ragged_numpy.py index 41382bf..5a9356a 100644 --- a/src/nested_ragged_tensors/ragged_numpy.py +++ b/src/nested_ragged_tensors/ragged_numpy.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from functools import cached_property from pathlib import Path +from functools import lru_cache import numpy as np from safetensors import safe_open @@ -1200,6 +1201,12 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge return J._slice(indices) + @staticmethod + @lru_cache(maxsize=None) + def _cached_slice(tensors_fp, k, idx): + with safe_open(tensors_fp, framework="np") as f: + return f.get_slice(k)[idx] + def _slice_single( self, indices: dict[str, slice | np.ndarray], reduce_dim: bool = False ) -> JointNestedRaggedTensorDict: @@ -1223,8 +1230,7 @@ def _slice_single( match idx: case slice() as S: if self._tensors is None: - with safe_open(self._tensors_fp, framework="np") as f: - tensors[new_key] = f.get_slice(k)[S] + tensors[new_key] = self._cached_slice(self._tensors_fp, k, S) else: tensors[new_key] = self._tensors[k][S] case np.ndarray() as arr if arr.ndim == 1: @@ -1261,6 +1267,19 @@ def _slice( case _: raise TypeError(f"{type(indices)} not supported for {self.__class__.__name__} slicing") + + @classmethod + @lru_cache(maxsize=None) + def _cached_bounds(cls, tensors_fp, dim): + with safe_open(tensors_fp, framework="np") as f: + return f.get_slice(f"dim{dim}/bounds") + + def _get_bounds(self, dim): + if self._tensors is None: + return self._cached_bounds(self._tensors_fp, dim) + else: + return self.tensors[f"dim{dim}/bounds"] + def _get_slice_indices( self, idx: int | slice | np.ndarray ) -> tuple[dict[str, slice], bool] | list[dict[str, slice]] | dict[str, slice]: @@ -1276,7 +1295,7 @@ def _get_slice_indices( end_i = S.stop if S.step not in (None, 1): - raise ValueError("Only slices with step size of None or 1 are supported; got {S.step}") + raise ValueError(f"Only slices with step size of None or 1 are supported; got {S.step}") out = {} for key in self.keys_at_dim(0): @@ -1290,11 +1309,7 @@ def _get_slice_indices( else: B_slice = slice(st_i - 1, end_i) - if self._tensors is None: - with safe_open(self._tensors_fp, framework="np") as f: - B = f.get_slice(f"dim{dim}/bounds")[B_slice] - else: - B = self.tensors[f"dim{dim}/bounds"][B_slice] + B = self._get_bounds(dim)[B_slice] if st_i == 0: offset = 0 @@ -1312,4 +1327,4 @@ def _get_slice_indices( return out case _: - raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing") + raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing") \ No newline at end of file