Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement Document.deduplicate_annotations() #436

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,52 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc

return dict(added_annotations)

def deduplicate_annotations(
self: D,
) -> D:
"""
Deduplicates annotations in the document. The method is useful, for instance, if annotations
are added by window-based processing with overlaps which may lead to duplicated annotations.
"""

dependency_ordered_fields: List[str] = []
_enumerate_dependencies(
dependency_ordered_fields,
dependency_graph=self._annotation_graph,
nodes=self._annotation_graph["_artificial_root"],
)

def get_score(annotation: Annotation) -> float:
score = getattr(annotation, "score", None)
return 1.0 if score is None else score

result = self.copy(with_annotations=False)
store: Dict[int, Annotation] = {}
for field_name in dependency_ordered_fields:
if field_name in self._annotation_fields:
layer = self[field_name]
new_mapping: Dict[int, Annotation] = {}
for is_prediction, anns in [(False, layer), (True, layer.predictions)]:
ann2duplicates = defaultdict(list)
for ann in anns:
ann2duplicates[ann].append(ann)
for duplicates in ann2duplicates.values():
duplicates_sorted = sorted(duplicates, key=get_score, reverse=True)
best_duplicate = duplicates_sorted[0]
new_ann = best_duplicate.copy_with_store(
override_annotation_store=store,
invalid_annotation_ids={},
)
for ann in duplicates:
new_mapping[ann._id] = new_ann
target_layer = result[field_name]
if is_prediction:
target_layer.predictions.append(new_ann)
else:
target_layer.append(new_ann)
store.update(new_mapping)
return result


def resolve_annotation(
id_or_annotation: Union[int, Annotation],
Expand Down
108 changes: 107 additions & 1 deletion tests/core/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation
from pytorch_ie.core.document import (
AnnotationLayer,
Expand Down Expand Up @@ -490,3 +490,109 @@ class MyDocumentTwoTargets(Document):

assert document.words.target_names == ["text1", "text2"]
assert document.words.targets == {"text1": "Hello world!", "text2": "Hello world again!"}


def test_deduplicate_annotations():
@dataclasses.dataclass
class MyDocument(Document):
text: str
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")

document = MyDocument(text="Hello world!")
document.entities.append(LabeledSpan(start=0, end=5, label="entity"))
document.entities.append(LabeledSpan(start=0, end=5, label="entity", score=0.5))
document.entities.append(LabeledSpan(start=6, end=11, label="entity"))
document.entities.append(LabeledSpan(start=6, end=11, label="entity_2"))
assert document.entities.resolve() == [
("entity", "Hello"),
("entity", "Hello"),
("entity", "world"),
("entity_2", "world"),
]
document.entities.predictions.append(LabeledSpan(start=0, end=5, label="entity", score=0.9))

document.relations.append(
BinaryRelation(head=document.entities[0], tail=document.entities[2], label="relation")
)
document.relations.append(
BinaryRelation(head=document.entities[0], tail=document.entities[2], label="relation")
)
document.relations.append(
BinaryRelation(head=document.entities[1], tail=document.entities[2], label="relation")
)
document.relations.append(
BinaryRelation(head=document.entities[1], tail=document.entities[3], label="relation")
)
assert document.relations.resolve() == [
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity_2", "world"))),
]

document.relations.predictions.append(
BinaryRelation(head=document.entities[1], tail=document.entities[3], label="relation")
)
document.relations.predictions.append(
BinaryRelation(
head=document.entities[1], tail=document.entities[3], label="relation", score=0.5
)
)
document.relations.predictions.append(
BinaryRelation(
head=document.entities.predictions[0],
tail=document.entities[3],
label="relation",
score=0.8,
)
)

tp = len(set(document.relations.predictions) & set(document.relations))
fp = len(set(document.relations.predictions) - set(document.relations))
fn = len(set(document.relations) - set(document.relations.predictions))
assert tp == 1
assert fp == 0
assert fn == 1

deduplicated_doc = document.deduplicate_annotations()
assert len(deduplicated_doc.entities) == 3
assert {ann.copy() for ann in deduplicated_doc.entities} == {
LabeledSpan(start=0, end=5, label="entity", score=1.0),
LabeledSpan(start=6, end=11, label="entity", score=1.0),
LabeledSpan(start=6, end=11, label="entity_2", score=1.0),
}
assert len(deduplicated_doc.entities.predictions) == 1
assert {ann.copy() for ann in deduplicated_doc.entities.predictions} == {
LabeledSpan(start=0, end=5, label="entity", score=0.9)
}

assert len(deduplicated_doc.relations) == 2
assert {ann.copy() for ann in deduplicated_doc.relations} == {
BinaryRelation(
head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[1], label="relation"
),
BinaryRelation(
head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[2], label="relation"
),
}
assert len(deduplicated_doc.relations.predictions) == 1
assert {ann.copy() for ann in deduplicated_doc.relations.predictions} == {
BinaryRelation(
head=deduplicated_doc.entities[0], tail=deduplicated_doc.entities[2], label="relation"
),
}

assert deduplicated_doc.relations.resolve() == [
("relation", (("entity", "Hello"), ("entity", "world"))),
("relation", (("entity", "Hello"), ("entity_2", "world"))),
]
assert deduplicated_doc.relations.predictions.resolve() == [
("relation", (("entity", "Hello"), ("entity_2", "world")))
]
tp = len(set(deduplicated_doc.relations.predictions) & set(deduplicated_doc.relations))
fp = len(set(deduplicated_doc.relations.predictions) - set(deduplicated_doc.relations))
fn = len(set(deduplicated_doc.relations) - set(deduplicated_doc.relations.predictions))
assert tp == 1
assert fp == 0
assert fn == 1
Loading