Skip to content

Commit

Permalink
implement Document.deduplicate_annotations()
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent 7c63106 commit 09c494a
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
51 changes: 51 additions & 0 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,57 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc

return dict(added_annotations)

def deduplicate_annotations(
self: D,
) -> D:
"""
Deduplicates annotations in the document. The method is useful, for instance, if annotations
are added by window-based processing with overlaps which may lead to duplicated annotations.
"""

dependency_ordered_fields: List[str] = []
_enumerate_dependencies(
dependency_ordered_fields,
dependency_graph=self._annotation_graph,
nodes=self._annotation_graph["_artificial_root"],
)

def get_score(annotation: Annotation) -> float:
score = getattr(annotation, "score", None)
return 1.0 if score is None else score

result = self.copy(with_annotations=False)
store: Dict[int, Annotation] = {}
store_predictions: Dict[int, Annotation] = {}
for field_name in dependency_ordered_fields:
if field_name in self._annotation_fields:
layer = self[field_name]
for is_prediction, anns in [(False, layer), (True, layer.predictions)]:
current_store = dict(store)
if is_prediction:
current_store.update(store_predictions)
ann2duplicates = defaultdict(list)
for ann in anns:
ann2duplicates[ann].append(ann)
for duplicates in ann2duplicates.values():
duplicates_sorted = sorted(duplicates, key=get_score, reverse=True)
best_duplicate = duplicates_sorted[0]
new_ann = best_duplicate.copy_with_store(
override_annotation_store=current_store,
invalid_annotation_ids={},
)

target_layer = result[field_name]
if is_prediction:
for ann in duplicates:
store_predictions[ann._id] = new_ann
target_layer.predictions.append(new_ann)
else:
for ann in duplicates:
store[ann._id] = new_ann
target_layer.append(new_ann)
return result


def resolve_annotation(
id_or_annotation: Union[int, Annotation],
Expand Down
108 changes: 107 additions & 1 deletion tests/core/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation
from pytorch_ie.core.document import (
AnnotationLayer,
Expand Down Expand Up @@ -490,3 +490,109 @@ class MyDocumentTwoTargets(Document):

assert document.words.target_names == ["text1", "text2"]
assert document.words.targets == {"text1": "Hello world!", "text2": "Hello world again!"}


def test_deduplicate_annotations():
@dataclasses.dataclass
class MyDocument(Document):
text: str
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")

document = MyDocument(text="Hello world!")
document.entities.append(LabeledSpan(start=0, end=5, label="entity"))
document.entities.append(LabeledSpan(start=0, end=5, label="entity", score=0.5))
document.entities.append(LabeledSpan(start=6, end=11, label="entity"))
document.entities.append(LabeledSpan(start=6, end=11, label="entity_2"))
assert document.entities.resolve() == [
("entity", "Hello"),
("entity", "Hello"),
("entity", "world"),
("entity_2", "world"),
]
document.entities.predictions.append(LabeledSpan(start=0, end=5, label="entity", score=0.9))

document.relations.append(
BinaryRelation(head=document.entities[0], tail=document.entities[2], label="relation")
)
document.relations.append(
BinaryRelation(head=document.entities[0], tail=document.entities[2], label="relation")
)
document.relations.append(
BinaryRelation(head=document.entities[1], tail=document.entities[2], label="relation")
)
document.relations.append(
BinaryRelation(head=document.entities[1], tail=document.entities[3], label="relation")
)
assert document.relations.resolve() == [
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity_2", "world"))),
]

document.relations.predictions.append(
BinaryRelation(head=document.entities[1], tail=document.entities[3], label="relation")
)
document.relations.predictions.append(
BinaryRelation(
head=document.entities[1], tail=document.entities[3], label="relation", score=0.5
)
)
document.relations.predictions.append(
BinaryRelation(
head=document.entities.predictions[0],
tail=document.entities[3],
label="relation",
score=0.8,
)
)

tp = len(set(document.relations.predictions) & set(document.relations))
fp = len(set(document.relations.predictions) - set(document.relations))
fn = len(set(document.relations) - set(document.relations.predictions))
assert tp == 1
assert fp == 0
assert fn == 1

deduplicated_doc = document.deduplicate_annotations()
assert len(deduplicated_doc.entities) == 3
assert {ann.copy() for ann in deduplicated_doc.entities} == {
LabeledSpan(start=0, end=5, label="entity", score=1.0),
LabeledSpan(start=6, end=11, label="entity", score=1.0),
LabeledSpan(start=6, end=11, label="entity_2", score=1.0),
}
assert len(deduplicated_doc.entities.predictions) == 1
assert {ann.copy() for ann in deduplicated_doc.entities.predictions} == {
LabeledSpan(start=0, end=5, label="entity", score=0.9)
}

assert len(deduplicated_doc.relations) == 2
assert {ann.copy() for ann in deduplicated_doc.relations} == {
BinaryRelation(
head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[1], label="relation"
),
BinaryRelation(
head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[2], label="relation"
),
}
assert len(deduplicated_doc.relations.predictions) == 1
assert {ann.copy() for ann in deduplicated_doc.relations.predictions} == {
BinaryRelation(
head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[2], label="relation"
),
}

assert deduplicated_doc.relations.resolve() == [
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity_2", "world"))),
]
assert deduplicated_doc.relations.predictions.resolve() == [
("relation", (("entity", "Hello"), ("entity_2", "world")))
]
tp = len(set(deduplicated_doc.relations.predictions) & set(deduplicated_doc.relations))
fp = len(set(deduplicated_doc.relations.predictions) - set(deduplicated_doc.relations))
fn = len(set(deduplicated_doc.relations) - set(deduplicated_doc.relations.predictions))
assert tp == 1
assert fp == 0
assert fn == 1

0 comments on commit 09c494a

Please sign in to comment.