diff --git a/src/nested_ragged_tensors/ragged_numpy.py b/src/nested_ragged_tensors/ragged_numpy.py index 81fb70d..33c5746 100644 --- a/src/nested_ragged_tensors/ragged_numpy.py +++ b/src/nested_ragged_tensors/ragged_numpy.py @@ -87,8 +87,7 @@ def __init__( >>> print(J) # doctest: +NORMALIZE_WHITESPACE JointNestedRaggedTensorDict({'dim1/lengths': array([3, 2]), 'dim1/bounds': array([3, 5]), - 'dim1/A': [array([1, 2, 3], dtype=uint8), - array([4, 5], dtype=uint8)], + 'dim1/A': array([1, 2, 3, 4, 5], dtype=uint8), 'dim0/B': array([1, 2], dtype=uint8)}, schema={'A': , 'B': }, pre_raggedified=True) @@ -253,10 +252,9 @@ def _initialize_tensors(self, tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_ except TypeError as e: raise ValueError(f"Failed to parse {k} as a nested list of numbers!") from e + flat_vals = list(itertools.chain.from_iterable(vals)) if k not in self.schema: - self.schema[k] = self._infer_dtype(list(itertools.chain.from_iterable(vals))) - - vals = [np.array(v, dtype=self.schema[k]) for v in vals] + self.schema[k] = self._infer_dtype(flat_vals) dim_str = "dim0" for i, L in enumerate(lengths): @@ -270,7 +268,7 @@ def _initialize_tensors(self, tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_ self.tensors[lengths_key] = L self.tensors[f"{dim_str}/bounds"] = np.cumsum(L, axis=0) - self.tensors[f"{dim_str}/{k}"] = vals + self.tensors[f"{dim_str}/{k}"] = np.array(flat_vals, dtype=self.schema[k]) def save(self, fp: Path): """Saves the tensor to a file. See `JointNestedRaggedTensorDict.load` for examples. @@ -319,16 +317,9 @@ def load(cls, fp: Path) -> JointNestedRaggedTensorDict: tensors = {} schema = {} for k, v in flat_vals_tensors.items(): - if cls._is_meta_key(k): - tensors[k] = v - else: + tensors[k] = v + if not cls._is_meta_key(k): schema[k] = v.dtype - dim_str = k.split("/")[0] - if dim_str == "dim0": - tensors[k] = v - else: - bounds = flat_vals_tensors[f"{dim_str}/bounds"] - tensors[k] = np.split(v, bounds[:-1]) return cls(tensors, schema=schema, pre_raggedified=True) @@ -531,10 +522,7 @@ def __getitem__(self, idx: int | slice | np.ndarray): continue new_key = f"dim{dim_int - 1}/{key}" - if dim_int == 1: - out_tensors[new_key] = T[0] - else: - out_tensors[new_key] = T + out_tensors[new_key] = T return self.__class__(out_tensors, schema=self.schema, pre_raggedified=True) case slice() as S: @@ -552,9 +540,6 @@ def __getitem__(self, idx: int | slice | np.ndarray): L = self.tensors[f"dim{dim}/lengths"] out_tensors[f"dim{dim}/lengths"] = L[st_i:end_i] - for key in self.keys_at_dim(dim): - out_tensors[f"dim{dim}/{key}"] = self.tensors[f"dim{dim}/{key}"][st_i:end_i] - B = self.tensors[f"dim{dim}/bounds"] if st_i == 0: @@ -562,10 +547,21 @@ def __getitem__(self, idx: int | slice | np.ndarray): else: offset = B[st_i - 1] - out_tensors[f"dim{dim}/bounds"] = B[st_i:end_i] - offset + B = B[st_i:end_i] - offset + + out_tensors[f"dim{dim}/bounds"] = B + + vals_start = offset + if len(B) == 0: + vals_end = offset + else: + vals_end = B[-1] + offset + + for key in self.keys_at_dim(dim): + out_tensors[f"dim{dim}/{key}"] = self.tensors[f"dim{dim}/{key}"][vals_start:vals_end] - st_i = 0 if st_i == 0 else B[st_i - 1] - end_i = B[end_i - 1] if end_i is not None else B[-1] + st_i = offset + end_i = vals_end return JointNestedRaggedTensorDict(out_tensors, schema=self.schema, pre_raggedified=True) case _: @@ -662,6 +658,7 @@ def to_dense(self, padding_side: str = "right") -> dict[str, np.array]: ... ValueError: padding_side must be 'left' or 'right'; got 'up' """ + out = {key: self.tensors[f"dim0/{key}"] for key in self.keys_at_dim(0)} shape = [len(self)] @@ -705,8 +702,10 @@ def pad_slice(ln: int, max_ln: int) -> slice: for idx, ln in zip(indices, L): out[f"dim{dim}/mask"][idx + (pad_slice(ln, max_ln),)] = True + bounds = self.tensors[f"dim{dim}/bounds"] for key in self.keys_at_dim(dim): - slice_vals = self.tensors[f"dim{dim}/{key}"] + slice_vals = np.split(self.tensors[f"dim{dim}/{key}"], bounds[:-1]) + if not slice_vals: continue @@ -761,7 +760,7 @@ def unsqueeze(self, dim: int) -> JointNestedRaggedTensorDict: out_tensors = {} for key in self.keys_at_dim(0): - out_tensors[f"dim1/{key}"] = [self.tensors[f"dim0/{key}"]] + out_tensors[f"dim1/{key}"] = self.tensors[f"dim0/{key}"] if self.keys_at_dim(0): lengths = np.array([len(self.tensors[f"dim0/{key}"])]) @@ -981,16 +980,15 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict: out_tensors[bounds_key] = np.concatenate( (out_tensors[bounds_key], T.tensors[bounds_key] + last_bound) ) - - for key in out_keys_at_dim[dim]: - out_tensors[f"dim{dim}/{key}"] = ( - out_tensors[f"dim{dim}/{key}"] + T.tensors[f"dim{dim}/{key}"] - ) - elif dim == 0: - for key in out_keys_at_dim[dim]: - out_tensors[f"dim{dim}/{key}"] = np.concatenate( - (out_tensors[f"dim{dim}/{key}"], T.tensors[f"dim{dim}/{key}"]), axis=0 - ) + for key in out_keys_at_dim[dim]: + k_str = f"dim{dim}/{key}" + try: + out_tensors[k_str] = np.concatenate((out_tensors[k_str], T.tensors[k_str]), axis=0) + except Exception as e: + raise ValueError( + f"Failed to concatenate {key} at dim {dim} with args " + f"{out_tensors[k_str]} and {T.tensors[k_str]}" + ) from e return cls(out_tensors, pre_raggedified=True, schema=out_schema) @classmethod @@ -1104,7 +1102,7 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge for k in keys_by_dim[dim]: v = f.get_slice(k)[vals_start:vals_end] schema[k] = v.dtype - tensors[k] = np.split(v, B[:-1]) + tensors[k] = v # np.split(v, B[:-1]) st_i = 0 if st_i == 0 else offset end_i = B[-1] + offset