Skip to content

Commit

Permalink
Implement concatenate_dataset_dicts (#153)
Browse files Browse the repository at this point in the history
* implement concatenate_dataset_dicts

* add tests

* wipe metadata from docs in `concatenate_datasets` + add metadata to test datasets

* add feature check in `test_to_document_type_function`

* Fix `map()` when no function used at all.

* remove features not declared in the target document type

* add parameter `clean_metadata` to `concatenate_datasets` and `concatenate_dataset_dicts`

---------

Co-authored-by: Arne Binder <[email protected]>
  • Loading branch information
RainbowRivey and ArneBinder authored Oct 1, 2024
1 parent 8fc0cc9 commit 34fff5d
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 24 deletions.
3 changes: 2 additions & 1 deletion src/pie_datasets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .builder import ArrowBasedBuilder, GeneratorBasedBuilder
from .dataset import Dataset, IterableDataset, concatenate_datasets
from .dataset_dict import DatasetDict, load_dataset
from .dataset_dict import DatasetDict, concatenate_dataset_dicts, load_dataset

__all__ = [
"GeneratorBasedBuilder",
Expand All @@ -10,4 +10,5 @@
"DatasetDict",
"load_dataset",
"concatenate_datasets",
"concatenate_dataset_dicts",
]
33 changes: 24 additions & 9 deletions src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def dataset_to_document_type(
# remove the document converters because they are not valid anymore
result.document_converters = {}

# remove features not declared in the target document type
if result.features is not None:
original_field_names = set(result.features)
target_field_names = {field.name for field in document_type.fields()}
remove_field_names = original_field_names - target_field_names
result = result.remove_columns(list(remove_field_names))

return result


Expand Down Expand Up @@ -376,7 +383,9 @@ def map(
result_document_type: Optional[Type[Document]] = None,
) -> "Dataset":
dataset = super().map(
function=decorate_convert_to_dict_of_lists(function) if as_documents else function,
function=decorate_convert_to_dict_of_lists(function)
if as_documents and function is not None
else function,
with_indices=with_indices,
with_rank=with_rank,
input_columns=input_columns,
Expand Down Expand Up @@ -582,7 +591,7 @@ def map( # type: ignore
function=decorate_convert_to_document_and_back(
function, document_type=self.document_type, batched=batched
)
if as_documents
if as_documents and function is not None
else function,
batched=batched,
**kwargs,
Expand Down Expand Up @@ -664,37 +673,43 @@ def get_pie_dataset_type(
)


def _add_dset_name_to_document(doc: Document, name: str) -> Document:
def _add_dset_name_to_document(doc: Document, name: str, clear_metadata: bool) -> Document:
if not hasattr(doc, "metadata"):
raise ValueError(
f"Document does not have metadata attribute which required to save the dataset name: {doc}"
)
# Keep the old name if available
if "dataset_name" in doc.metadata:
raise ValueError(
f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}"
)
name = doc.metadata["dataset_name"]
if clear_metadata:
doc.metadata = {}
doc.metadata["dataset_name"] = name
return doc


def concatenate_datasets(
dsets: Union[
List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset]
]
],
clear_metadata: bool,
) -> Union[Dataset, IterableDataset]:
"""Concatenate multiple datasets into a single dataset. The datasets must have the same
document type.
document type. Dataset name will be saved in Metadata.
Args:
dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If
a dictionary is provided, the dataset names will be added to the documents as metadata.
clear_metadata: Whether to clear the metadata before concatenating.
Returns:
A new dataset that is the concatenation of the input datasets.
"""

if isinstance(dsets, dict):
dsets = [
dset.map(_add_dset_name_to_document, fn_kwargs={"name": name})
dset.map(
_add_dset_name_to_document,
fn_kwargs={"name": name, "clear_metadata": clear_metadata},
)
for name, dset in dsets.items()
]

Expand Down
40 changes: 39 additions & 1 deletion src/pie_datasets/core/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from pytorch_ie.core.document import Document
from pytorch_ie.utils.hydra import resolve_target, serialize_document_type

from .dataset import Dataset, IterableDataset, get_pie_dataset_type
from .dataset import (
Dataset,
IterableDataset,
concatenate_datasets,
get_pie_dataset_type,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -729,3 +734,36 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset
f"expected datasets.load_dataset to return {datasets.DatasetDict}, {datasets.IterableDatasetDict}, "
f"{Dataset}, or {IterableDataset}, but got {type(dataset_or_dataset_dict)}"
)


def concatenate_dataset_dicts(
inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool
):
"""Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be
saved in Metadata.
Args:
inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate.
split_mappings: A mapping from target split names to mappings from input dataset names to
source split names.
clear_metadata: Whether to clear the metadata before concatenating.
Returns: A dataset dict with keys in split_names as splits and content from the merged input
dataset dicts.
"""

input_splits = {}
for target_split_name, mapping in split_mappings.items():
input_splits[target_split_name] = {
ds_name: inputs[ds_name][source_split_name]
for ds_name, source_split_name in mapping.items()
}

result = DatasetDict(
{
target_split_name: concatenate_datasets(dsets, clear_metadata=clear_metadata)
for target_split_name, dsets in input_splits.items()
}
)

return result
3 changes: 3 additions & 0 deletions tests/fixtures/dataset_dict/comagc_extract/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansAndBinaryRelations"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"text": "Thus, FGF6 is increased in PIN and prostate cancer and can promote the proliferation of the transformed prostatic epithelial cells via paracrine and autocrine mechanisms.", "id": "10945637.s12", "metadata": {"CCS": "normalTOcancer", "CGE": "increased", "IGE": "unchanged", "PT": "causality", "cancer_type": "prostate", "expression_change_keyword_1": {"name": "\nNone\n", "pos": null, "type": null}, "expression_change_keyword_2": {"name": "increased", "pos": [14, 22], "type": "Positive_regulation"}}, "labeled_spans": {"annotations": [{"start": 6, "end": 10, "label": "GENE", "score": 1.0, "_id": -4685428526827816387}, {"start": 35, "end": 50, "label": "CANCER", "score": 1.0, "_id": -611854743241672378}], "predictions": []}, "binary_relations": {"annotations": [{"head": -4685428526827816387, "tail": -611854743241672378, "label": "oncogene", "score": 1.0, "_id": -1790325547764256303}], "predictions": []}}
{"text": "Isolation and characterization of the major form of human MUC18 cDNA gene and correlation of MUC18 over-expression in prostate cancer cell lines and tissues with malignant progression.", "id": "11722842.s0", "metadata": {"CCS": "normalTOcancer", "CGE": "increased", "IGE": "unchanged", "PT": "observation", "cancer_type": "prostate", "expression_change_keyword_1": {"name": "over-expression", "pos": [99, 113], "type": "Gene_expression"}, "expression_change_keyword_2": {"name": "over-expression", "pos": [99, 113], "type": "Positive_regulation"}}, "labeled_spans": {"annotations": [{"start": 93, "end": 98, "label": "GENE", "score": 1.0, "_id": -2017777239235151954}, {"start": 118, "end": 133, "label": "CANCER", "score": 1.0, "_id": 4129617449961559606}], "predictions": []}, "binary_relations": {"annotations": [{"head": -2017777239235151954, "tail": 4129617449961559606, "label": "biomarker", "score": 1.0, "_id": 7993340717186791454}], "predictions": []}}
{"text": "We therefore conclude that MUC18 expression is increased during prostate cancer initiation (high grade PIN) and progression to carcinoma, and in metastatic cell lines and metastatic carcinoma.", "id": "11722842.s13", "metadata": {"CCS": "normalTOcancer", "CGE": "increased", "IGE": "unchanged", "PT": "observation", "cancer_type": "prostate", "expression_change_keyword_1": {"name": "expression", "pos": [33, 42], "type": "Gene_expression"}, "expression_change_keyword_2": {"name": "increased", "pos": [47, 55], "type": "Positive_regulation"}}, "labeled_spans": {"annotations": [{"start": 27, "end": 32, "label": "GENE", "score": 1.0, "_id": 5431679980839797458}, {"start": 64, "end": 79, "label": "CANCER", "score": 1.0, "_id": 1650882012654160466}], "predictions": []}, "binary_relations": {"annotations": [{"head": 5431679980839797458, "tail": 1650882012654160466, "label": "biomarker", "score": 1.0, "_id": -6073164971037079930}], "predictions": []}}
3 changes: 3 additions & 0 deletions tests/fixtures/dataset_dict/tbga_extract/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansAndBinaryRelations"
}
3 changes: 3 additions & 0 deletions tests/fixtures/dataset_dict/tbga_extract/test/documents.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"text": "In addition, the combined cancer genome expression metaanalysis datasets included PDE11A among the top 1% down-regulated genes in PCa.", "id": null, "metadata": {"entity_ids": ["50940", "C0006826"], "entity_names": ["PDE11A", "Malignant Neoplasms"]}, "labeled_spans": {"annotations": [{"start": 82, "end": 88, "label": "ENTITY", "score": 1.0, "_id": -924809712458378694}, {"start": 26, "end": 32, "label": "ENTITY", "score": 1.0, "_id": -8300559430683946006}], "predictions": []}, "binary_relations": {"annotations": [{"head": -924809712458378694, "tail": -8300559430683946006, "label": "NA", "score": 1.0, "_id": -1873235480272460116}], "predictions": []}}
{"text": "We conclude that the CYGB gene is regulated by both promoter methylation and tumour hypoxia in HNSCC and that increased expression of this gene correlates with clincopathological measures of a tumour's biological aggression.", "id": null, "metadata": {"entity_ids": ["114757", "C0001807"], "entity_names": ["CYGB", "Aggressive behavior"]}, "labeled_spans": {"annotations": [{"start": 21, "end": 30, "label": "ENTITY", "score": 1.0, "_id": 4471756672664549063}, {"start": 213, "end": 223, "label": "ENTITY", "score": 1.0, "_id": -3820234498234956495}], "predictions": []}, "binary_relations": {"annotations": [{"head": 4471756672664549063, "tail": -3820234498234956495, "label": "NA", "score": 1.0, "_id": -1529179093863665121}], "predictions": []}}
{"text": "Thus, the role of SIVA in tumorigenesis remains unclear.", "id": null, "metadata": {"entity_ids": ["10572", "C0007621"], "entity_names": ["SIVA1", "Neoplastic Cell Transformation"]}, "labeled_spans": {"annotations": [{"start": 18, "end": 22, "label": "ENTITY", "score": 1.0, "_id": 3174421471102386276}, {"start": 26, "end": 39, "label": "ENTITY", "score": 1.0, "_id": -6496953722761076655}], "predictions": []}, "binary_relations": {"annotations": [{"head": 3174421471102386276, "tail": -6496953722761076655, "label": "NA", "score": 1.0, "_id": 2920545352474864205}], "predictions": []}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"text": "A monocyte chemoattractant protein-1 gene polymorphism is associated with occult ischemia in a high-risk asymptomatic population.", "id": null, "metadata": {"entity_ids": ["6347", "C0231221"], "entity_names": ["CCL2", "Asymptomatic"]}, "labeled_spans": {"annotations": [{"start": 2, "end": 36, "label": "ENTITY", "score": 1.0, "_id": 5426963144202911262}, {"start": 105, "end": 117, "label": "ENTITY", "score": 1.0, "_id": 8375553621315725498}], "predictions": []}, "binary_relations": {"annotations": [{"head": 5426963144202911262, "tail": 8375553621315725498, "label": "NA", "score": 1.0, "_id": 8597812253194613001}], "predictions": []}}
{"text": "This study examined the effects of Her2 blockade on tumor angiogenesis, vascular architecture, and hypoxia in Her2(+) and Her2(-) MCF7 xenograft tumors.", "id": null, "metadata": {"entity_ids": ["2064", "C0242184"], "entity_names": ["ERBB2", "Hypoxia"]}, "labeled_spans": {"annotations": [{"start": 122, "end": 126, "label": "ENTITY", "score": 1.0, "_id": 8449701248948288217}, {"start": 99, "end": 106, "label": "ENTITY", "score": 1.0, "_id": -971867574717604855}], "predictions": []}, "binary_relations": {"annotations": [{"head": 8449701248948288217, "tail": -971867574717604855, "label": "NA", "score": 1.0, "_id": -2442696185288775855}], "predictions": []}}
{"text": "Eleven deleterious variants, six nonsense and five missense, were identified in seven genes: four LCA-associated genes (CEP290, IQCB1, NMNAT1, and RPGRIP1), one gene responsible for syndromic LCA (ALMS1), and two IRDs-related genes (CTNNA1 and CYP4V2).", "id": null, "metadata": {"entity_ids": ["80184", "C2931258"], "entity_names": ["CEP290", "Amaurosis congenita of Leber, type 1"]}, "labeled_spans": {"annotations": [{"start": 120, "end": 126, "label": "ENTITY", "score": 1.0, "_id": 3602497405587057427}, {"start": 98, "end": 101, "label": "ENTITY", "score": 1.0, "_id": 2172619519622247379}], "predictions": []}, "binary_relations": {"annotations": [{"head": 3602497405587057427, "tail": 2172619519622247379, "label": "genomic_alterations", "score": 1.0, "_id": 8689688816868215711}], "predictions": []}}
49 changes: 37 additions & 12 deletions tests/unit/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy
import pytest
import torch
from pyexpat import features
from pytorch_ie import Document
from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span
from pytorch_ie.core import AnnotationList, annotation_field
Expand Down Expand Up @@ -203,9 +204,25 @@ def test_register_document_converter_mapping(dataset_with_converter_mapping):


