Skip to content

Commit

Permalink
Reworked load slice to use a helper. Not sure if this is the right di…
Browse files Browse the repository at this point in the history
…rection yet.
  • Loading branch information
mmcdermott committed Sep 11, 2024
1 parent d239733 commit 1367ac3
Showing 1 changed file with 188 additions and 67 deletions.
255 changes: 188 additions & 67 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import itertools
from collections import defaultdict
from collections.abc import Sequence
from functools import cached_property
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -67,17 +67,19 @@ class JointNestedRaggedTensorDict:

def __init__(
self,
tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_NUM_LIST],
tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_NUM_LIST] | None = None,
schema: dict[str, np.dtype] | None = None,
pre_raggedified: bool = False,
pre_raggedified: bool | None = None,
fp: Path | None = None,
):
"""Initializes JointNestedRaggedTensorDict with the given tensors.
Args:
tensors: The tensors to be stored.
tensors: The tensors to be stored. If `None`, then `fp` must be provided.
schema: The schema for the tensors, if known.
pre_raggedified: If `True`, the tensors are assumed to be pre-raggedified and are stored as-is. If
`False`, the tensors are assumed to be raw data and are raggedified.
`False`, the tensors are assumed to be raw data and are raggedified. If `fp` is provided, must
be True.
Examples:
>>> J = JointNestedRaggedTensorDict({
Expand All @@ -98,11 +100,36 @@ def __init__(
ValueError: Empty list found for key S! Nested Ragged Tensors does not support empty tensors.
"""

self.schema = schema if schema is not None else {}
if pre_raggedified:
self.tensors = tensors
if tensors is None and fp is None:
raise ValueError("Either tensors or fp must be provided!")

if pre_raggedified is None:
pre_raggedified = fp is not None

if fp is None:
self.schema = schema if schema is not None else {}
self._fp = None
if pre_raggedified:
self._tensors = tensors
else:
self._initialize_tensors(tensors)
else:
self._fp = fp
self._tensors = None

@property
def tensors(self) -> dict[str, np.ndarray]:
if self._tensors is None:
self._tensors = self.load(self._fp)._tensors
return self._tensors

@cached_property
def _tensor_keys(self) -> set[str]:
if self._tensors is None:
with safe_open(self._fp, framework="np") as f:
return set(f.keys())
else:
self._initialize_tensors(tensors)
return set(self._tensors.keys())

def __repr__(self) -> str:
return f"JointNestedRaggedTensorDict({self.tensors}, schema={self.schema}, pre_raggedified=True)"
Expand Down Expand Up @@ -233,7 +260,7 @@ def _infer_dtype(cls, vals: Sequence[NUM_T]) -> np.dtype:

