Skip to content

Commit

Permalink
implement merge_annotations_from_documents(), helper methods and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Oct 8, 2024
1 parent bd5c6ac commit 0c70478
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
152 changes: 152 additions & 0 deletions src/pytorch_ie/utils/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from collections import defaultdict
from typing import Dict, Hashable, List, Optional, TypeVar

from pytorch_ie.core.document import Document
from pytorch_ie.documents import WithMetadata


def deduplicate_annotation_dicts(
annotation_dicts: List[Dict[str, Hashable]]
) -> List[Dict[str, Hashable]]:
"""Remove duplicate annotation dictionaries from a list of annotation dictionaries.
Args:
annotation_dicts: The list of annotation dictionaries to remove duplicates from.
Returns:
The list of annotation dictionaries with duplicates removed.
"""
unique_annotation_dicts = []
seen_annotation_dicts = set()
for annotation_dict in annotation_dicts:
annotation_dict_tuple = tuple(sorted(annotation_dict.items()))
if annotation_dict_tuple not in seen_annotation_dicts:
unique_annotation_dicts.append(annotation_dict)
seen_annotation_dicts.add(annotation_dict_tuple)
return unique_annotation_dicts


D = TypeVar("D", bound=Document)


def deduplicate_annotations(document: D) -> D:
"""Remove duplicate annotations from a document.
Args:
document: The document to remove duplicate annotations from.
Returns:
The document with duplicate annotations removed.
"""
annotation_field_names = [field.name for field in document.annotation_fields()]
doc_dict = document.asdict()
for annotation_field_name in annotation_field_names:
doc_dict[annotation_field_name]["annotations"] = deduplicate_annotation_dicts(
doc_dict[annotation_field_name]["annotations"]
)
doc_dict[annotation_field_name]["predictions"] = deduplicate_annotation_dicts(
doc_dict[annotation_field_name]["predictions"]
)
return type(document).fromdict(doc_dict)


def save_annotation_sources_to_metadata(
document: D,
annotation_id2source: Dict[int, List[str]],
metadata_key: str,
use_predictions: bool,
) -> None:
"""Save the source names for the annotations or predictions in the metadata of the document.
Args:
document: The document to save the source names in the metadata for.
metadata_key: The key in the metadata where the source names should be stored.
annotation_id2source: A mapping from annotation IDs to the source names. Should contain
the ids of all annotations or predictions (depending on use_predictions) in the document.
use_predictions: Whether to store the source names for the predictions or the annotations.
"""

if not hasattr(document, "metadata"):
raise ValueError("Document does not have metadata, can not store source names.")
if metadata_key in document.metadata:
raise ValueError(f"Metadata key '{metadata_key}' already exists in the document.")
document.metadata[metadata_key] = defaultdict(dict)
for annotation_field in document.annotation_fields():
layer_name = annotation_field.name
document.metadata[metadata_key][layer_name] = []
layer = document[layer_name]
if use_predictions:
layer = layer.predictions
for ann in layer:
document.metadata[metadata_key][layer_name].append(annotation_id2source[ann._id])
document.metadata[metadata_key] = dict(document.metadata[metadata_key])


def merge_annotations_from_documents(
documents: Dict[str, D],
metadata_key_source_annotations: Optional[str] = None,
metadata_key_source_predictions: Optional[str] = None,
) -> D:
"""Merge annotations from multiple documents into a single document. Note that this will remove
any annotation duplicates.
Args:
documents: A dictionary mapping document source (e.g. dataset names) to documents.
metadata_key_source_annotations: If not None, the key in the metadata where the source names
for the (gold) annotations are stored.
metadata_key_source_predictions: If not None, the key in the metadata where the source names
for the predictions are stored.
Returns:
The merged document with the source names and annotation scores stored in the metadata at key
metadata_key, for each layer in the order of the predictions.
"""
if len(documents) == 0:
raise ValueError("No documents provided.")
source_names = sorted(documents)
first_source_name = source_names[0]
merged_document: D = documents[first_source_name].copy(with_annotations=False)