def test_to_document_type_function(dataset_with_converter_functions):
# Features are only available for Dataset type (not for IterableDataset)
if isinstance(dataset_with_converter_functions, Dataset):
assert set(dataset_with_converter_functions.features) == {
"entities",
"relations",
"metadata",
"sentences",
"id",
"text",
}
else:
assert dataset_with_converter_functions.features is None
assert dataset_with_converter_functions.document_type == TestDocument
converted_dataset = dataset_with_converter_functions.to_document_type(TestDocumentWithLabel)
assert converted_dataset.document_type == TestDocumentWithLabel
if isinstance(converted_dataset, Dataset):
assert set(converted_dataset.features) == {"id", "label", "metadata", "text"}
else:
assert converted_dataset.features is None

assert len(converted_dataset.document_converters) == 0
for doc in converted_dataset:
Expand Down Expand Up @@ -485,8 +502,12 @@ def _empty_docs():
assert str(excinfo.value) == "No documents to create dataset from"


@pytest.mark.parametrize("as_list", [False, True])
def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list):
@pytest.mark.parametrize(
"as_list, clear_metadata", [(False, False), (False, True), (True, False), (True, True)]
)
def test_concatenate_datasets(
maybe_iterable_dataset, dataset_with_converter_functions, as_list, clear_metadata
):
# Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets
if as_list:
# Test concatenation of list of datasets
Expand All @@ -495,11 +516,14 @@ def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_fun
maybe_iterable_dataset["train"],
maybe_iterable_dataset["validation"],
maybe_iterable_dataset["test"],
]
],
clear_metadata=clear_metadata,
)
else:
# Test concatenation of dictionary of datasets
concatenated_dataset = concatenate_datasets(maybe_iterable_dataset)
concatenated_dataset = concatenate_datasets(
maybe_iterable_dataset, clear_metadata=clear_metadata
)

