Skip to content

Commit

Permalink
Fixed an error in concatenation for dim-0 tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Sep 9, 2024
1 parent 60fdeae commit b2c746d
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions src/nested_ragged_tensors/ragged_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,10 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict:
tensors: The tensors to concatenate.
Examples:
>>> JointNestedRaggedTensorDict.concatenate([])
Traceback (most recent call last):
...
ValueError: Can't concatenate an empty list!
>>> J1 = JointNestedRaggedTensorDict({
... "T": [[1, 2, 3], [4, 5]],
... "id": [[[1, 2, 3], [3, 4], [1, 2]], [[3], [3, 2, 2]]],
Expand Down Expand Up @@ -871,6 +875,30 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict:
[4. , 2. , 0. ],
[0. , 0. , 0. ],
[3. , 0. , 0. ]]])
>>> J1 = JointNestedRaggedTensorDict({
... "T": [1, 2, 3],
... "id": [[1, 2, 3], [3, 4], [1, 2]],
... }, schema={"T": int, "id": int, "val": float})
>>> concatenated = JointNestedRaggedTensorDict.concatenate([J1])
>>> dense_dict = concatenated.to_dense()
>>> dense_dict['T']
array([1, 2, 3])
>>> J2 = JointNestedRaggedTensorDict({
... "T": [6, 7, 8, 9],
... "id": [[3], [3, 2, 2], [1], [1]],
... }, schema={"T": int, "id": int, "val": float})
>>> concatenated = JointNestedRaggedTensorDict.concatenate([J1, J2])
>>> dense_dict = concatenated.to_dense()
>>> dense_dict['T']
array([1, 2, 3, 6, 7, 8, 9])
>>> dense_dict['id']
array([[1, 2, 3],
[3, 4, 0],
[1, 2, 0],
[3, 0, 0],
[3, 2, 2],
[1, 0, 0],
[1, 0, 0]])
"""

if len(tensors) == 1:
Expand Down Expand Up @@ -912,10 +940,15 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict:
(out_tensors[bounds_key], T.tensors[bounds_key] + out_tensors[bounds_key][-1])
)

for key in out_keys_at_dim[dim]:
out_tensors[f"dim{dim}/{key}"] = (
out_tensors[f"dim{dim}/{key}"] + T.tensors[f"dim{dim}/{key}"]
)
for key in out_keys_at_dim[dim]:
out_tensors[f"dim{dim}/{key}"] = (
out_tensors[f"dim{dim}/{key}"] + T.tensors[f"dim{dim}/{key}"]
)
elif dim == 0:
for key in out_keys_at_dim[dim]:
out_tensors[f"dim{dim}/{key}"] = np.concatenate(
(out_tensors[f"dim{dim}/{key}"], T.tensors[f"dim{dim}/{key}"]), axis=0
)
return cls(out_tensors, pre_raggedified=True, schema=out_schema)

@classmethod
Expand Down

0 comments on commit b2c746d

Please sign in to comment.