added_annotation_id2source_names: Dict[int, List[str]] = defaultdict(list)
for source_name in source_names:
document = documents[source_name]
if type(document) is not type(merged_document):
raise ValueError(
f"Document types do not match: {type(document)} and {type(merged_document)}"
)
if isinstance(document, WithMetadata) and document.id is not None:
if document.id != merged_document.id:
raise ValueError(
f"Document IDs do not match: {document.id} and {merged_document.id}"
)

# TODO: add_all_annotations_from_other needs to be fixed! it should return a mapping from
# original annotation *IDs* to new annotations!
# Note: this does not check for duplicates!
added_annotations = merged_document.add_all_annotations_from_other(
other=document, strict=True
)

for layer_name, orig_id2new_annotation in added_annotations.items():
for orig_id, new_annotation in orig_id2new_annotation.items():
added_annotation_id2source_names[new_annotation._id].append(source_name)

merged_document = deduplicate_annotations(merged_document)

# save source names in metadata (at key metadata_key_source_annotations / metadata_key_source_predictions
# for each layer in the order of the annotations / predictions)
if metadata_key_source_annotations is not None:
save_annotation_sources_to_metadata(
document=merged_document,
annotation_id2source=added_annotation_id2source_names,
metadata_key=metadata_key_source_annotations,
use_predictions=False,
)
if metadata_key_source_predictions is not None:
save_annotation_sources_to_metadata(
document=merged_document,
annotation_id2source=added_annotation_id2source_names,
metadata_key=metadata_key_source_predictions,
use_predictions=True,
)
return merged_document
64 changes: 64 additions & 0 deletions tests/utils/test_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpans
from pytorch_ie.utils.document import merge_annotations_from_documents


def test_document_merge_annotations():
base_doc = TextDocumentWithLabeledSpans(id="doc1", text="This is a test.")
# add annotations
base_doc.labeled_spans.append(LabeledSpan(start=0, end=4, label="label1", score=1.0))
base_doc.labeled_spans.append(LabeledSpan(start=5, end=7, label="label2", score=1.0))

input1 = base_doc.copy()
# add predictions
input1.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.9))
input1.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7))

input2 = base_doc.copy()
# add predictions
input2.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.8))
input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7))
input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label3", score=0.6))

documents = {
"doc1": input1,
"doc2": input2,
}
result = merge_annotations_from_documents(
documents,
metadata_key_source_annotations="annotations_source",
metadata_key_source_predictions="predictions_source",
)
assert result.id == "doc1"
assert set(result.labeled_spans) == set(base_doc.labeled_spans)
assert len(result.labeled_spans) == len(base_doc.labeled_spans) == 2
assert len(result.labeled_spans.predictions) == 4
assert result.labeled_spans.predictions.resolve() == [
("label1", "This"),
("label2", "is"),
("label1", "This"),
("label3", "is"),
]
annotations_with_sources = [
(ann.copy(), sources)
for ann, sources in zip(
result.labeled_spans, result.metadata["annotations_source"]["labeled_spans"]
)
]
assert annotations_with_sources == [
(LabeledSpan(start=0, end=4, label="label1", score=1.0), ["doc1", "doc2"]),
(LabeledSpan(start=5, end=7, label="label2", score=1.0), ["doc1", "doc2"]),
]
predictions_with_scores = [
(ann.copy(), sources)
for ann, sources in zip(
result.labeled_spans.predictions,
result.metadata["predictions_source"]["labeled_spans"],
)
]
assert predictions_with_scores == [
(LabeledSpan(start=0, end=4, label="label1", score=0.9), ["doc1"]),
(LabeledSpan(start=5, end=7, label="label2", score=0.7), ["doc1", "doc2"]),
(LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]),
(LabeledSpan(start=5, end=7, label="label3", score=0.6), ["doc2"]),
]

0 comments on commit 0c70478

Please sign in to comment.