# Check correct output type
if isinstance(maybe_iterable_dataset["train"], IterableDataset):
Expand Down Expand Up @@ -539,15 +563,17 @@ def test_concatenate_datasets_errors(dataset_with_converter_functions):
# Test concatenation of empty datasets
empty_dataset = list[Dataset]()
with pytest.raises(ValueError) as excinfo:
concatenate_datasets(empty_dataset)
concatenate_datasets(empty_dataset, clear_metadata=False)
assert str(excinfo.value) == "No datasets to concatenate"

# Test concatenation of datasets with different document types
dataset_with_converted_doc = dataset_with_converter_functions.to_document_type(
TestDocumentWithLabel
)
with pytest.raises(ValueError) as excinfo:
concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc])
concatenate_datasets(
[dataset_with_converter_functions, dataset_with_converted_doc], clear_metadata=False
)
assert str(excinfo.value) == "All datasets must have the same document type to concatenate"


Expand All @@ -556,7 +582,7 @@ def test_add_dset_name_to_document():
doc = Document()
assert not hasattr(doc, "metadata")
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
_add_dset_name_to_document(doc, "test", clear_metadata=False)
assert (
str(excinfo.value)
== "Document does not have metadata attribute which required to save the dataset name: Document()"
Expand All @@ -565,10 +591,9 @@ def test_add_dset_name_to_document():
# Test adding dataset name to document
doc.metadata = {}
assert hasattr(doc, "metadata")
_add_dset_name_to_document(doc, "test_dataset_name")
_add_dset_name_to_document(doc, "test_dataset_name", clear_metadata=False)
assert doc.metadata["dataset_name"] == "test_dataset_name"

# Test document already having dataset_name in metadata
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name"
# Test document already having dataset_name in metadata keeps the old name
_add_dset_name_to_document(doc, "test", clear_metadata=False)
assert doc.metadata["dataset_name"] == "test_dataset_name"
33 changes: 32 additions & 1 deletion tests/unit/core/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from pytorch_ie.core import AnnotationList, Document, annotation_field
from pytorch_ie.documents import TextBasedDocument, TextDocument

from pie_datasets import Dataset, DatasetDict, IterableDataset, load_dataset
from pie_datasets import (
Dataset,
DatasetDict,
IterableDataset,
concatenate_dataset_dicts,
load_dataset,
)
from pie_datasets.core.dataset_dict import (
EnterDatasetDictMixin,
EnterDatasetMixin,
Expand Down Expand Up @@ -693,3 +699,28 @@ def test_load_dataset_conll2003_wrong_type_single_split():
"<class 'datasets.dataset_dict.IterableDatasetDict'>, <class 'pie_datasets.core.dataset.Dataset'>, "
"or <class 'pie_datasets.core.dataset.IterableDataset'>, but got <class 'datasets.arrow_dataset.Dataset'>"
)


@pytest.fixture
def tbga_extract():
return DatasetDict.from_json(data_dir=FIXTURES_ROOT / "dataset_dict" / "tbga_extract")


@pytest.fixture
def comagc_extract():
return DatasetDict.from_json(data_dir=FIXTURES_ROOT / "dataset_dict" / "comagc_extract")


def test_concatenate_dataset_dicts(tbga_extract, comagc_extract):
concatenated_dataset = concatenate_dataset_dicts(
inputs={"tbga": tbga_extract, "comagc": comagc_extract},
split_mappings={"train": {"tbga": "train", "comagc": "train"}},
clear_metadata=True,
)

assert len(concatenated_dataset["train"]) == len(tbga_extract["train"]) + len(
comagc_extract["train"]
)
assert all(
[ds.metadata["dataset_name"] in ["tbga", "comagc"] for ds in concatenated_dataset["train"]]
)

0 comments on commit 34fff5d

Please sign in to comment.