diff --git a/src/pie_datasets/core/__init__.py b/src/pie_datasets/core/__init__.py index e81adf7b..22f4bd7a 100644 --- a/src/pie_datasets/core/__init__.py +++ b/src/pie_datasets/core/__init__.py @@ -1,5 +1,5 @@ from .builder import ArrowBasedBuilder, GeneratorBasedBuilder -from .dataset import Dataset, IterableDataset +from .dataset import Dataset, IterableDataset, concatenate_datasets from .dataset_dict import DatasetDict, load_dataset __all__ = [ @@ -9,4 +9,5 @@ "IterableDataset", "DatasetDict", "load_dataset", + "concatenate_datasets", ] diff --git a/src/pie_datasets/core/dataset.py b/src/pie_datasets/core/dataset.py index ca18f0b0..ccaa77b2 100644 --- a/src/pie_datasets/core/dataset.py +++ b/src/pie_datasets/core/dataset.py @@ -588,3 +588,52 @@ def get_pie_dataset_type( raise TypeError( f"the dataset must be of type Dataset or IterableDataset, but is of type {type(hf_dataset)}" ) + + +def _add_dset_name_to_document(doc: Document, name: str) -> Document: + if not hasattr(doc, "metadata"): + raise ValueError( + f"Document does not have metadata attribute which required to save the dataset name: {doc}" + ) + if "dataset_name" in doc.metadata: + raise ValueError( + f"Document already has a dataset_name attribute: {doc.metadata['dataset']}" + ) + doc.metadata["dataset_name"] = name + return doc + + +def concatenate_datasets( + dsets: Union[ + List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset] + ] +) -> Union[Dataset, IterableDataset]: + """Concatenate multiple datasets into a single dataset. + + The datasets must have the same + document type. + Args: + dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If a dictionary is + provided, the dataset names will be added to the documents as metadata. + Returns: + A new dataset that is the concatenation of the input datasets. + """ + + if isinstance(dsets, dict): + dsets = [ + dset.map(_add_dset_name_to_document, fn_kwargs={"name": name}) + for name, dset in dsets.items() + ] + + if len(dsets) == 0: + raise ValueError("No datasets to concatenate") + + document_type = dsets[0].document_type + for doc in dsets[1:]: + if not doc.document_type == document_type: + raise ValueError("All datasets must have the same document type to concatenate") + + result_hf = datasets.concatenate_datasets(dsets) + pie_dataset_type = get_pie_dataset_type(dsets[0]) + + return pie_dataset_type.from_hf_dataset(result_hf, document_type=document_type)