Skip to content

Commit

Permalink
fix #149 drugprot polishment (#150)
Browse files Browse the repository at this point in the history
* Introduced method header for _generate_example method in drugprot.py

* Modified test_drugprot.py

* restructured test methods to match other test files' structure and especially test_chemprot.py

* added test_document_to_example() method

* Added drugprot2example() method, enabling conversion from DrugprotDocument back to Example format

* Added drugprot_bigbio2example() method, enabling conversion from DrugprotBigBioDocument back to Example format

* Pre-Commit forgotten

* Connected full-cycle test of document conversion and back

* test method test_example_to_document_and_back_all

* Modified full cycle test

* dataset_variant irrelevant

* test every split

* Modified drugprot.py to align with BRAT format

* split affected entity ids and relation ids in metadata in example to document methods

* adjusted document to example methods to fit

* adjusted related test methods to align with new format

* minor improvements

---------

Co-authored-by: Arne Binder <[email protected]>
  • Loading branch information
kai-car and ArneBinder authored Aug 20, 2024
1 parent 88f6b77 commit b5f08f4
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 64 deletions.
138 changes: 125 additions & 13 deletions dataset_builders/pie/drugprot/drugprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DrugprotBigbioDocument(TextBasedDocument):


def example2drugprot(example: Dict[str, Any]) -> DrugprotDocument:
metadata = {"entity_ids": []}
metadata = {"entity_ids": [], "relation_ids": []}
id2labeled_span: Dict[str, LabeledSpan] = {}

document = DrugprotDocument(
Expand All @@ -40,30 +40,38 @@ def example2drugprot(example: Dict[str, Any]) -> DrugprotDocument:
id=example["document_id"],
metadata=metadata,
)

for span in example["entities"]:
labeled_span = LabeledSpan(
start=span["offset"][0],
end=span["offset"][1],
label=span["type"],
)
document.entities.append(labeled_span)
document.metadata["entity_ids"].append(span["id"])
id2labeled_span[span["id"]] = labeled_span
entity_id = span["id"].split("_")[1]
document.metadata["entity_ids"].append(entity_id)
id2labeled_span[entity_id] = labeled_span

for relation in example["relations"]:
arg1_id = relation["arg1_id"].split("_")[1]
arg2_id = relation["arg2_id"].split("_")[1]
document.relations.append(
BinaryRelation(
head=id2labeled_span[relation["arg1_id"]],
tail=id2labeled_span[relation["arg2_id"]],
head=id2labeled_span[arg1_id],
tail=id2labeled_span[arg2_id],
label=relation["type"],
)
)
relation_id = "R" + relation["id"].split("_")[1]
document.metadata["relation_ids"].append(relation_id)

return document


def example2drugprot_bigbio(example: Dict[str, Any]) -> DrugprotBigbioDocument:
text = " ".join([" ".join(passage["text"]) for passage in example["passages"]])
doc_id = example["document_id"]
metadata = {"entity_ids": []}
metadata = {"entity_ids": [], "relation_ids": []}
id2labeled_span: Dict[str, LabeledSpan] = {}

document = DrugprotBigbioDocument(
Expand All @@ -79,27 +87,122 @@ def example2drugprot_bigbio(example: Dict[str, Any]) -> DrugprotBigbioDocument:
label=passage["type"],
)
)
# We sort labels and relation to always have an deterministic order for testing purposes.
# We sort labels and relation to always have a deterministic order for testing purposes.
for span in example["entities"]:
labeled_span = LabeledSpan(
start=span["offsets"][0][0],
end=span["offsets"][0][1],
label=span["type"],
)
document.entities.append(labeled_span)
document.metadata["entity_ids"].append(span["id"])
id2labeled_span[span["id"]] = labeled_span
entity_id = span["id"].split("_")[1]
document.metadata["entity_ids"].append(entity_id)
id2labeled_span[entity_id] = labeled_span

for relation in example["relations"]:
arg1_id = relation["arg1_id"].split("_")[1]
arg2_id = relation["arg2_id"].split("_")[1]
document.relations.append(
BinaryRelation(
head=id2labeled_span[relation["arg1_id"]],
tail=id2labeled_span[relation["arg2_id"]],
head=id2labeled_span[arg1_id],
tail=id2labeled_span[arg2_id],
label=relation["type"],
)
)
relation_id = "R" + relation["id"].split("_")[1]
document.metadata["relation_ids"].append(relation_id)

return document


