diff --git a/src/nested_ragged_tensors/ragged_numpy.py b/src/nested_ragged_tensors/ragged_numpy.py index 2d403dc..5805f18 100644 --- a/src/nested_ragged_tensors/ragged_numpy.py +++ b/src/nested_ragged_tensors/ragged_numpy.py @@ -825,6 +825,10 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict: tensors: The tensors to concatenate. Examples: + >>> JointNestedRaggedTensorDict.concatenate([]) + Traceback (most recent call last): + ... + ValueError: Can't concatenate an empty list! >>> J1 = JointNestedRaggedTensorDict({ ... "T": [[1, 2, 3], [4, 5]], ... "id": [[[1, 2, 3], [3, 4], [1, 2]], [[3], [3, 2, 2]]], @@ -871,6 +875,30 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict: [4. , 2. , 0. ], [0. , 0. , 0. ], [3. , 0. , 0. ]]]) + >>> J1 = JointNestedRaggedTensorDict({ + ... "T": [1, 2, 3], + ... "id": [[1, 2, 3], [3, 4], [1, 2]], + ... }, schema={"T": int, "id": int, "val": float}) + >>> concatenated = JointNestedRaggedTensorDict.concatenate([J1]) + >>> dense_dict = concatenated.to_dense() + >>> dense_dict['T'] + array([1, 2, 3]) + >>> J2 = JointNestedRaggedTensorDict({ + ... "T": [6, 7, 8, 9], + ... "id": [[3], [3, 2, 2], [1], [1]], + ... }, schema={"T": int, "id": int, "val": float}) + >>> concatenated = JointNestedRaggedTensorDict.concatenate([J1, J2]) + >>> dense_dict = concatenated.to_dense() + >>> dense_dict['T'] + array([1, 2, 3, 6, 7, 8, 9]) + >>> dense_dict['id'] + array([[1, 2, 3], + [3, 4, 0], + [1, 2, 0], + [3, 0, 0], + [3, 2, 2], + [1, 0, 0], + [1, 0, 0]]) """ if len(tensors) == 1: @@ -912,10 +940,15 @@ def concatenate(cls, tensors: list) -> JointNestedRaggedTensorDict: (out_tensors[bounds_key], T.tensors[bounds_key] + out_tensors[bounds_key][-1]) ) - for key in out_keys_at_dim[dim]: - out_tensors[f"dim{dim}/{key}"] = ( - out_tensors[f"dim{dim}/{key}"] + T.tensors[f"dim{dim}/{key}"] - ) + for key in out_keys_at_dim[dim]: + out_tensors[f"dim{dim}/{key}"] = ( + out_tensors[f"dim{dim}/{key}"] + T.tensors[f"dim{dim}/{key}"] + ) + elif dim == 0: + for key in out_keys_at_dim[dim]: + out_tensors[f"dim{dim}/{key}"] = np.concatenate( + (out_tensors[f"dim{dim}/{key}"], T.tensors[f"dim{dim}/{key}"]), axis=0 + ) return cls(out_tensors, pre_raggedified=True, schema=out_schema) @classmethod