diff --git a/src/pie_datasets/core/dataset_dict.py b/src/pie_datasets/core/dataset_dict.py index b2a6214c..1c3689bf 100644 --- a/src/pie_datasets/core/dataset_dict.py +++ b/src/pie_datasets/core/dataset_dict.py @@ -694,6 +694,15 @@ def cast_document_type( ) return result + def shuffle(self, **kwargs): + result = DatasetDict.from_hf(super().shuffle(**kwargs), document_type=self.document_type) + + # TODO: integrate into DatasetDict.from_hf + for split_name, split in result.items(): + split.document_converters = self[split_name].document_converters + + return result + def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset]: dataset_or_dataset_dict = datasets.load_dataset(*args, **kwargs)