Skip to content

Commit

Permalink
Added concatenate_datasets method in src/dataset.py
Browse files Browse the repository at this point in the history
* enables concatenation of multiple pie-datasets

* tests still missing
  • Loading branch information
kai-car committed Aug 7, 2024
1 parent 522742d commit 0a288bb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/pie_datasets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .builder import ArrowBasedBuilder, GeneratorBasedBuilder
from .dataset import Dataset, IterableDataset
from .dataset import Dataset, IterableDataset, concatenate_datasets
from .dataset_dict import DatasetDict, load_dataset

__all__ = [
Expand All @@ -9,4 +9,5 @@
"IterableDataset",
"DatasetDict",
"load_dataset",
"concatenate_datasets",
]
49 changes: 49 additions & 0 deletions src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,52 @@ def get_pie_dataset_type(
raise TypeError(
f"the dataset must be of type Dataset or IterableDataset, but is of type {type(hf_dataset)}"
)


def _add_dset_name_to_document(doc: Document, name: str) -> Document:
if not hasattr(doc, "metadata"):
raise ValueError(
f"Document does not have metadata attribute which required to save the dataset name: {doc}"
)
if "dataset_name" in doc.metadata:
raise ValueError(
f"Document already has a dataset_name attribute: {doc.metadata['dataset']}"
)
doc.metadata["dataset_name"] = name
return doc


def concatenate_datasets(
dsets: Union[
List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset]
]
) -> Union[Dataset, IterableDataset]:
"""Concatenate multiple datasets into a single dataset. The datasets must have the same
document type.
Args:
dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If a dictionary is
provided, the dataset names will be added to the documents as metadata.
Returns:
A new dataset that is the concatenation of the input datasets.
"""

if isinstance(dsets, dict):
dsets = [
dset.map(_add_dset_name_to_document, fn_kwargs={"name": name})
for name, dset in dsets.items()
]

if len(dsets) == 0:
raise ValueError("No datasets to concatenate")

document_type = dsets[0].document_type
for doc in dsets[1:]:
if not doc.document_type == document_type:
raise ValueError("All datasets must have the same document type to concatenate")

result_hf = datasets.concatenate_datasets(dsets)
pie_dataset_type = get_pie_dataset_type(dsets[0])

return pie_dataset_type.from_hf_dataset(result_hf, document_type=document_type)

0 comments on commit 0a288bb

Please sign in to comment.