Skip to content

Commit

Permalink
add parameter clean_metadata to concatenate_datasets and `concate…
Browse files Browse the repository at this point in the history
…nate_dataset_dicts`
  • Loading branch information
RainbowRivey committed Sep 30, 2024
1 parent f09940c commit 4719a07
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
21 changes: 13 additions & 8 deletions src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,38 +673,43 @@ def get_pie_dataset_type(
)


def _add_dset_name_to_document(doc: Document, name: str) -> Document:
def _add_dset_name_to_document(doc: Document, name: str, clear_metadata: bool) -> Document:
if not hasattr(doc, "metadata"):
raise ValueError(
f"Document does not have metadata attribute which required to save the dataset name: {doc}"
)
# Keep the old name if available
if "dataset_name" in doc.metadata:
raise ValueError(
f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}"
)
doc.metadata = {}
name = doc.metadata["dataset_name"]
if clear_metadata:
doc.metadata = {}
doc.metadata["dataset_name"] = name
return doc


def concatenate_datasets(
dsets: Union[
List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset]
]
],
clear_metadata: bool,
) -> Union[Dataset, IterableDataset]:
"""Concatenate multiple datasets into a single dataset. The datasets must have the same
document type. Datasets metadata will be removed, dataset name will be saved instead.
document type. Dataset name will be saved in Metadata.
Args:
dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If
a dictionary is provided, the dataset names will be added to the documents as metadata.
clear_metadata: Whether to clear the metadata before concatenating.
Returns:
A new dataset that is the concatenation of the input datasets.
"""

if isinstance(dsets, dict):
dsets = [
dset.map(_add_dset_name_to_document, fn_kwargs={"name": name})
dset.map(
_add_dset_name_to_document,
fn_kwargs={"name": name, "clear_metadata": clear_metadata},
)
for name, dset in dsets.items()
]

Expand Down
9 changes: 5 additions & 4 deletions src/pie_datasets/core/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,15 +720,16 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset


def concatenate_dataset_dicts(
inputs: Dict[str, DatasetDict],
split_mappings: Dict[str, Dict[str, str]],
inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool
):
"""Concatenate the splits of multiple dataset dicts into a single one.
"""Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be
saved in Metadata.
Args:
inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate.
split_mappings: A mapping from target split names to mappings from input dataset names to
source split names.
clear_metadata: Whether to clear the metadata before concatenating.
Returns: A dataset dict with keys in split_names as splits and content from the merged input
dataset dicts.
Expand All @@ -743,7 +744,7 @@ def concatenate_dataset_dicts(

result = DatasetDict(
{
target_split_name: concatenate_datasets(dsets)
target_split_name: concatenate_datasets(dsets, clear_metadata=clear_metadata)
for target_split_name, dsets in input_splits.items()
}
)
Expand Down
32 changes: 20 additions & 12 deletions tests/unit/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,12 @@ def _empty_docs():
assert str(excinfo.value) == "No documents to create dataset from"


@pytest.mark.parametrize("as_list", [False, True])
def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list):
@pytest.mark.parametrize(
"as_list, clear_metadata", [(False, False), (False, True), (True, False), (True, True)]
)
def test_concatenate_datasets(
maybe_iterable_dataset, dataset_with_converter_functions, as_list, clear_metadata
):
# Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets
if as_list:
# Test concatenation of list of datasets
Expand All @@ -512,11 +516,14 @@ def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_fun
maybe_iterable_dataset["train"],
maybe_iterable_dataset["validation"],
maybe_iterable_dataset["test"],
]
],
clear_metadata=clear_metadata,
)
else:
# Test concatenation of dictionary of datasets
concatenated_dataset = concatenate_datasets(maybe_iterable_dataset)
concatenated_dataset = concatenate_datasets(
maybe_iterable_dataset, clear_metadata=clear_metadata
)

# Check correct output type
if isinstance(maybe_iterable_dataset["train"], IterableDataset):
Expand Down Expand Up @@ -556,15 +563,17 @@ def test_concatenate_datasets_errors(dataset_with_converter_functions):
# Test concatenation of empty datasets
empty_dataset = list[Dataset]()
with pytest.raises(ValueError) as excinfo:
concatenate_datasets(empty_dataset)
concatenate_datasets(empty_dataset, clear_metadata=False)
assert str(excinfo.value) == "No datasets to concatenate"

# Test concatenation of datasets with different document types
dataset_with_converted_doc = dataset_with_converter_functions.to_document_type(
TestDocumentWithLabel
)
with pytest.raises(ValueError) as excinfo:
concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc])
concatenate_datasets(
[dataset_with_converter_functions, dataset_with_converted_doc], clear_metadata=False
)
assert str(excinfo.value) == "All datasets must have the same document type to concatenate"


Expand All @@ -573,7 +582,7 @@ def test_add_dset_name_to_document():
doc = Document()
assert not hasattr(doc, "metadata")
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
_add_dset_name_to_document(doc, "test", clear_metadata=False)
assert (
str(excinfo.value)
== "Document does not have metadata attribute which required to save the dataset name: Document()"
Expand All @@ -582,10 +591,9 @@ def test_add_dset_name_to_document():
# Test adding dataset name to document
doc.metadata = {}
assert hasattr(doc, "metadata")
_add_dset_name_to_document(doc, "test_dataset_name")
_add_dset_name_to_document(doc, "test_dataset_name", clear_metadata=False)
assert doc.metadata["dataset_name"] == "test_dataset_name"

# Test document already having dataset_name in metadata
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name"
# Test document already having dataset_name in metadata keeps the old name
_add_dset_name_to_document(doc, "test", clear_metadata=False)
assert doc.metadata["dataset_name"] == "test_dataset_name"
1 change: 1 addition & 0 deletions tests/unit/core/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ def test_concatenate_dataset_dicts(tbga_extract, comagc_extract):
concatenated_dataset = concatenate_dataset_dicts(
inputs={"tbga": tbga_extract, "comagc": comagc_extract},
split_mappings={"train": {"tbga": "train", "comagc": "train"}},
clear_metadata=True,
)

assert len(concatenated_dataset["train"]) == len(tbga_extract["train"]) + len(
Expand Down

0 comments on commit 4719a07

Please sign in to comment.