Skip to content

Commit

Permalink
Added concatenate_datasets method in src/dataset.py
Browse files Browse the repository at this point in the history
* enables concatenation of multiple pie-datasets

* tests still missing
  • Loading branch information
kai-car committed Aug 8, 2024
1 parent 88f6b77 commit dd873a8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/pie_datasets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .builder import ArrowBasedBuilder, GeneratorBasedBuilder
from .dataset import Dataset, IterableDataset
from .dataset import Dataset, IterableDataset, concatenate_datasets
from .dataset_dict import DatasetDict, load_dataset

__all__ = [
Expand All @@ -9,4 +9,5 @@
"IterableDataset",
"DatasetDict",
"load_dataset",
"concatenate_datasets",
]
49 changes: 49 additions & 0 deletions src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,52 @@ def get_pie_dataset_type(
raise TypeError(
f"the dataset must be of type Dataset or IterableDataset, but is of type {type(hf_dataset)}"
)


def _add_dset_name_to_document(doc: Document, name: str) -> Document:
if not hasattr(doc, "metadata"):
raise ValueError(

Check warning on line 595 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L594-L595

Added lines #L594 - L595 were not covered by tests
f"Document does not have metadata attribute which required to save the dataset name: {doc}"
)
if "dataset_name" in doc.metadata:
raise ValueError(

Check warning on line 599 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L598-L599

Added lines #L598 - L599 were not covered by tests
f"Document already has a dataset_name attribute: {doc.metadata['dataset']}"
)
doc.metadata["dataset_name"] = name
return doc

Check warning on line 603 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L602-L603

Added lines #L602 - L603 were not covered by tests


def concatenate_datasets(
dsets: Union[
List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset]
]
) -> Union[Dataset, IterableDataset]:
"""Concatenate multiple datasets into a single dataset.
The datasets must have the same
document type.
Args:
dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If a dictionary is
provided, the dataset names will be added to the documents as metadata.
Returns:
A new dataset that is the concatenation of the input datasets.
"""

if isinstance(dsets, dict):
dsets = [

Check warning on line 623 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L622-L623

Added lines #L622 - L623 were not covered by tests
dset.map(_add_dset_name_to_document, fn_kwargs={"name": name})
for name, dset in dsets.items()
]

if len(dsets) == 0:
raise ValueError("No datasets to concatenate")

Check warning on line 629 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L628-L629

Added lines #L628 - L629 were not covered by tests

document_type = dsets[0].document_type
for doc in dsets[1:]:
if not doc.document_type == document_type:
raise ValueError("All datasets must have the same document type to concatenate")

Check warning on line 634 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L631-L634

Added lines #L631 - L634 were not covered by tests

result_hf = datasets.concatenate_datasets(dsets)
pie_dataset_type = get_pie_dataset_type(dsets[0])

Check warning on line 637 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L636-L637

Added lines #L636 - L637 were not covered by tests

return pie_dataset_type.from_hf_dataset(result_hf, document_type=document_type)

Check warning on line 639 in src/pie_datasets/core/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/pie_datasets/core/dataset.py#L639

Added line #L639 was not covered by tests

0 comments on commit dd873a8

Please sign in to comment.