def drugprot2example(doc: DrugprotDocument) -> Dict[str, Any]:
entities = []
for i, entity in enumerate(doc.entities):
entities.append(
{
"id": doc.id + "_" + doc.metadata["entity_ids"][i],
"type": entity.label,
"text": doc.text[entity.start : entity.end],
"offset": [entity.start, entity.end],
}
)

relations = []
for i, relation in enumerate(doc.relations):
relations.append(
{
"id": doc.id + "_" + doc.metadata["relation_ids"][i][1:],
"arg1_id": doc.id
+ "_"
+ doc.metadata["entity_ids"][doc.entities.index(relation.head)],
"arg2_id": doc.id
+ "_"
+ doc.metadata["entity_ids"][doc.entities.index(relation.tail)],
"type": relation.label,
}
)

return {
"document_id": doc.id,
"title": doc.title,
"abstract": doc.abstract,
"text": doc.text,
"entities": entities,
"relations": relations,
}


def drugprot_bigbio2example(doc: DrugprotBigbioDocument) -> Dict[str, Any]:
entities = []
for i, entity in enumerate(doc.entities):
entities.append(
{
"id": doc.id + "_" + doc.metadata["entity_ids"][i],
"normalized": [],
"offsets": [[entity.start, entity.end]],
"type": entity.label,
"text": [doc.text[entity.start : entity.end]],
}
)

relations = []
for i, relation in enumerate(doc.relations):
relations.append(
{
"id": doc.id + "_" + doc.metadata["relation_ids"][i][1:],
"arg1_id": doc.id
+ "_"
+ doc.metadata["entity_ids"][doc.entities.index(relation.head)],
"arg2_id": doc.id
+ "_"
+ doc.metadata["entity_ids"][doc.entities.index(relation.tail)],
"normalized": [],
"type": relation.label,
}
)

passages = []
for passage in doc.passages:
passages.append(
{
"id": doc.id + "_" + passage.label,
"text": [doc.text[passage.start : passage.end]],
"offsets": [[passage.start, passage.end]],
"type": passage.label,
}
)

return {
"coreferences": [],
"document_id": doc.id,
"entities": entities,
"events": [],
"id": doc.id,
"passages": passages,
"relations": relations,
}


