Skip to content

Commit

Permalink
extend split_mappings format
Browse files Browse the repository at this point in the history
  • Loading branch information
RainbowRivey committed Nov 4, 2024
1 parent 2b3ba52 commit c0638e4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/pie_datasets/core/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,27 +737,37 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset


def concatenate_dataset_dicts(
inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool
inputs: Dict[str, DatasetDict],
split_mappings: Dict[str, Dict[str, Union[str, List[str]]]],
clear_metadata: bool,
):
"""Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be
saved in Metadata.
Args:
inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate.
split_mappings: A mapping from target split names to mappings from input dataset names to
source split names.
split_mappings: A mapping from target split name to mappings from input dataset name to
source split name or list of names.
clear_metadata: Whether to clear the metadata before concatenating.
Returns: A dataset dict with keys in split_names as splits and content from the merged input
dataset dicts.
"""

input_splits = {}
input_splits: Dict[str, Dict[str, Union[Dataset, IterableDataset]]] = {}
for target_split_name, mapping in split_mappings.items():
input_splits[target_split_name] = {
ds_name: inputs[ds_name][source_split_name]
for ds_name, source_split_name in mapping.items()
}
input_splits[target_split_name] = {}
for ds_name, source_split_name in mapping.items():
if isinstance(source_split_name, str):
input_splits[target_split_name][ds_name] = inputs[ds_name][source_split_name]
elif isinstance(source_split_name, list):
input_splits[target_split_name][ds_name] = concatenate_datasets(
[
inputs[ds_name][_source_split_name]
for _source_split_name in source_split_name
],
clear_metadata=clear_metadata,
)

result = DatasetDict(
{
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/core/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,19 @@ def test_concatenate_dataset_dicts(tbga_extract, comagc_extract):
assert all(
[ds.metadata["dataset_name"] in ["tbga", "comagc"] for ds in concatenated_dataset["train"]]
)

concatenated_dataset_with_list_in_mapping = concatenate_dataset_dicts(
inputs={"tbga": tbga_extract, "comagc": comagc_extract},
split_mappings={"train": {"tbga": ["train", "test"], "comagc": "train"}},
clear_metadata=True,
)

assert len(concatenated_dataset_with_list_in_mapping["train"]) == len(
tbga_extract["train"]
) + len(tbga_extract["test"]) + len(comagc_extract["train"])
assert all(
[
ds.metadata["dataset_name"] in ["tbga", "comagc"]
for ds in concatenated_dataset_with_list_in_mapping["train"]
]
)

0 comments on commit c0638e4

Please sign in to comment.