diff --git a/src/pytorch_ie/core/document.py b/src/pytorch_ie/core/document.py index 15e4a01b..55046ea3 100644 --- a/src/pytorch_ie/core/document.py +++ b/src/pytorch_ie/core/document.py @@ -948,6 +948,57 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc return dict(added_annotations) + def deduplicate_annotations( + self: D, + ) -> D: + """ + Deduplicates annotations in the document. The method is useful, for instance, if annotations + are added by window-based processing with overlaps which may lead to duplicated annotations. + """ + + dependency_ordered_fields: List[str] = [] + _enumerate_dependencies( + dependency_ordered_fields, + dependency_graph=self._annotation_graph, + nodes=self._annotation_graph["_artificial_root"], + ) + + def get_score(annotation: Annotation) -> float: + score = getattr(annotation, "score", None) + return 1.0 if score is None else score + + result = self.copy(with_annotations=False) + store: Dict[int, Annotation] = {} + store_predictions: Dict[int, Annotation] = {} + for field_name in dependency_ordered_fields: + if field_name in self._annotation_fields: + layer = self[field_name] + for is_prediction, anns in [(False, layer), (True, layer.predictions)]: + current_store = dict(store) + if is_prediction: + current_store.update(store_predictions) + ann2duplicates = defaultdict(list) + for ann in anns: + ann2duplicates[ann].append(ann) + for duplicates in ann2duplicates.values(): + duplicates_sorted = sorted(duplicates, key=get_score, reverse=True) + best_duplicate = duplicates_sorted[0] + new_ann = best_duplicate.copy_with_store( + override_annotation_store=current_store, + invalid_annotation_ids={}, + ) + + target_layer = result[field_name] + if is_prediction: + for ann in duplicates: + store_predictions[ann._id] = new_ann + target_layer.predictions.append(new_ann) + else: + for ann in duplicates: + store[ann._id] = new_ann + target_layer.append(new_ann) + return result + def resolve_annotation( id_or_annotation: Union[int, Annotation], diff --git a/tests/core/test_document.py b/tests/core/test_document.py index 20986f96..9d1bda89 100644 --- a/tests/core/test_document.py +++ b/tests/core/test_document.py @@ -4,7 +4,7 @@ import pytest -from pytorch_ie.annotations import BinaryRelation +from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import Annotation from pytorch_ie.core.document import ( AnnotationLayer, @@ -490,3 +490,109 @@ class MyDocumentTwoTargets(Document): assert document.words.target_names == ["text1", "text2"] assert document.words.targets == {"text1": "Hello world!", "text2": "Hello world again!"} + + +def test_deduplicate_annotations(): + @dataclasses.dataclass + class MyDocument(Document): + text: str + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + + document = MyDocument(text="Hello world!") + document.entities.append(LabeledSpan(start=0, end=5, label="entity")) + document.entities.append(LabeledSpan(start=0, end=5, label="entity", score=0.5)) + document.entities.append(LabeledSpan(start=6, end=11, label="entity")) + document.entities.append(LabeledSpan(start=6, end=11, label="entity_2")) + assert document.entities.resolve() == [ + ("entity", "Hello"), + ("entity", "Hello"), + ("entity", "world"), + ("entity_2", "world"), + ] + document.entities.predictions.append(LabeledSpan(start=0, end=5, label="entity", score=0.9)) + + document.relations.append( + BinaryRelation(head=document.entities[0], tail=document.entities[2], label="relation") + ) + document.relations.append( + BinaryRelation(head=document.entities[0], tail=document.entities[2], label="relation") + ) + document.relations.append( + BinaryRelation(head=document.entities[1], tail=document.entities[2], label="relation") + ) + document.relations.append( + BinaryRelation(head=document.entities[1], tail=document.entities[3], label="relation") + ) + assert document.relations.resolve() == [ + ("relation", (("entity", "Hello"), ("entity", "world"))), + ("relation", (("entity", "Hello"), ("entity", "world"))), + ("relation", (("entity", "Hello"), ("entity", "world"))), + ("relation", (("entity", "Hello"), ("entity_2", "world"))), + ] + + document.relations.predictions.append( + BinaryRelation(head=document.entities[1], tail=document.entities[3], label="relation") + ) + document.relations.predictions.append( + BinaryRelation( + head=document.entities[1], tail=document.entities[3], label="relation", score=0.5 + ) + ) + document.relations.predictions.append( + BinaryRelation( + head=document.entities.predictions[0], + tail=document.entities[3], + label="relation", + score=0.8, + ) + ) + + tp = len(set(document.relations.predictions) & set(document.relations)) + fp = len(set(document.relations.predictions) - set(document.relations)) + fn = len(set(document.relations) - set(document.relations.predictions)) + assert tp == 1 + assert fp == 0 + assert fn == 1 + + deduplicated_doc = document.deduplicate_annotations() + assert len(deduplicated_doc.entities) == 3 + assert {ann.copy() for ann in deduplicated_doc.entities} == { + LabeledSpan(start=0, end=5, label="entity", score=1.0), + LabeledSpan(start=6, end=11, label="entity", score=1.0), + LabeledSpan(start=6, end=11, label="entity_2", score=1.0), + } + assert len(deduplicated_doc.entities.predictions) == 1 + assert {ann.copy() for ann in deduplicated_doc.entities.predictions} == { + LabeledSpan(start=0, end=5, label="entity", score=0.9) + } + + assert len(deduplicated_doc.relations) == 2 + assert {ann.copy() for ann in deduplicated_doc.relations} == { + BinaryRelation( + head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[1], label="relation" + ), + BinaryRelation( + head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[2], label="relation" + ), + } + assert len(deduplicated_doc.relations.predictions) == 1 + assert {ann.copy() for ann in deduplicated_doc.relations.predictions} == { + BinaryRelation( + head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[2], label="relation" + ), + } + + assert deduplicated_doc.relations.resolve() == [ + ("relation", (("entity", "Hello"), ("entity", "world"))), + ("relation", (("entity", "Hello"), ("entity_2", "world"))), + ] + assert deduplicated_doc.relations.predictions.resolve() == [ + ("relation", (("entity", "Hello"), ("entity_2", "world"))) + ] + tp = len(set(deduplicated_doc.relations.predictions) & set(deduplicated_doc.relations)) + fp = len(set(deduplicated_doc.relations.predictions) - set(deduplicated_doc.relations)) + fn = len(set(deduplicated_doc.relations) - set(deduplicated_doc.relations.predictions)) + assert tp == 1 + assert fp == 0 + assert fn == 1