class Drugprot(GeneratorBasedBuilder):
DOCUMENT_TYPES = {
"drugprot_source": DrugprotDocument,
Expand Down Expand Up @@ -144,12 +247,21 @@ def document_converters(self):
raise ValueError(f"Unknown dataset name: {self.config.name}")

def _generate_document(
self,
example: Dict[str, Any],
self, example: Dict[str, Any], **kwargs
) -> Union[DrugprotDocument, DrugprotBigbioDocument]:
if self.config.name == "drugprot_source":
return example2drugprot(example)
elif self.config.name == "drugprot_bigbio_kb":
return example2drugprot_bigbio(example)
else:
raise ValueError(f"Unknown dataset config name: {self.config.name}")

def _generate_example(
self, document: Union[DrugprotDocument, DrugprotBigbioDocument], **kwargs
) -> Dict[str, Any]:
if isinstance(document, DrugprotBigbioDocument):
return drugprot_bigbio2example(document)
elif isinstance(document, DrugprotDocument):
return drugprot2example(document)
else:
raise ValueError(f"Unknown document type: {type(document)}")
131 changes: 80 additions & 51 deletions tests/dataset_builders/pie/drugprot/test_drugprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.dataset_builders.common import PIE_BASE_PATH

DATASET_NAME = "drugprot"
BUILDER_CLASS = Drugprot
PIE_DATASET_PATH = PIE_BASE_PATH / DATASET_NAME
HF_DATASET_PATH = Drugprot.BASE_DATASET_PATH
HF_DATASET_REVISION = Drugprot.BASE_DATASET_REVISION
Expand Down Expand Up @@ -317,28 +318,7 @@ def test_hf_dataset_all(hf_dataset, split):
assert len(example["relations"]) >= 0


@pytest.fixture(scope="module")
def builder(dataset_variant) -> Drugprot:
return Drugprot(config_name=dataset_variant)


def test_document_converters(builder, dataset_variant):
if dataset_variant == "drugprot_source":
assert set(builder.document_converters) == {TextDocumentWithLabeledSpansAndBinaryRelations}
elif dataset_variant == "drugprot_bigbio_kb":
assert set(builder.document_converters) == {
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
}
else:
raise ValueError(f"Unknown dataset variant: {dataset_variant}")


@pytest.fixture(scope="module")
def document(hf_example, builder) -> Union[DrugprotDocument, DrugprotBigbioDocument]:
return builder._generate_document(hf_example)


def test_document(document, dataset_variant):
def test_example_to_document(document, dataset_variant):
if dataset_variant == "drugprot_source":
assert isinstance(document, DrugprotDocument)
assert (
Expand Down Expand Up @@ -385,29 +365,77 @@ def test_document(document, dataset_variant):
("GENE-Y", "RDH12"),
("GENE-N", "retinol dehydrogenase"),
]
# check entity ids
# check metadata
assert document.metadata["entity_ids"] == [
"17512723_T1",
"17512723_T2",
"17512723_T3",
"17512723_T4",
"17512723_T5",
"17512723_T6",
"17512723_T7",
"17512723_T8",
"17512723_T9",
"17512723_T10",
"17512723_T11",
"17512723_T12",
"17512723_T13",
"T1",
"T2",
"T3",
"T4",
"T5",
"T6",
"T7",
"T8",
"T9",
"T10",
"T11",
"T12",
"T13",
]
assert document.metadata["relation_ids"] == ["R0"]

# check the relations
assert document.relations.resolve() == [
("PRODUCT-OF", (("CHEMICAL", "androstanediol"), ("GENE-Y", "human type 12 RDH")))
]


@pytest.fixture(scope="module")
def builder(dataset_variant) -> BUILDER_CLASS:
return BUILDER_CLASS(config_name=dataset_variant)


def test_builder(builder, dataset_variant):
assert builder is not None
assert builder.config_id == dataset_variant
assert builder.dataset_name == "drugprot"
if dataset_variant == "drugprot_source":
assert builder.document_type == DrugprotDocument
elif dataset_variant == "drugprot_bigbio_kb":
assert builder.document_type == DrugprotBigbioDocument
else:
raise ValueError(f"Unknown dataset variant: {dataset_variant}")


def test_document_to_example_and_back(document, builder, hf_example):
hf_example_back = builder._generate_example(document)
assert hf_example_back == hf_example


@pytest.mark.slow
def test_example_to_document_and_back_all(hf_dataset, builder):
for ds in hf_dataset.values():
for example in ds:
document = builder._generate_document(example)
example_back = builder._generate_example(document)
assert example_back == example


def test_document_converters(builder, dataset_variant):
if dataset_variant == "drugprot_source":
assert set(builder.document_converters) == {TextDocumentWithLabeledSpansAndBinaryRelations}
elif dataset_variant == "drugprot_bigbio_kb":
assert set(builder.document_converters) == {
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
}
else:
raise ValueError(f"Unknown dataset variant: {dataset_variant}")


@pytest.fixture(scope="module")
def document(hf_example, builder) -> Union[DrugprotDocument, DrugprotBigbioDocument]:
return builder._generate_document(hf_example)


@pytest.fixture(scope="module")
def pie_dataset(dataset_variant) -> DatasetDict:
return load_dataset(str(PIE_DATASET_PATH), name=dataset_variant)
Expand Down Expand Up @@ -445,7 +473,7 @@ def test_converted_pie_dataset(converted_pie_dataset, converted_document_type):


@pytest.fixture(scope="module")
def converted_document(converted_pie_dataset) -> Type[TextBasedDocument]:
def converted_document(converted_pie_dataset) -> TextBasedDocument:
return converted_pie_dataset["train"][0]


Expand Down Expand Up @@ -485,22 +513,23 @@ def test_converted_document(converted_document, converted_document_type):
("GENE-Y", "RDH12"),
("GENE-N", "retinol dehydrogenase"),
]
# check entity ids
# check metadata
assert converted_document.metadata["entity_ids"] == [
"17512723_T1",
"17512723_T2",
"17512723_T3",
"17512723_T4",
"17512723_T5",
"17512723_T6",
"17512723_T7",
"17512723_T8",
"17512723_T9",
"17512723_T10",
"17512723_T11",
"17512723_T12",
"17512723_T13",
"T1",
"T2",
"T3",
"T4",
"T5",
"T6",
"T7",
"T8",
"T9",
"T10",
"T11",
"T12",
"T13",
]
assert converted_document.metadata["relation_ids"] == ["R0"]

# check the relations
assert converted_document.binary_relations.resolve() == [
Expand Down

0 comments on commit b5f08f4

Please sign in to comment.