def _initialize_tensors(self, tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_NUM_LIST]):
"""Initializes the tensors from lists of raw data entries."""
self.tensors = {}
self._tensors = {}
for k, T in tensors.items():
if len(T) == 0:
raise ValueError(
Expand All @@ -244,7 +271,7 @@ def _initialize_tensors(self, tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_
dim_str = "dim0"
if k not in self.schema:
self.schema[k] = self._infer_dtype(T)
self.tensors[f"{dim_str}/{k}"] = np.array(T, dtype=self.schema[k])
self._tensors[f"{dim_str}/{k}"] = np.array(T, dtype=self.schema[k])
continue

try:
Expand All @@ -263,21 +290,25 @@ def _initialize_tensors(self, tensors: dict[str, list[NESTED_NUM_LIST] | NESTED_
dim_str = f"dim{i+1}"

lengths_key = f"{dim_str}/lengths"
if lengths_key in self.tensors:
if not np.array_equal(self.tensors[lengths_key], L):
raise ValueError(f"Inconsistent lengths tensors! {self.tensors[lengths_key]} vs. {L}")
if lengths_key in self._tensors:
if not np.array_equal(self._tensors[lengths_key], L):
raise ValueError(
f"Inconsistent lengths tensors! {self._tensors[lengths_key]} vs. {L}"
)
else:
self.tensors[lengths_key] = L
self.tensors[f"{dim_str}/bounds"] = np.cumsum(L, axis=0)
self._tensors[lengths_key] = L
self._tensors[f"{dim_str}/bounds"] = np.cumsum(L, axis=0)

self.tensors[f"{dim_str}/{k}"] = vals
self._tensors[f"{dim_str}/{k}"] = vals

def save(self, fp: Path):
"""Saves the tensor to a file. See `JointNestedRaggedTensorDict.load` for examples.
Args:
fp: The path to which the tensors will be saved.
"""
if self._fp is not None:
raise ValueError("Cannot save a tensor that was loaded from a file!")
save_file(self._tensors_with_flat_values, fp)

@property
Expand Down Expand Up @@ -345,7 +376,7 @@ def max_n_dims(self) -> int:
>>> J.max_n_dims
3
"""
return max(int(k.split("/")[0][3:]) for k in self.tensors.keys()) + 1
return max(int(k.split("/")[0][3:]) for k in self._tensor_keys) + 1

@property
def min_n_dims(self) -> int:
Expand All @@ -360,7 +391,7 @@ def min_n_dims(self) -> int:
>>> J.min_n_dims
2
"""
return min(int(k.split("/")[0][3:]) for k in self.tensors.keys()) + 1
return min(int(k.split("/")[0][3:]) for k in self._tensor_keys) + 1

def keys(self) -> set[str]:
"""Returns the set of all keys for the stored tensors.
Expand All @@ -373,7 +404,7 @@ def keys(self) -> set[str]:
... })
>>> assert J.keys() == {'id', 'T', 'val'}
"""
return {k.split("/")[1] for k in self.tensors.keys() if not self._is_meta_key(k)}
return {k.split("/")[1] for k in self._tensor_keys if not self._is_meta_key(k)}

@classmethod
def _is_meta_key(cls, k: str) -> bool:
Expand Down Expand Up @@ -418,11 +449,11 @@ def _get_dim(self, key: str) -> int:
>>> J._get_dim("id")
2
"""
for k in self.tensors.keys():
for k in self._tensor_keys:
if k.endswith(f"/{key}"):
return self._get_dim_from_key_str(k)

raise KeyError(f"Key {key} not found in {', '.join(self.tensors.keys())}")
raise KeyError(f"Key {key} not found in {', '.join(self._tensor_keys)}")

def keys_at_dim(self, dim: int) -> set[str]:
"""Returns the keys for tensors that are at that dimensionality.
Expand Down Expand Up @@ -707,7 +738,7 @@ def pad_slice(ln: int, max_ln: int) -> slice:

for key in self.keys_at_dim(dim):
slice_vals = self.tensors[f"dim{dim}/{key}"]
if not slice_vals:
if len(slice_vals) == 0:
continue

out[key] = np.zeros(shape=tuple(shape), dtype=slice_vals[0].dtype)
Expand Down Expand Up @@ -1032,7 +1063,9 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge
>>> J_dense = J.to_dense()
>>> J2_dense = J2.to_dense()
>>> for k in J_dense.keys():
... assert (J_dense[k] == J2_dense[k]).all(), f"Tensors at {k} unequal!"
... assert (J_dense[k] == J2_dense[k]).all(), (
... f"Tensors at {k} unequal! Want {J_dense[k]}, got {J2_dense[k]}"
... )
>>> J = JointNestedRaggedTensorDict({
... "T": [[1, 2, 3 ], [4, 5 ]],
... "id": [[[1, 2, 3], [3, 4], [1, 2 ]], [[3], [3, 2, 2]]],
Expand All @@ -1047,68 +1080,156 @@ def load_slice(cls, fp: Path, idx: int | slice | np.ndarray) -> JointNestedRagge
>>> J_dense = J.to_dense()
>>> J2_dense = J2.to_dense()
>>> for k in J_dense.keys():
... assert (J_dense[k] == J2_dense[k]).all(), f"Tensors at {k} unequal!"
... assert (J_dense[k] == J2_dense[k]).all(), (
... f"Tensors at {k} unequal! Want {J_dense[k]}, got {J2_dense[k]}"
... )
"""

J = cls(fp=fp)
indices = J._get_slice_indices(idx)

return J._slice(indices)

def _slice_single(
self, indices: dict[str, slice | np.ndarray], reduce_dim: bool = False
) -> JointNestedRaggedTensorDict:
""""""
# case int() as i:
# with_dim = self[slice(i, i + 1)]
# out_tensors = {}
# for k, T in with_dim.tensors.items():
# dim, key = k.split("/")
# dim_int = int(dim[3:])

# if dim_int == 1 and key in ("lengths", "bounds"):
# # These keys will be dropped as this tensor will become a 1D tensor in truth.
# continue

# new_key = f"dim{dim_int - 1}/{key}"
# if dim_int == 1:
# out_tensors[new_key] = T[0]
# else:
# out_tensors[new_key] = T

# return self.__class__(out_tensors, schema=self.schema, pre_raggedified=True)

tensors = {}
schema = {}

for k, idx in indices.items():
old_dim, key = k.split("/")
old_dim_int = int(old_dim[3:])

if reduce_dim:
if old_dim_int == 1 and key in ("lengths", "bounds"):
# These keys will be dropped as this tensor will become a 1D tensor in truth.
continue
new_key = f"dim{old_dim_int - 1}/{key}"
else:
new_key = k

match idx:
case slice() as S:
if self._tensors is None:
with safe_open(self._fp, framework="np") as f:
tensors[new_key] = f.get_slice(k)[S]
else:
tensors[new_key] = self._tensors[k][S]
case np.ndarray() as arr if arr.ndim == 1:
tensors[new_key] = arr
case _:
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing")

schema[new_key] = tensors[new_key].dtype

if self._tensors is None:
new_max_dim = self.max_n_dims - 1 if reduce_dim else self.max_n_dims

for new_dim in range(1, new_max_dim):
old_dim = new_dim + 1 if reduce_dim else new_dim
B = tensors[f"dim{new_dim}/bounds"]
for key in self.keys_at_dim(old_dim):
tensors[f"dim{new_dim}/{key}"] = np.split(tensors[f"dim{new_dim}/{key}"], B[:-1])

return self.__class__(tensors, schema=schema, fp=None, pre_raggedified=True)

def _slice(
self,
indices: tuple[dict[str, slice | np.ndarray], bool]
| list[tuple[dict[str, slice | np.ndarray], bool]],
) -> JointNestedRaggedTensorDict:
"""Returns a new JointNestedRaggedTensorDict that is a slice of this one.
Args:
indices: The indices to slice by.
Returns:
A new JointNestedRaggedTensorDict that is a slice of this one.
Examples:
"""
match indices:
case tuple() as T:
return self._slice_single(*T)
case dict():
return self._slice_single(indices, reduce_dim=False)
case list():
return self.__class__.vstack([self._slice_single(idx, reduce_dim=True) for idx in indices])
case _:
raise TypeError(f"{type(indices)} not supported for {self.__class__.__name__} slicing")

def _get_slice_indices(
self, idx: int | slice | np.ndarray
) -> tuple[dict[str, slice], bool] | list[dict[str, slice]] | dict[str, slice]:
"""Returns the start and end indices for each dimension of self after slicing by idx."""

match idx:
case np.ndarray() as arr if arr.dtype in (NP_INT_TYPES + NP_UINT_TYPES) and arr.ndim == 1:
return cls.vstack([cls.load_slice(fp, int(i)) for i in arr])
return [self._get_slice_indices(slice(i, i + 1)) for i in arr]
case int() as i:
return cls.load_slice(fp, slice(i, i + 1))[0]
return (self._get_slice_indices(slice(i, i + 1)), True)
case slice() as S:
st_i = 0 if S.start is None else S.start
end_i = S.stop

if S.step not in (None, 1):
raise ValueError("Only slices with step size of None or 1 are supported; got {S.step}")

tensors = {}
schema = {}
with safe_open(fp, framework="np") as f:
keys_by_dim = defaultdict(list)

for k in f.keys():
if cls._is_meta_key(k):
continue

keys_by_dim[cls._get_dim_from_key_str(k)].append(k)
out = {}
for key in self.keys_at_dim(0):
out[f"dim0/{key}"] = slice(st_i, end_i)

max_n_dims = max(keys_by_dim.keys()) + 1
for dim in range(1, self.max_n_dims):
out[f"dim{dim}/lengths"] = slice(st_i, end_i)

for key in keys_by_dim[0]:
v = f.get_slice(k)[st_i:end_i]
schema[k] = v.dtype
tensors[k] = v
if st_i == 0:
B_slice = slice(0, end_i)
else:
B_slice = slice(st_i - 1, end_i)

for dim in range(1, max_n_dims):
try:
tensors[f"dim{dim}/lengths"] = f.get_slice(f"dim{dim}/lengths")[st_i:end_i]
except SystemError as e:
raise ValueError(
f"Error loading lengths for dim {dim} with st_i={st_i} and end_i={end_i}"
) from e
if self._tensors is None:
with safe_open(self._fp, framework="np") as f:
B = f.get_slice(f"dim{dim}/bounds")[B_slice]
else:
B = self.tensors[f"dim{dim}/bounds"][B_slice]

if st_i == 0:
offset = 0
B = f.get_slice(f"dim{dim}/bounds")[st_i:end_i]
else:
B = f.get_slice(f"dim{dim}/bounds")[st_i - 1 : end_i]
offset = B[0]
B = B[1:] - offset
if st_i == 0:
offset = 0
else:
offset = B[0]
B = B[1:] - offset

vals_start = offset
vals_end = B[-1] + offset
vals_start = offset
vals_end = B[-1] + offset

tensors[f"dim{dim}/bounds"] = B
for key in self.keys_at_dim(dim):
out[f"dim{dim}/{key}"] = slice(vals_start, vals_end)

for k in keys_by_dim[dim]:
v = f.get_slice(k)[vals_start:vals_end]
schema[k] = v.dtype
tensors[k] = np.split(v, B[:-1])
out[f"dim{dim}/bounds"] = B

st_i = 0 if st_i == 0 else offset
end_i = B[-1] + offset
st_i = offset
end_i = B[-1] + offset

return cls(tensors, schema=schema, pre_raggedified=True)
return out
case _:
raise TypeError(f"{type(idx)} not supported for {cls.__name__}.load_slice")
raise TypeError(f"{type(idx)} not supported for {self.__class__.__name__} slicing")

0 comments on commit 1367ac3

Please sign in to comment.