diff --git a/src/pie_datasets/core/__init__.py b/src/pie_datasets/core/__init__.py index 22f4bd7a..54595955 100644 --- a/src/pie_datasets/core/__init__.py +++ b/src/pie_datasets/core/__init__.py @@ -1,6 +1,6 @@ from .builder import ArrowBasedBuilder, GeneratorBasedBuilder from .dataset import Dataset, IterableDataset, concatenate_datasets -from .dataset_dict import DatasetDict, load_dataset +from .dataset_dict import DatasetDict, concatenate_dataset_dicts, load_dataset __all__ = [ "GeneratorBasedBuilder", @@ -10,4 +10,5 @@ "DatasetDict", "load_dataset", "concatenate_datasets", + "concatenate_dataset_dicts", ] diff --git a/src/pie_datasets/core/dataset.py b/src/pie_datasets/core/dataset.py index 60ae5c34..f592cbd4 100644 --- a/src/pie_datasets/core/dataset.py +++ b/src/pie_datasets/core/dataset.py @@ -225,6 +225,13 @@ def dataset_to_document_type( # remove the document converters because they are not valid anymore result.document_converters = {} + # remove features not declared in the target document type + if result.features is not None: + original_field_names = set(result.features) + target_field_names = {field.name for field in document_type.fields()} + remove_field_names = original_field_names - target_field_names + result = result.remove_columns(list(remove_field_names)) + return result @@ -376,7 +383,9 @@ def map( result_document_type: Optional[Type[Document]] = None, ) -> "Dataset": dataset = super().map( - function=decorate_convert_to_dict_of_lists(function) if as_documents else function, + function=decorate_convert_to_dict_of_lists(function) + if as_documents and function is not None + else function, with_indices=with_indices, with_rank=with_rank, input_columns=input_columns, @@ -582,7 +591,7 @@ def map( # type: ignore function=decorate_convert_to_document_and_back( function, document_type=self.document_type, batched=batched ) - if as_documents + if as_documents and function is not None else function, batched=batched, **kwargs, @@ -664,15 +673,16 @@ def get_pie_dataset_type( ) -def _add_dset_name_to_document(doc: Document, name: str) -> Document: +def _add_dset_name_to_document(doc: Document, name: str, clear_metadata: bool) -> Document: if not hasattr(doc, "metadata"): raise ValueError( f"Document does not have metadata attribute which required to save the dataset name: {doc}" ) + # Keep the old name if available if "dataset_name" in doc.metadata: - raise ValueError( - f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}" - ) + name = doc.metadata["dataset_name"] + if clear_metadata: + doc.metadata = {} doc.metadata["dataset_name"] = name return doc @@ -680,21 +690,26 @@ def _add_dset_name_to_document(doc: Document, name: str) -> Document: def concatenate_datasets( dsets: Union[ List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset] - ] + ], + clear_metadata: bool, ) -> Union[Dataset, IterableDataset]: """Concatenate multiple datasets into a single dataset. The datasets must have the same - document type. + document type. Dataset name will be saved in Metadata. Args: dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If a dictionary is provided, the dataset names will be added to the documents as metadata. + clear_metadata: Whether to clear the metadata before concatenating. Returns: A new dataset that is the concatenation of the input datasets. """ if isinstance(dsets, dict): dsets = [ - dset.map(_add_dset_name_to_document, fn_kwargs={"name": name}) + dset.map( + _add_dset_name_to_document, + fn_kwargs={"name": name, "clear_metadata": clear_metadata}, + ) for name, dset in dsets.items() ] diff --git a/src/pie_datasets/core/dataset_dict.py b/src/pie_datasets/core/dataset_dict.py index b2a6214c..b46ac63c 100644 --- a/src/pie_datasets/core/dataset_dict.py +++ b/src/pie_datasets/core/dataset_dict.py @@ -19,7 +19,12 @@ from pytorch_ie.core.document import Document from pytorch_ie.utils.hydra import resolve_target, serialize_document_type -from .dataset import Dataset, IterableDataset, get_pie_dataset_type +from .dataset import ( + Dataset, + IterableDataset, + concatenate_datasets, + get_pie_dataset_type, +) logger = logging.getLogger(__name__) @@ -712,3 +717,36 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset f"expected datasets.load_dataset to return {datasets.DatasetDict}, {datasets.IterableDatasetDict}, " f"{Dataset}, or {IterableDataset}, but got {type(dataset_or_dataset_dict)}" ) + + +def concatenate_dataset_dicts( + inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool +): + """Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be + saved in Metadata. + + Args: + inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate. + split_mappings: A mapping from target split names to mappings from input dataset names to + source split names. + clear_metadata: Whether to clear the metadata before concatenating. + + Returns: A dataset dict with keys in split_names as splits and content from the merged input + dataset dicts. + """ + + input_splits = {} + for target_split_name, mapping in split_mappings.items(): + input_splits[target_split_name] = { + ds_name: inputs[ds_name][source_split_name] + for ds_name, source_split_name in mapping.items() + } + + result = DatasetDict( + { + target_split_name: concatenate_datasets(dsets, clear_metadata=clear_metadata) + for target_split_name, dsets in input_splits.items() + } + ) + + return result diff --git a/tests/fixtures/dataset_dict/comagc_extract/metadata.json b/tests/fixtures/dataset_dict/comagc_extract/metadata.json new file mode 100644 index 00000000..6b8bc1f7 --- /dev/null +++ b/tests/fixtures/dataset_dict/comagc_extract/metadata.json @@ -0,0 +1,3 @@ +{ + "document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansAndBinaryRelations" +} \ No newline at end of file diff --git a/tests/fixtures/dataset_dict/comagc_extract/train/documents.jsonl b/tests/fixtures/dataset_dict/comagc_extract/train/documents.jsonl new file mode 100644 index 00000000..770f991a --- /dev/null +++ b/tests/fixtures/dataset_dict/comagc_extract/train/documents.jsonl @@ -0,0 +1,3 @@ +{"text": "Thus, FGF6 is increased in PIN and prostate cancer and can promote the proliferation of the transformed prostatic epithelial cells via paracrine and autocrine mechanisms.", "id": "10945637.s12", "metadata": {"CCS": "normalTOcancer", "CGE": "increased", "IGE": "unchanged", "PT": "causality", "cancer_type": "prostate", "expression_change_keyword_1": {"name": "\nNone\n", "pos": null, "type": null}, "expression_change_keyword_2": {"name": "increased", "pos": [14, 22], "type": "Positive_regulation"}}, "labeled_spans": {"annotations": [{"start": 6, "end": 10, "label": "GENE", "score": 1.0, "_id": -4685428526827816387}, {"start": 35, "end": 50, "label": "CANCER", "score": 1.0, "_id": -611854743241672378}], "predictions": []}, "binary_relations": {"annotations": [{"head": -4685428526827816387, "tail": -611854743241672378, "label": "oncogene", "score": 1.0, "_id": -1790325547764256303}], "predictions": []}} +{"text": "Isolation and characterization of the major form of human MUC18 cDNA gene and correlation of MUC18 over-expression in prostate cancer cell lines and tissues with malignant progression.", "id": "11722842.s0", "metadata": {"CCS": "normalTOcancer", "CGE": "increased", "IGE": "unchanged", "PT": "observation", "cancer_type": "prostate", "expression_change_keyword_1": {"name": "over-expression", "pos": [99, 113], "type": "Gene_expression"}, "expression_change_keyword_2": {"name": "over-expression", "pos": [99, 113], "type": "Positive_regulation"}}, "labeled_spans": {"annotations": [{"start": 93, "end": 98, "label": "GENE", "score": 1.0, "_id": -2017777239235151954}, {"start": 118, "end": 133, "label": "CANCER", "score": 1.0, "_id": 4129617449961559606}], "predictions": []}, "binary_relations": {"annotations": [{"head": -2017777239235151954, "tail": 4129617449961559606, "label": "biomarker", "score": 1.0, "_id": 7993340717186791454}], "predictions": []}} +{"text": "We therefore conclude that MUC18 expression is increased during prostate cancer initiation (high grade PIN) and progression to carcinoma, and in metastatic cell lines and metastatic carcinoma.", "id": "11722842.s13", "metadata": {"CCS": "normalTOcancer", "CGE": "increased", "IGE": "unchanged", "PT": "observation", "cancer_type": "prostate", "expression_change_keyword_1": {"name": "expression", "pos": [33, 42], "type": "Gene_expression"}, "expression_change_keyword_2": {"name": "increased", "pos": [47, 55], "type": "Positive_regulation"}}, "labeled_spans": {"annotations": [{"start": 27, "end": 32, "label": "GENE", "score": 1.0, "_id": 5431679980839797458}, {"start": 64, "end": 79, "label": "CANCER", "score": 1.0, "_id": 1650882012654160466}], "predictions": []}, "binary_relations": {"annotations": [{"head": 5431679980839797458, "tail": 1650882012654160466, "label": "biomarker", "score": 1.0, "_id": -6073164971037079930}], "predictions": []}} diff --git a/tests/fixtures/dataset_dict/tbga_extract/metadata.json b/tests/fixtures/dataset_dict/tbga_extract/metadata.json new file mode 100644 index 00000000..6b8bc1f7 --- /dev/null +++ b/tests/fixtures/dataset_dict/tbga_extract/metadata.json @@ -0,0 +1,3 @@ +{ + "document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansAndBinaryRelations" +} \ No newline at end of file diff --git a/tests/fixtures/dataset_dict/tbga_extract/test/documents.jsonl b/tests/fixtures/dataset_dict/tbga_extract/test/documents.jsonl new file mode 100644 index 00000000..37668733 --- /dev/null +++ b/tests/fixtures/dataset_dict/tbga_extract/test/documents.jsonl @@ -0,0 +1,3 @@ +{"text": "In addition, the combined cancer genome expression metaanalysis datasets included PDE11A among the top 1% down-regulated genes in PCa.", "id": null, "metadata": {"entity_ids": ["50940", "C0006826"], "entity_names": ["PDE11A", "Malignant Neoplasms"]}, "labeled_spans": {"annotations": [{"start": 82, "end": 88, "label": "ENTITY", "score": 1.0, "_id": -924809712458378694}, {"start": 26, "end": 32, "label": "ENTITY", "score": 1.0, "_id": -8300559430683946006}], "predictions": []}, "binary_relations": {"annotations": [{"head": -924809712458378694, "tail": -8300559430683946006, "label": "NA", "score": 1.0, "_id": -1873235480272460116}], "predictions": []}} +{"text": "We conclude that the CYGB gene is regulated by both promoter methylation and tumour hypoxia in HNSCC and that increased expression of this gene correlates with clincopathological measures of a tumour's biological aggression.", "id": null, "metadata": {"entity_ids": ["114757", "C0001807"], "entity_names": ["CYGB", "Aggressive behavior"]}, "labeled_spans": {"annotations": [{"start": 21, "end": 30, "label": "ENTITY", "score": 1.0, "_id": 4471756672664549063}, {"start": 213, "end": 223, "label": "ENTITY", "score": 1.0, "_id": -3820234498234956495}], "predictions": []}, "binary_relations": {"annotations": [{"head": 4471756672664549063, "tail": -3820234498234956495, "label": "NA", "score": 1.0, "_id": -1529179093863665121}], "predictions": []}} +{"text": "Thus, the role of SIVA in tumorigenesis remains unclear.", "id": null, "metadata": {"entity_ids": ["10572", "C0007621"], "entity_names": ["SIVA1", "Neoplastic Cell Transformation"]}, "labeled_spans": {"annotations": [{"start": 18, "end": 22, "label": "ENTITY", "score": 1.0, "_id": 3174421471102386276}, {"start": 26, "end": 39, "label": "ENTITY", "score": 1.0, "_id": -6496953722761076655}], "predictions": []}, "binary_relations": {"annotations": [{"head": 3174421471102386276, "tail": -6496953722761076655, "label": "NA", "score": 1.0, "_id": 2920545352474864205}], "predictions": []}} diff --git a/tests/fixtures/dataset_dict/tbga_extract/train/documents.jsonl b/tests/fixtures/dataset_dict/tbga_extract/train/documents.jsonl new file mode 100644 index 00000000..2122d32a --- /dev/null +++ b/tests/fixtures/dataset_dict/tbga_extract/train/documents.jsonl @@ -0,0 +1,3 @@ +{"text": "A monocyte chemoattractant protein-1 gene polymorphism is associated with occult ischemia in a high-risk asymptomatic population.", "id": null, "metadata": {"entity_ids": ["6347", "C0231221"], "entity_names": ["CCL2", "Asymptomatic"]}, "labeled_spans": {"annotations": [{"start": 2, "end": 36, "label": "ENTITY", "score": 1.0, "_id": 5426963144202911262}, {"start": 105, "end": 117, "label": "ENTITY", "score": 1.0, "_id": 8375553621315725498}], "predictions": []}, "binary_relations": {"annotations": [{"head": 5426963144202911262, "tail": 8375553621315725498, "label": "NA", "score": 1.0, "_id": 8597812253194613001}], "predictions": []}} +{"text": "This study examined the effects of Her2 blockade on tumor angiogenesis, vascular architecture, and hypoxia in Her2(+) and Her2(-) MCF7 xenograft tumors.", "id": null, "metadata": {"entity_ids": ["2064", "C0242184"], "entity_names": ["ERBB2", "Hypoxia"]}, "labeled_spans": {"annotations": [{"start": 122, "end": 126, "label": "ENTITY", "score": 1.0, "_id": 8449701248948288217}, {"start": 99, "end": 106, "label": "ENTITY", "score": 1.0, "_id": -971867574717604855}], "predictions": []}, "binary_relations": {"annotations": [{"head": 8449701248948288217, "tail": -971867574717604855, "label": "NA", "score": 1.0, "_id": -2442696185288775855}], "predictions": []}} +{"text": "Eleven deleterious variants, six nonsense and five missense, were identified in seven genes: four LCA-associated genes (CEP290, IQCB1, NMNAT1, and RPGRIP1), one gene responsible for syndromic LCA (ALMS1), and two IRDs-related genes (CTNNA1 and CYP4V2).", "id": null, "metadata": {"entity_ids": ["80184", "C2931258"], "entity_names": ["CEP290", "Amaurosis congenita of Leber, type 1"]}, "labeled_spans": {"annotations": [{"start": 120, "end": 126, "label": "ENTITY", "score": 1.0, "_id": 3602497405587057427}, {"start": 98, "end": 101, "label": "ENTITY", "score": 1.0, "_id": 2172619519622247379}], "predictions": []}, "binary_relations": {"annotations": [{"head": 3602497405587057427, "tail": 2172619519622247379, "label": "genomic_alterations", "score": 1.0, "_id": 8689688816868215711}], "predictions": []}} diff --git a/tests/unit/core/test_dataset.py b/tests/unit/core/test_dataset.py index 16532823..76ef9f1f 100644 --- a/tests/unit/core/test_dataset.py +++ b/tests/unit/core/test_dataset.py @@ -6,6 +6,7 @@ import numpy import pytest import torch +from pyexpat import features from pytorch_ie import Document from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span from pytorch_ie.core import AnnotationList, annotation_field @@ -203,9 +204,25 @@ def test_register_document_converter_mapping(dataset_with_converter_mapping): def test_to_document_type_function(dataset_with_converter_functions): + # Features are only available for Dataset type (not for IterableDataset) + if isinstance(dataset_with_converter_functions, Dataset): + assert set(dataset_with_converter_functions.features) == { + "entities", + "relations", + "metadata", + "sentences", + "id", + "text", + } + else: + assert dataset_with_converter_functions.features is None assert dataset_with_converter_functions.document_type == TestDocument converted_dataset = dataset_with_converter_functions.to_document_type(TestDocumentWithLabel) assert converted_dataset.document_type == TestDocumentWithLabel + if isinstance(converted_dataset, Dataset): + assert set(converted_dataset.features) == {"id", "label", "metadata", "text"} + else: + assert converted_dataset.features is None assert len(converted_dataset.document_converters) == 0 for doc in converted_dataset: @@ -485,8 +502,12 @@ def _empty_docs(): assert str(excinfo.value) == "No documents to create dataset from" -@pytest.mark.parametrize("as_list", [False, True]) -def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list): +@pytest.mark.parametrize( + "as_list, clear_metadata", [(False, False), (False, True), (True, False), (True, True)] +) +def test_concatenate_datasets( + maybe_iterable_dataset, dataset_with_converter_functions, as_list, clear_metadata +): # Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets if as_list: # Test concatenation of list of datasets @@ -495,11 +516,14 @@ def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_fun maybe_iterable_dataset["train"], maybe_iterable_dataset["validation"], maybe_iterable_dataset["test"], - ] + ], + clear_metadata=clear_metadata, ) else: # Test concatenation of dictionary of datasets - concatenated_dataset = concatenate_datasets(maybe_iterable_dataset) + concatenated_dataset = concatenate_datasets( + maybe_iterable_dataset, clear_metadata=clear_metadata + ) # Check correct output type if isinstance(maybe_iterable_dataset["train"], IterableDataset): @@ -539,7 +563,7 @@ def test_concatenate_datasets_errors(dataset_with_converter_functions): # Test concatenation of empty datasets empty_dataset = list[Dataset]() with pytest.raises(ValueError) as excinfo: - concatenate_datasets(empty_dataset) + concatenate_datasets(empty_dataset, clear_metadata=False) assert str(excinfo.value) == "No datasets to concatenate" # Test concatenation of datasets with different document types @@ -547,7 +571,9 @@ def test_concatenate_datasets_errors(dataset_with_converter_functions): TestDocumentWithLabel ) with pytest.raises(ValueError) as excinfo: - concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc]) + concatenate_datasets( + [dataset_with_converter_functions, dataset_with_converted_doc], clear_metadata=False + ) assert str(excinfo.value) == "All datasets must have the same document type to concatenate" @@ -556,7 +582,7 @@ def test_add_dset_name_to_document(): doc = Document() assert not hasattr(doc, "metadata") with pytest.raises(ValueError) as excinfo: - _add_dset_name_to_document(doc, "test") + _add_dset_name_to_document(doc, "test", clear_metadata=False) assert ( str(excinfo.value) == "Document does not have metadata attribute which required to save the dataset name: Document()" @@ -565,10 +591,9 @@ def test_add_dset_name_to_document(): # Test adding dataset name to document doc.metadata = {} assert hasattr(doc, "metadata") - _add_dset_name_to_document(doc, "test_dataset_name") + _add_dset_name_to_document(doc, "test_dataset_name", clear_metadata=False) assert doc.metadata["dataset_name"] == "test_dataset_name" - # Test document already having dataset_name in metadata - with pytest.raises(ValueError) as excinfo: - _add_dset_name_to_document(doc, "test") - assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name" + # Test document already having dataset_name in metadata keeps the old name + _add_dset_name_to_document(doc, "test", clear_metadata=False) + assert doc.metadata["dataset_name"] == "test_dataset_name" diff --git a/tests/unit/core/test_dataset_dict.py b/tests/unit/core/test_dataset_dict.py index 829fc90e..007749f2 100644 --- a/tests/unit/core/test_dataset_dict.py +++ b/tests/unit/core/test_dataset_dict.py @@ -9,7 +9,13 @@ from pytorch_ie.core import AnnotationList, Document, annotation_field from pytorch_ie.documents import TextBasedDocument, TextDocument -from pie_datasets import Dataset, DatasetDict, IterableDataset, load_dataset +from pie_datasets import ( + Dataset, + DatasetDict, + IterableDataset, + concatenate_dataset_dicts, + load_dataset, +) from pie_datasets.core.dataset_dict import ( EnterDatasetDictMixin, EnterDatasetMixin, @@ -632,3 +638,28 @@ def test_load_dataset_conll2003_wrong_type_single_split(): ", , " "or , but got " ) + + +@pytest.fixture +def tbga_extract(): + return DatasetDict.from_json(data_dir=FIXTURES_ROOT / "dataset_dict" / "tbga_extract") + + +@pytest.fixture +def comagc_extract(): + return DatasetDict.from_json(data_dir=FIXTURES_ROOT / "dataset_dict" / "comagc_extract") + + +def test_concatenate_dataset_dicts(tbga_extract, comagc_extract): + concatenated_dataset = concatenate_dataset_dicts( + inputs={"tbga": tbga_extract, "comagc": comagc_extract}, + split_mappings={"train": {"tbga": "train", "comagc": "train"}}, + clear_metadata=True, + ) + + assert len(concatenated_dataset["train"]) == len(tbga_extract["train"]) + len( + comagc_extract["train"] + ) + assert all( + [ds.metadata["dataset_name"] in ["tbga", "comagc"] for ds in concatenated_dataset["train"]] + )