diff --git a/dataset_builders/pie/drugprot/drugprot.py b/dataset_builders/pie/drugprot/drugprot.py index f07756ac..b8d0abbb 100644 --- a/dataset_builders/pie/drugprot/drugprot.py +++ b/dataset_builders/pie/drugprot/drugprot.py @@ -30,7 +30,7 @@ class DrugprotBigbioDocument(TextBasedDocument): def example2drugprot(example: Dict[str, Any]) -> DrugprotDocument: - metadata = {"entity_ids": []} + metadata = {"entity_ids": [], "relation_ids": []} id2labeled_span: Dict[str, LabeledSpan] = {} document = DrugprotDocument( @@ -40,6 +40,7 @@ def example2drugprot(example: Dict[str, Any]) -> DrugprotDocument: id=example["document_id"], metadata=metadata, ) + for span in example["entities"]: labeled_span = LabeledSpan( start=span["offset"][0], @@ -47,23 +48,30 @@ def example2drugprot(example: Dict[str, Any]) -> DrugprotDocument: label=span["type"], ) document.entities.append(labeled_span) - document.metadata["entity_ids"].append(span["id"]) - id2labeled_span[span["id"]] = labeled_span + entity_id = span["id"].split("_")[1] + document.metadata["entity_ids"].append(entity_id) + id2labeled_span[entity_id] = labeled_span + for relation in example["relations"]: + arg1_id = relation["arg1_id"].split("_")[1] + arg2_id = relation["arg2_id"].split("_")[1] document.relations.append( BinaryRelation( - head=id2labeled_span[relation["arg1_id"]], - tail=id2labeled_span[relation["arg2_id"]], + head=id2labeled_span[arg1_id], + tail=id2labeled_span[arg2_id], label=relation["type"], ) ) + relation_id = "R" + relation["id"].split("_")[1] + document.metadata["relation_ids"].append(relation_id) + return document def example2drugprot_bigbio(example: Dict[str, Any]) -> DrugprotBigbioDocument: text = " ".join([" ".join(passage["text"]) for passage in example["passages"]]) doc_id = example["document_id"] - metadata = {"entity_ids": []} + metadata = {"entity_ids": [], "relation_ids": []} id2labeled_span: Dict[str, LabeledSpan] = {} document = DrugprotBigbioDocument( @@ -79,7 +87,7 @@ def example2drugprot_bigbio(example: Dict[str, Any]) -> DrugprotBigbioDocument: label=passage["type"], ) ) - # We sort labels and relation to always have an deterministic order for testing purposes. + # We sort labels and relation to always have a deterministic order for testing purposes. for span in example["entities"]: labeled_span = LabeledSpan( start=span["offsets"][0][0], @@ -87,19 +95,114 @@ def example2drugprot_bigbio(example: Dict[str, Any]) -> DrugprotBigbioDocument: label=span["type"], ) document.entities.append(labeled_span) - document.metadata["entity_ids"].append(span["id"]) - id2labeled_span[span["id"]] = labeled_span + entity_id = span["id"].split("_")[1] + document.metadata["entity_ids"].append(entity_id) + id2labeled_span[entity_id] = labeled_span + for relation in example["relations"]: + arg1_id = relation["arg1_id"].split("_")[1] + arg2_id = relation["arg2_id"].split("_")[1] document.relations.append( BinaryRelation( - head=id2labeled_span[relation["arg1_id"]], - tail=id2labeled_span[relation["arg2_id"]], + head=id2labeled_span[arg1_id], + tail=id2labeled_span[arg2_id], label=relation["type"], ) ) + relation_id = "R" + relation["id"].split("_")[1] + document.metadata["relation_ids"].append(relation_id) + return document +def drugprot2example(doc: DrugprotDocument) -> Dict[str, Any]: + entities = [] + for i, entity in enumerate(doc.entities): + entities.append( + { + "id": doc.id + "_" + doc.metadata["entity_ids"][i], + "type": entity.label, + "text": doc.text[entity.start : entity.end], + "offset": [entity.start, entity.end], + } + ) + + relations = [] + for i, relation in enumerate(doc.relations): + relations.append( + { + "id": doc.id + "_" + doc.metadata["relation_ids"][i][1:], + "arg1_id": doc.id + + "_" + + doc.metadata["entity_ids"][doc.entities.index(relation.head)], + "arg2_id": doc.id + + "_" + + doc.metadata["entity_ids"][doc.entities.index(relation.tail)], + "type": relation.label, + } + ) + + return { + "document_id": doc.id, + "title": doc.title, + "abstract": doc.abstract, + "text": doc.text, + "entities": entities, + "relations": relations, + } + + +def drugprot_bigbio2example(doc: DrugprotBigbioDocument) -> Dict[str, Any]: + entities = [] + for i, entity in enumerate(doc.entities): + entities.append( + { + "id": doc.id + "_" + doc.metadata["entity_ids"][i], + "normalized": [], + "offsets": [[entity.start, entity.end]], + "type": entity.label, + "text": [doc.text[entity.start : entity.end]], + } + ) + + relations = [] + for i, relation in enumerate(doc.relations): + relations.append( + { + "id": doc.id + "_" + doc.metadata["relation_ids"][i][1:], + "arg1_id": doc.id + + "_" + + doc.metadata["entity_ids"][doc.entities.index(relation.head)], + "arg2_id": doc.id + + "_" + + doc.metadata["entity_ids"][doc.entities.index(relation.tail)], + "normalized": [], + "type": relation.label, + } + ) + + passages = [] + for passage in doc.passages: + passages.append( + { + "id": doc.id + "_" + passage.label, + "text": [doc.text[passage.start : passage.end]], + "offsets": [[passage.start, passage.end]], + "type": passage.label, + } + ) + + return { + "coreferences": [], + "document_id": doc.id, + "entities": entities, + "events": [], + "id": doc.id, + "passages": passages, + "relations": relations, + } + + class Drugprot(GeneratorBasedBuilder): DOCUMENT_TYPES = { "drugprot_source": DrugprotDocument, @@ -144,8 +247,7 @@ def document_converters(self): raise ValueError(f"Unknown dataset name: {self.config.name}") def _generate_document( - self, - example: Dict[str, Any], + self, example: Dict[str, Any], **kwargs ) -> Union[DrugprotDocument, DrugprotBigbioDocument]: if self.config.name == "drugprot_source": return example2drugprot(example) @@ -153,3 +255,13 @@ def _generate_document( return example2drugprot_bigbio(example) else: raise ValueError(f"Unknown dataset config name: {self.config.name}") + + def _generate_example( + self, document: Union[DrugprotDocument, DrugprotBigbioDocument], **kwargs + ) -> Dict[str, Any]: + if isinstance(document, DrugprotBigbioDocument): + return drugprot_bigbio2example(document) + elif isinstance(document, DrugprotDocument): + return drugprot2example(document) + else: + raise ValueError(f"Unknown document type: {type(document)}") diff --git a/tests/dataset_builders/pie/drugprot/test_drugprot.py b/tests/dataset_builders/pie/drugprot/test_drugprot.py index 63f82d83..20e8bac6 100644 --- a/tests/dataset_builders/pie/drugprot/test_drugprot.py +++ b/tests/dataset_builders/pie/drugprot/test_drugprot.py @@ -22,6 +22,7 @@ from tests.dataset_builders.common import PIE_BASE_PATH DATASET_NAME = "drugprot" +BUILDER_CLASS = Drugprot PIE_DATASET_PATH = PIE_BASE_PATH / DATASET_NAME HF_DATASET_PATH = Drugprot.BASE_DATASET_PATH HF_DATASET_REVISION = Drugprot.BASE_DATASET_REVISION @@ -317,28 +318,7 @@ def test_hf_dataset_all(hf_dataset, split): assert len(example["relations"]) >= 0 -@pytest.fixture(scope="module") -def builder(dataset_variant) -> Drugprot: - return Drugprot(config_name=dataset_variant) - - -def test_document_converters(builder, dataset_variant): - if dataset_variant == "drugprot_source": - assert set(builder.document_converters) == {TextDocumentWithLabeledSpansAndBinaryRelations} - elif dataset_variant == "drugprot_bigbio_kb": - assert set(builder.document_converters) == { - TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions - } - else: - raise ValueError(f"Unknown dataset variant: {dataset_variant}") - - -@pytest.fixture(scope="module") -def document(hf_example, builder) -> Union[DrugprotDocument, DrugprotBigbioDocument]: - return builder._generate_document(hf_example) - - -def test_document(document, dataset_variant): +def test_example_to_document(document, dataset_variant): if dataset_variant == "drugprot_source": assert isinstance(document, DrugprotDocument) assert ( @@ -385,22 +365,23 @@ def test_document(document, dataset_variant): ("GENE-Y", "RDH12"), ("GENE-N", "retinol dehydrogenase"), ] - # check entity ids + # check metadata assert document.metadata["entity_ids"] == [ - "17512723_T1", - "17512723_T2", - "17512723_T3", - "17512723_T4", - "17512723_T5", - "17512723_T6", - "17512723_T7", - "17512723_T8", - "17512723_T9", - "17512723_T10", - "17512723_T11", - "17512723_T12", - "17512723_T13", + "T1", + "T2", + "T3", + "T4", + "T5", + "T6", + "T7", + "T8", + "T9", + "T10", + "T11", + "T12", + "T13", ] + assert document.metadata["relation_ids"] == ["R0"] # check the relations assert document.relations.resolve() == [ @@ -408,6 +389,53 @@ def test_document(document, dataset_variant): ] +@pytest.fixture(scope="module") +def builder(dataset_variant) -> BUILDER_CLASS: + return BUILDER_CLASS(config_name=dataset_variant) + + +def test_builder(builder, dataset_variant): + assert builder is not None + assert builder.config_id == dataset_variant + assert builder.dataset_name == "drugprot" + if dataset_variant == "drugprot_source": + assert builder.document_type == DrugprotDocument + elif dataset_variant == "drugprot_bigbio_kb": + assert builder.document_type == DrugprotBigbioDocument + else: + raise ValueError(f"Unknown dataset variant: {dataset_variant}") + + +def test_document_to_example_and_back(document, builder, hf_example): + hf_example_back = builder._generate_example(document) + assert hf_example_back == hf_example + + +@pytest.mark.slow +def test_example_to_document_and_back_all(hf_dataset, builder): + for ds in hf_dataset.values(): + for example in ds: + document = builder._generate_document(example) + example_back = builder._generate_example(document) + assert example_back == example + + +def test_document_converters(builder, dataset_variant): + if dataset_variant == "drugprot_source": + assert set(builder.document_converters) == {TextDocumentWithLabeledSpansAndBinaryRelations} + elif dataset_variant == "drugprot_bigbio_kb": + assert set(builder.document_converters) == { + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions + } + else: + raise ValueError(f"Unknown dataset variant: {dataset_variant}") + + +@pytest.fixture(scope="module") +def document(hf_example, builder) -> Union[DrugprotDocument, DrugprotBigbioDocument]: + return builder._generate_document(hf_example) + + @pytest.fixture(scope="module") def pie_dataset(dataset_variant) -> DatasetDict: return load_dataset(str(PIE_DATASET_PATH), name=dataset_variant) @@ -445,7 +473,7 @@ def test_converted_pie_dataset(converted_pie_dataset, converted_document_type): @pytest.fixture(scope="module") -def converted_document(converted_pie_dataset) -> Type[TextBasedDocument]: +def converted_document(converted_pie_dataset) -> TextBasedDocument: return converted_pie_dataset["train"][0] @@ -485,22 +513,23 @@ def test_converted_document(converted_document, converted_document_type): ("GENE-Y", "RDH12"), ("GENE-N", "retinol dehydrogenase"), ] - # check entity ids + # check metadata assert converted_document.metadata["entity_ids"] == [ - "17512723_T1", - "17512723_T2", - "17512723_T3", - "17512723_T4", - "17512723_T5", - "17512723_T6", - "17512723_T7", - "17512723_T8", - "17512723_T9", - "17512723_T10", - "17512723_T11", - "17512723_T12", - "17512723_T13", + "T1", + "T2", + "T3", + "T4", + "T5", + "T6", + "T7", + "T8", + "T9", + "T10", + "T11", + "T12", + "T13", ] + assert converted_document.metadata["relation_ids"] == ["R0"] # check the relations assert converted_document.binary_relations.resolve() == [