diff --git a/src/pie_datasets/core/dataset_dict.py b/src/pie_datasets/core/dataset_dict.py index 3003b06b..3341207f 100644 --- a/src/pie_datasets/core/dataset_dict.py +++ b/src/pie_datasets/core/dataset_dict.py @@ -737,27 +737,37 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset def concatenate_dataset_dicts( - inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool + inputs: Dict[str, DatasetDict], + split_mappings: Dict[str, Dict[str, Union[str, List[str]]]], + clear_metadata: bool, ): """Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be saved in Metadata. Args: inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate. - split_mappings: A mapping from target split names to mappings from input dataset names to - source split names. + split_mappings: A mapping from target split name to mappings from input dataset name to + source split name or list of names. clear_metadata: Whether to clear the metadata before concatenating. Returns: A dataset dict with keys in split_names as splits and content from the merged input dataset dicts. """ - input_splits = {} + input_splits: Dict[str, Dict[str, Union[Dataset, IterableDataset]]] = {} for target_split_name, mapping in split_mappings.items(): - input_splits[target_split_name] = { - ds_name: inputs[ds_name][source_split_name] - for ds_name, source_split_name in mapping.items() - } + input_splits[target_split_name] = {} + for ds_name, source_split_name in mapping.items(): + if isinstance(source_split_name, str): + input_splits[target_split_name][ds_name] = inputs[ds_name][source_split_name] + elif isinstance(source_split_name, list): + input_splits[target_split_name][ds_name] = concatenate_datasets( + [ + inputs[ds_name][_source_split_name] + for _source_split_name in source_split_name + ], + clear_metadata=clear_metadata, + ) result = DatasetDict( { diff --git a/tests/unit/core/test_dataset_dict.py b/tests/unit/core/test_dataset_dict.py index a43aa949..6d7b677d 100644 --- a/tests/unit/core/test_dataset_dict.py +++ b/tests/unit/core/test_dataset_dict.py @@ -724,3 +724,19 @@ def test_concatenate_dataset_dicts(tbga_extract, comagc_extract): assert all( [ds.metadata["dataset_name"] in ["tbga", "comagc"] for ds in concatenated_dataset["train"]] ) + + concatenated_dataset_with_list_in_mapping = concatenate_dataset_dicts( + inputs={"tbga": tbga_extract, "comagc": comagc_extract}, + split_mappings={"train": {"tbga": ["train", "test"], "comagc": "train"}}, + clear_metadata=True, + ) + + assert len(concatenated_dataset_with_list_in_mapping["train"]) == len( + tbga_extract["train"] + ) + len(tbga_extract["test"]) + len(comagc_extract["train"]) + assert all( + [ + ds.metadata["dataset_name"] in ["tbga", "comagc"] + for ds in concatenated_dataset_with_list_in_mapping["train"] + ] + )