Skip to content

Commit

Permalink
Merge pull request #45 from mmcdermott/improve_test_coverage
Browse files Browse the repository at this point in the history
Improve test coverage to 100%
  • Loading branch information
mmcdermott authored Nov 7, 2024
2 parents 84faf1c + e168dd7 commit d915700
Showing 1 changed file with 215 additions and 24 deletions.
239 changes: 215 additions & 24 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def __init__(
... J.save(fp)
... J3 = JointNestedRaggedTensorDict(tensors_fp=fp)
... assert J == J3
>>> with tempfile.TemporaryDirectory() as dirpath:
... fp = Path(dirpath) / "tensors.nrt"
... J.save(fp)
... JointNestedRaggedTensorDict(tensors_fp=fp) # doctest: +NORMALIZE_WHITESPACE
JointNestedRaggedTensorDict(tensors_fp=.../tensors.nrt,
schema={'B': dtype('uint8'), 'A': dtype('uint8')})
>>> JointNestedRaggedTensorDict({"S": []})
Traceback (most recent call last):
...
Expand All @@ -121,6 +127,13 @@ def __init__(
Traceback (most recent call last):
...
FileNotFoundError: Tensors filepath must exist, got ...
>>> JointNestedRaggedTensorDict({
... "id": [[1], [3, 4], [1]],
... "D": [[[[5, 6], [7]], [[[8]], [[]]], [[[9, 10]]]]],
... })
Traceback (most recent call last):
...
ValueError: Failed to parse D as a nested list of numbers!
"""
args = [
("raw_tensors", raw_tensors),
Expand All @@ -146,7 +159,29 @@ def __init__(
self._tensors_fp = tensors_fp

def __eq__(self, other: object) -> bool:
"""Checks if this JointNestedRaggedTensorDict is equal to another object."""
"""Checks if this JointNestedRaggedTensorDict is equal to another object.
Examples:
>>> data = {"A": [[1, 2, 3], [4, 5]], "B": [1, 2]}
>>> J = JointNestedRaggedTensorDict(data)
>>> J == J
True
>>> J == JointNestedRaggedTensorDict(data)
True
>>> J == JointNestedRaggedTensorDict({"A": [[1, 2, 4], [4, 5]], "B": [1, 2]})
False
>>> J == data
False
>>> J == JointNestedRaggedTensorDict({"A": data["A"]})
False
>>> import tempfile
>>> with tempfile.NamedTemporaryFile() as f:
... fp = Path(f.name)
... J.save(fp)
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... J == J2
True
"""

if not isinstance(other, JointNestedRaggedTensorDict):
return False
Expand All @@ -167,20 +202,19 @@ def __repr__(self) -> str:
return f"{prefix}processed_tensors={self._tensors}, {schema_arg})"
elif self._tensors_fp is not None:
return f"{prefix}tensors_fp={str(self._tensors_fp)}, {schema_arg})"
else:
raise ValueError("No tensors found!")
else: # pragma: no cover
raise ValueError("No tensors found! This error should not happen")

def __str__(self) -> str:
return self.__repr__()

@property
def schema(self) -> dict[str, np.dtype]:
if not self._schema:
if self._schema is None:
self._schema = {}

for k in self._tensor_keys:
for k in sorted(self._tensor_keys):
dim, key = k.split("/")
if key == "bounds":
continue
with self._tensor_at_key(k) as T:
self._schema[key] = T[:1].dtype
return self._schema
Expand Down Expand Up @@ -361,7 +395,7 @@ def _initialize_tensors(self, tensors: dict[str, list[NESTED_NUM_LIST_T] | NESTE
bounds_key = f"{dim_str}/bounds"
B = np.cumsum(L, axis=0)
if bounds_key in self._tensors:
if not np.array_equal(self._tensors[bounds_key], B):
if not np.array_equal(self._tensors[bounds_key], B): # pragma: no cover
raise ValueError(f"Inconsistent bounds tensors! {self._tensors[bounds_key]} vs. {B}")
else:
self._tensors[bounds_key] = B
Expand All @@ -386,6 +420,14 @@ def save(self, fp: Path):
... J.save(fp)
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... assert J == J2
>>> with tempfile.TemporaryDirectory() as dirpath:
... fp = Path(dirpath) / "tensors.nrt"
... J.save(fp)
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... J2.save(fp)
Traceback (most recent call last):
...
ValueError: Already saved to .../tensors.nrt!
"""
if self._tensors is None:
raise ValueError(f"Already saved to {self._tensors_fp}!")
Expand Down Expand Up @@ -474,12 +516,17 @@ def _get_dim(self, key: str) -> int:
1
>>> J._get_dim("id")
2
>>> J._get_dim("value")
Traceback (most recent call last):
...
KeyError: "Key 'value' not found in 'T', 'id', 'val'"
"""
for k in self._tensor_keys:
if k.endswith(f"/{key}"):
return self._get_dim_from_key_str(k)

raise KeyError(f"Key {key} not found in {', '.join(self.tensors.keys())}")
keys = "', '".join(sorted(self.keys()))
raise KeyError(f"Key '{key}' not found in '{keys}'")

def keys_at_dim(self, dim: int) -> set[str]:
"""Returns the keys for tensors that are at that dimensionality.
Expand Down Expand Up @@ -738,12 +785,8 @@ def pad_slice(ln: int, max_ln: int) -> slice:
)

B = self.tensors[f"dim{dim}/bounds"]
if len(B) > 0:
L = np.concatenate([[B[0]], np.diff(B, axis=0)], axis=0)
max_ln = max(L)
else:
L = []
max_ln = 0
L = np.concatenate([[B[0]], np.diff(B, axis=0)], axis=0)
max_ln = max(L)

shape.append(max_ln)

Expand Down Expand Up @@ -801,6 +844,10 @@ def squeeze(self, dim: int) -> JointNestedRaggedTensorDict:
[[3. , 0. , 0. ],
[3.3, 2. , 0. ],
[0. , 0. , 0. ]]])
>>> J.squeeze(dim=1)
Traceback (most recent call last):
...
ValueError: Only supports dim = 0 for now; got 1
"""
if dim != 0:
raise ValueError(f"Only supports dim = 0 for now; got {dim}")
Expand Down Expand Up @@ -831,10 +878,10 @@ def unsqueeze(self, dim: int) -> JointNestedRaggedTensorDict:
ValueError: If dim != 0 or if the tensors are exclusively non-ragged 1D tensors.
Examples:
>>> J = JointNestedRaggedTensorDict({"T": [1, 2, 3]}, schema={"T": int})
>>> J = JointNestedRaggedTensorDict({"T": [[1, 2, 3]]}, schema={"T": int})
>>> dense_dict = J.unsqueeze(dim=0).to_dense()
>>> dense_dict['T']
array([[1, 2, 3]])
array([[[1, 2, 3]]])
>>> J = JointNestedRaggedTensorDict({
... "T": [1, 2],
... "id": [[[1, 2, 3], [3, 4], [1, 2]], [[3], [3, 2, 2]]],
Expand All @@ -859,6 +906,10 @@ def unsqueeze(self, dim: int) -> JointNestedRaggedTensorDict:
[[3. , 0. , 0. ],
[3.3, 2. , 0. ],
[0. , 0. , 0. ]]]])
>>> J.unsqueeze(dim=1)
Traceback (most recent call last):
...
ValueError: Only supports dim = 0 for now; got 1
"""
if dim != 0:
raise ValueError(f"Only supports dim = 0 for now; got {dim}")
Expand Down Expand Up @@ -950,10 +1001,31 @@ def flatten(self, dim: int = -1) -> JointNestedRaggedTensorDict:
array([1, 2, 3, 3, 4, 1, 2], dtype=uint8)
>>> dense_dict['val']
array([1. , 0.2, 0. , 3.1, 0. , 1. , 2.2], dtype=float32)
>>> J = JointNestedRaggedTensorDict({
... "ts": [1, 2, 3],
... "id": [[1], [3, 4], [1]],
... "D": [[[[5, 6], [7]]], [[[8]], [[]]], [[[9, 10]]]],
... })
>>> dense_dict = J.flatten().to_dense()
>>> dense_dict['ts']
array([1, 2, 3], dtype=uint8)
>>> dense_dict['id']
array([[1, 0],
[3, 4],
[1, 0]], dtype=uint8)
>>> dense_dict['D']
array([[[ 5, 6, 7],
[ 0, 0, 0]],
<BLANKLINE>
[[ 8, 0, 0],
[ 0, 0, 0]],
<BLANKLINE>
[[ 9, 10, 0],
[ 0, 0, 0]]], dtype=uint8)
>>> J.flatten(dim=0)
Traceback (most recent call last):
...
ValueError: Only supports dim = -1 or 1 for now; got 0
ValueError: Only supports dim = -1 or 3 for now; got 0
"""
if dim < 0:
target_dim = self.max_n_dims + dim
Expand Down Expand Up @@ -1017,6 +1089,15 @@ def __len__(self) -> int:
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... len(J2)
2
>>> J = JointNestedRaggedTensorDict({"T": [1, 2, 3]})
>>> len(J)
3
>>> with tempfile.NamedTemporaryFile() as f:
... fp = Path(f.name)
... J.save(fp)
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... len(J2)
3
"""
if self._tensors is None:
with safe_open(self._tensors_fp, framework="np") as f:
Expand Down Expand Up @@ -1169,6 +1250,31 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict:
array([9])
>>> dense_dict['id']
array([[1]])
>>> J1 = JointNestedRaggedTensorDict({"T": [1, 2, 3]})
>>> J2 = JointNestedRaggedTensorDict({"B": [1, 2, 3]})
>>> JointNestedRaggedTensorDict.concatenate([J1, J2])
Traceback (most recent call last):
...
ValueError: Keys inconsistent! {'B'} != {'T'}
>>> J1 = JointNestedRaggedTensorDict({"T": [1, 2, 3]})
>>> J2 = JointNestedRaggedTensorDict({"T": [[1, 2, 3]]})
>>> JointNestedRaggedTensorDict.concatenate([J1, J2])
Traceback (most recent call last):
...
ValueError: Max dims inconsistent! 2 != 1
>>> J1 = JointNestedRaggedTensorDict({"T": [1, 2, 3]})
>>> J2 = JointNestedRaggedTensorDict({"T": [1.1, 2.1, 3.4]})
>>> JointNestedRaggedTensorDict.concatenate([J1, J2]) # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: Schema inconsistent!
{'T': <class 'numpy.float32'>} != {'T': <class 'numpy.uint8'>}
>>> J1 = JointNestedRaggedTensorDict({"T": [1, 2, 3], "B": [[1], [2], [3]]})
>>> J2 = JointNestedRaggedTensorDict({"B": [1, 2, 3], "T": [[1], [2], [3]]})
>>> JointNestedRaggedTensorDict.concatenate([J1, J2])
Traceback (most recent call last):
...
ValueError: Keys inconsistent @ dim 0! {'B'} != {'T'}
"""

if len(tensors) == 1:
Expand Down Expand Up @@ -1210,7 +1316,7 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict:
k_str = f"dim{dim}/{key}"
try:
out_tensors[k_str] = np.concatenate((out_tensors[k_str], T.tensors[k_str]), axis=0)
except Exception as e:
except Exception as e: # pragma: no cover
raise ValueError(
f"Failed to concatenate {key} at dim {dim} with args "
f"{out_tensors[k_str]} and {T.tensors[k_str]}"
Expand All @@ -1219,10 +1325,28 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict:

def _slice_single(
self,
indices: dict[str, slice | np.ndarray],
indices: dict[str, slice],
squeeze_dims: list[int] | None = None,
) -> JointNestedRaggedTensorDict:
"""Slices this collection of tensors by the given indices."""
"""Slices this collection of tensors by the given indices.
Args:
indices: The indices to slice by, structured as a dictionary of tensor keys to slices.
squeeze_dims: The dimensions to squeeze.
Returns:
A new JointNestedRaggedTensorDict that is a slice of this one.
Examples:
>>> J = JointNestedRaggedTensorDict(
... {"T": [1, 2, 3], "id": [[1, 2, 3], [3, 4], [1, 2]]},
... schema={"T": int, "id": int, "val": float}
... )
>>> J._slice_single({"dim0/T": [1, 3]})
Traceback (most recent call last):
...
TypeError: <class 'list'> not supported for JointNestedRaggedTensorDict slicing
"""

tensors = {}
schema = {}
Expand Down Expand Up @@ -1273,6 +1397,11 @@ def _slice(
A new JointNestedRaggedTensorDict that is a slice of this one.
Examples:
>>> J = JointNestedRaggedTensorDict({"T": [1, 2, 3]})
>>> J._slice(3.4)
Traceback (most recent call last):
...
TypeError: <class 'float'> not supported for JointNestedRaggedTensorDict slicing
"""
match indices:
case tuple() as T:
Expand All @@ -1287,7 +1416,29 @@ def _slice(
def _get_slice_indices(
self, idx: int | slice | tuple | np.ndarray
) -> tuple[dict[str, slice], bool] | list[dict[str, slice]] | dict[str, slice]:
"""Returns the start and end indices for each dimension of self after slicing by idx."""
"""Returns the start and end indices for each dimension of self after slicing by idx.
Args:
idx: The index to slice by.
Returns:
The slice that should be used for each nested tensor by key.
Examples:
>>> J = JointNestedRaggedTensorDict({"T": [1, 2, 3]})
>>> J._get_slice_indices(1)
({'dim0/T': slice(1, 2, None)}, [0])
>>> J._get_slice_indices(slice(1, 3))
{'dim0/T': slice(1, 3, None)}
>>> J._get_slice_indices([1, 2])
Traceback (most recent call last):
...
TypeError: <class 'list'> not supported for JointNestedRaggedTensorDict slicing
>>> J._get_slice_indices((1, 2.4))
Traceback (most recent call last):
...
TypeError: <class 'float'> at index 1 not supported for JointNestedRaggedTensorDict tuple slicing
"""

match idx:
case np.ndarray() as arr if arr.dtype in (NP_INT_TYPES + NP_UINT_TYPES) and arr.ndim == 1:
Expand Down Expand Up @@ -1323,13 +1474,53 @@ def _get_slice_indices(
def _get_slice_indices_internal(
self, idx: slice, starting_dim: int, curr_indices: dict[str, slice]
) -> dict[str, slice]:
"""Returns the start and end indices for each dimension of self after slicing by idx."""
"""Returns the resolved slice for the given input slice and starting dimension.
Slice resolution takes into account the nested bounds of the ragged tensor.
Args:
idx: The slice to resolve.
starting_dim: The dimension to start resolving from.
curr_indices: The current resolved indices.
Returns:
The updated resolved indices.
Examples:
>>> J = JointNestedRaggedTensorDict({
... "T": [1, 2, 3],
... "id": [[1, 2], [3, 4], [5, 6]],
... "val": [[[1.0], []], [[3.1, 0.], [1., 2.2]], [[3], [3.3, 2., 0]]],
... })
>>> J._get_slice_indices_internal(slice(1, 3), 0, {}) # doctest: +NORMALIZE_WHITESPACE
{'dim0/T': slice(1, 3, None),
'dim1/bounds': slice(1, 3, None),
'dim1/id': slice(np.int64(2), np.int64(6), None),
'dim2/bounds': slice(np.int64(2), np.int64(6), None),
'dim2/val': slice(np.int64(1), np.int64(9), None)}
>>> J._get_slice_indices_internal(slice(3, 3), 0, {}) # doctest: +NORMALIZE_WHITESPACE
{'dim0/T': slice(3, 3, None),
'dim1/bounds': slice(3, 3, None),
'dim1/id': slice(np.int64(6), np.int64(6), None),
'dim2/bounds': slice(np.int64(6), np.int64(6), None),
'dim2/val': slice(np.int64(9), np.int64(9), None)}
>>> J._get_slice_indices_internal(slice(4, 4), 0, {}) # doctest: +NORMALIZE_WHITESPACE
{'dim0/T': slice(4, 4, None),
'dim1/bounds': slice(4, 4, None),
'dim1/id': slice(np.int64(6), np.int64(6), None),
'dim2/bounds': slice(np.int64(6), np.int64(6), None),
'dim2/val': slice(np.int64(9), np.int64(9), None)}
>>> J._get_slice_indices_internal(slice(1, 3, 2), 0, {})
Traceback (most recent call last):
...
ValueError: Only slices with step size of None or 1 are supported; got 2
"""

st_i = 0 if idx.start is None else idx.start
end_i = idx.stop

if idx.step not in (None, 1):
raise ValueError("Only slices with step size of None or 1 are supported; got {idx.step}")
raise ValueError(f"Only slices with step size of None or 1 are supported; got {idx.step}")

out = {**curr_indices}

Expand Down

0 comments on commit d915700

Please sign in to comment.