From 4719a0741bfde824c9651ef2c3ff0f16c0005e16 Mon Sep 17 00:00:00 2001 From: R33v4 Date: Mon, 30 Sep 2024 13:15:26 +0200 Subject: [PATCH] add parameter `clean_metadata` to `concatenate_datasets` and `concatenate_dataset_dicts` --- src/pie_datasets/core/dataset.py | 21 +++++++++++------- src/pie_datasets/core/dataset_dict.py | 9 ++++---- tests/unit/core/test_dataset.py | 32 +++++++++++++++++---------- tests/unit/core/test_dataset_dict.py | 1 + 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/pie_datasets/core/dataset.py b/src/pie_datasets/core/dataset.py index 911ec568..f592cbd4 100644 --- a/src/pie_datasets/core/dataset.py +++ b/src/pie_datasets/core/dataset.py @@ -673,16 +673,16 @@ def get_pie_dataset_type( ) -def _add_dset_name_to_document(doc: Document, name: str) -> Document: +def _add_dset_name_to_document(doc: Document, name: str, clear_metadata: bool) -> Document: if not hasattr(doc, "metadata"): raise ValueError( f"Document does not have metadata attribute which required to save the dataset name: {doc}" ) + # Keep the old name if available if "dataset_name" in doc.metadata: - raise ValueError( - f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}" - ) - doc.metadata = {} + name = doc.metadata["dataset_name"] + if clear_metadata: + doc.metadata = {} doc.metadata["dataset_name"] = name return doc @@ -690,21 +690,26 @@ def _add_dset_name_to_document(doc: Document, name: str) -> Document: def concatenate_datasets( dsets: Union[ List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset] - ] + ], + clear_metadata: bool, ) -> Union[Dataset, IterableDataset]: """Concatenate multiple datasets into a single dataset. The datasets must have the same - document type. Datasets metadata will be removed, dataset name will be saved instead. + document type. Dataset name will be saved in Metadata. Args: dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If a dictionary is provided, the dataset names will be added to the documents as metadata. + clear_metadata: Whether to clear the metadata before concatenating. Returns: A new dataset that is the concatenation of the input datasets. """ if isinstance(dsets, dict): dsets = [ - dset.map(_add_dset_name_to_document, fn_kwargs={"name": name}) + dset.map( + _add_dset_name_to_document, + fn_kwargs={"name": name, "clear_metadata": clear_metadata}, + ) for name, dset in dsets.items() ] diff --git a/src/pie_datasets/core/dataset_dict.py b/src/pie_datasets/core/dataset_dict.py index 31b80c49..b46ac63c 100644 --- a/src/pie_datasets/core/dataset_dict.py +++ b/src/pie_datasets/core/dataset_dict.py @@ -720,15 +720,16 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset def concatenate_dataset_dicts( - inputs: Dict[str, DatasetDict], - split_mappings: Dict[str, Dict[str, str]], + inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool ): - """Concatenate the splits of multiple dataset dicts into a single one. + """Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be + saved in Metadata. Args: inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate. split_mappings: A mapping from target split names to mappings from input dataset names to source split names. + clear_metadata: Whether to clear the metadata before concatenating. Returns: A dataset dict with keys in split_names as splits and content from the merged input dataset dicts. @@ -743,7 +744,7 @@ def concatenate_dataset_dicts( result = DatasetDict( { - target_split_name: concatenate_datasets(dsets) + target_split_name: concatenate_datasets(dsets, clear_metadata=clear_metadata) for target_split_name, dsets in input_splits.items() } ) diff --git a/tests/unit/core/test_dataset.py b/tests/unit/core/test_dataset.py index 66f2b1e7..76ef9f1f 100644 --- a/tests/unit/core/test_dataset.py +++ b/tests/unit/core/test_dataset.py @@ -502,8 +502,12 @@ def _empty_docs(): assert str(excinfo.value) == "No documents to create dataset from" -@pytest.mark.parametrize("as_list", [False, True]) -def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list): +@pytest.mark.parametrize( + "as_list, clear_metadata", [(False, False), (False, True), (True, False), (True, True)] +) +def test_concatenate_datasets( + maybe_iterable_dataset, dataset_with_converter_functions, as_list, clear_metadata +): # Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets if as_list: # Test concatenation of list of datasets @@ -512,11 +516,14 @@ def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_fun maybe_iterable_dataset["train"], maybe_iterable_dataset["validation"], maybe_iterable_dataset["test"], - ] + ], + clear_metadata=clear_metadata, ) else: # Test concatenation of dictionary of datasets - concatenated_dataset = concatenate_datasets(maybe_iterable_dataset) + concatenated_dataset = concatenate_datasets( + maybe_iterable_dataset, clear_metadata=clear_metadata + ) # Check correct output type if isinstance(maybe_iterable_dataset["train"], IterableDataset): @@ -556,7 +563,7 @@ def test_concatenate_datasets_errors(dataset_with_converter_functions): # Test concatenation of empty datasets empty_dataset = list[Dataset]() with pytest.raises(ValueError) as excinfo: - concatenate_datasets(empty_dataset) + concatenate_datasets(empty_dataset, clear_metadata=False) assert str(excinfo.value) == "No datasets to concatenate" # Test concatenation of datasets with different document types @@ -564,7 +571,9 @@ def test_concatenate_datasets_errors(dataset_with_converter_functions): TestDocumentWithLabel ) with pytest.raises(ValueError) as excinfo: - concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc]) + concatenate_datasets( + [dataset_with_converter_functions, dataset_with_converted_doc], clear_metadata=False + ) assert str(excinfo.value) == "All datasets must have the same document type to concatenate" @@ -573,7 +582,7 @@ def test_add_dset_name_to_document(): doc = Document() assert not hasattr(doc, "metadata") with pytest.raises(ValueError) as excinfo: - _add_dset_name_to_document(doc, "test") + _add_dset_name_to_document(doc, "test", clear_metadata=False) assert ( str(excinfo.value) == "Document does not have metadata attribute which required to save the dataset name: Document()" @@ -582,10 +591,9 @@ def test_add_dset_name_to_document(): # Test adding dataset name to document doc.metadata = {} assert hasattr(doc, "metadata") - _add_dset_name_to_document(doc, "test_dataset_name") + _add_dset_name_to_document(doc, "test_dataset_name", clear_metadata=False) assert doc.metadata["dataset_name"] == "test_dataset_name" - # Test document already having dataset_name in metadata - with pytest.raises(ValueError) as excinfo: - _add_dset_name_to_document(doc, "test") - assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name" + # Test document already having dataset_name in metadata keeps the old name + _add_dset_name_to_document(doc, "test", clear_metadata=False) + assert doc.metadata["dataset_name"] == "test_dataset_name" diff --git a/tests/unit/core/test_dataset_dict.py b/tests/unit/core/test_dataset_dict.py index 9813bd31..007749f2 100644 --- a/tests/unit/core/test_dataset_dict.py +++ b/tests/unit/core/test_dataset_dict.py @@ -654,6 +654,7 @@ def test_concatenate_dataset_dicts(tbga_extract, comagc_extract): concatenated_dataset = concatenate_dataset_dicts( inputs={"tbga": tbga_extract, "comagc": comagc_extract}, split_mappings={"train": {"tbga": "train", "comagc": "train"}}, + clear_metadata=True, ) assert len(concatenated_dataset["train"]) == len(tbga_extract["train"]) + len(