-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement merge_annotations_from_documents(), helper methods and test
- Loading branch information
1 parent
bd5c6ac
commit 0c70478
Showing
2 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from collections import defaultdict | ||
from typing import Dict, Hashable, List, Optional, TypeVar | ||
|
||
from pytorch_ie.core.document import Document | ||
from pytorch_ie.documents import WithMetadata | ||
|
||
|
||
def deduplicate_annotation_dicts( | ||
annotation_dicts: List[Dict[str, Hashable]] | ||
) -> List[Dict[str, Hashable]]: | ||
"""Remove duplicate annotation dictionaries from a list of annotation dictionaries. | ||
Args: | ||
annotation_dicts: The list of annotation dictionaries to remove duplicates from. | ||
Returns: | ||
The list of annotation dictionaries with duplicates removed. | ||
""" | ||
unique_annotation_dicts = [] | ||
seen_annotation_dicts = set() | ||
for annotation_dict in annotation_dicts: | ||
annotation_dict_tuple = tuple(sorted(annotation_dict.items())) | ||
if annotation_dict_tuple not in seen_annotation_dicts: | ||
unique_annotation_dicts.append(annotation_dict) | ||
seen_annotation_dicts.add(annotation_dict_tuple) | ||
return unique_annotation_dicts | ||
|
||
|
||
D = TypeVar("D", bound=Document) | ||
|
||
|
||
def deduplicate_annotations(document: D) -> D: | ||
"""Remove duplicate annotations from a document. | ||
Args: | ||
document: The document to remove duplicate annotations from. | ||
Returns: | ||
The document with duplicate annotations removed. | ||
""" | ||
annotation_field_names = [field.name for field in document.annotation_fields()] | ||
doc_dict = document.asdict() | ||
for annotation_field_name in annotation_field_names: | ||
doc_dict[annotation_field_name]["annotations"] = deduplicate_annotation_dicts( | ||
doc_dict[annotation_field_name]["annotations"] | ||
) | ||
doc_dict[annotation_field_name]["predictions"] = deduplicate_annotation_dicts( | ||
doc_dict[annotation_field_name]["predictions"] | ||
) | ||
return type(document).fromdict(doc_dict) | ||
|
||
|
||
def save_annotation_sources_to_metadata( | ||
document: D, | ||
annotation_id2source: Dict[int, List[str]], | ||
metadata_key: str, | ||
use_predictions: bool, | ||
) -> None: | ||
"""Save the source names for the annotations or predictions in the metadata of the document. | ||
Args: | ||
document: The document to save the source names in the metadata for. | ||
metadata_key: The key in the metadata where the source names should be stored. | ||
annotation_id2source: A mapping from annotation IDs to the source names. Should contain | ||
the ids of all annotations or predictions (depending on use_predictions) in the document. | ||
use_predictions: Whether to store the source names for the predictions or the annotations. | ||
""" | ||
|
||
if not hasattr(document, "metadata"): | ||
raise ValueError("Document does not have metadata, can not store source names.") | ||
if metadata_key in document.metadata: | ||
raise ValueError(f"Metadata key '{metadata_key}' already exists in the document.") | ||
document.metadata[metadata_key] = defaultdict(dict) | ||
for annotation_field in document.annotation_fields(): | ||
layer_name = annotation_field.name | ||
document.metadata[metadata_key][layer_name] = [] | ||
layer = document[layer_name] | ||
if use_predictions: | ||
layer = layer.predictions | ||
for ann in layer: | ||
document.metadata[metadata_key][layer_name].append(annotation_id2source[ann._id]) | ||
document.metadata[metadata_key] = dict(document.metadata[metadata_key]) | ||
|
||
|
||
def merge_annotations_from_documents( | ||
documents: Dict[str, D], | ||
metadata_key_source_annotations: Optional[str] = None, | ||
metadata_key_source_predictions: Optional[str] = None, | ||
) -> D: | ||
"""Merge annotations from multiple documents into a single document. Note that this will remove | ||
any annotation duplicates. | ||
Args: | ||
documents: A dictionary mapping document source (e.g. dataset names) to documents. | ||
metadata_key_source_annotations: If not None, the key in the metadata where the source names | ||
for the (gold) annotations are stored. | ||
metadata_key_source_predictions: If not None, the key in the metadata where the source names | ||
for the predictions are stored. | ||
Returns: | ||
The merged document with the source names and annotation scores stored in the metadata at key | ||
metadata_key, for each layer in the order of the predictions. | ||
""" | ||
if len(documents) == 0: | ||
raise ValueError("No documents provided.") | ||
source_names = sorted(documents) | ||
first_source_name = source_names[0] | ||
merged_document: D = documents[first_source_name].copy(with_annotations=False) | ||
|
||
added_annotation_id2source_names: Dict[int, List[str]] = defaultdict(list) | ||
for source_name in source_names: | ||
document = documents[source_name] | ||
if type(document) is not type(merged_document): | ||
raise ValueError( | ||
f"Document types do not match: {type(document)} and {type(merged_document)}" | ||
) | ||
if isinstance(document, WithMetadata) and document.id is not None: | ||
if document.id != merged_document.id: | ||
raise ValueError( | ||
f"Document IDs do not match: {document.id} and {merged_document.id}" | ||
) | ||
|
||
# TODO: add_all_annotations_from_other needs to be fixed! it should return a mapping from | ||
# original annotation *IDs* to new annotations! | ||
# Note: this does not check for duplicates! | ||
added_annotations = merged_document.add_all_annotations_from_other( | ||
other=document, strict=True | ||
) | ||
|
||
for layer_name, orig_id2new_annotation in added_annotations.items(): | ||
for orig_id, new_annotation in orig_id2new_annotation.items(): | ||
added_annotation_id2source_names[new_annotation._id].append(source_name) | ||
|
||
merged_document = deduplicate_annotations(merged_document) | ||
|
||
# save source names in metadata (at key metadata_key_source_annotations / metadata_key_source_predictions | ||
# for each layer in the order of the annotations / predictions) | ||
if metadata_key_source_annotations is not None: | ||
save_annotation_sources_to_metadata( | ||
document=merged_document, | ||
annotation_id2source=added_annotation_id2source_names, | ||
metadata_key=metadata_key_source_annotations, | ||
use_predictions=False, | ||
) | ||
if metadata_key_source_predictions is not None: | ||
save_annotation_sources_to_metadata( | ||
document=merged_document, | ||
annotation_id2source=added_annotation_id2source_names, | ||
metadata_key=metadata_key_source_predictions, | ||
use_predictions=True, | ||
) | ||
return merged_document |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from pytorch_ie.annotations import LabeledSpan | ||
from pytorch_ie.documents import TextDocumentWithLabeledSpans | ||
from pytorch_ie.utils.document import merge_annotations_from_documents | ||
|
||
|
||
def test_document_merge_annotations(): | ||
base_doc = TextDocumentWithLabeledSpans(id="doc1", text="This is a test.") | ||
# add annotations | ||
base_doc.labeled_spans.append(LabeledSpan(start=0, end=4, label="label1", score=1.0)) | ||
base_doc.labeled_spans.append(LabeledSpan(start=5, end=7, label="label2", score=1.0)) | ||
|
||
input1 = base_doc.copy() | ||
# add predictions | ||
input1.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.9)) | ||
input1.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7)) | ||
|
||
input2 = base_doc.copy() | ||
# add predictions | ||
input2.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.8)) | ||
input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7)) | ||
input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label3", score=0.6)) | ||
|
||
documents = { | ||
"doc1": input1, | ||
"doc2": input2, | ||
} | ||
result = merge_annotations_from_documents( | ||
documents, | ||
metadata_key_source_annotations="annotations_source", | ||
metadata_key_source_predictions="predictions_source", | ||
) | ||
assert result.id == "doc1" | ||
assert set(result.labeled_spans) == set(base_doc.labeled_spans) | ||
assert len(result.labeled_spans) == len(base_doc.labeled_spans) == 2 | ||
assert len(result.labeled_spans.predictions) == 4 | ||
assert result.labeled_spans.predictions.resolve() == [ | ||
("label1", "This"), | ||
("label2", "is"), | ||
("label1", "This"), | ||
("label3", "is"), | ||
] | ||
annotations_with_sources = [ | ||
(ann.copy(), sources) | ||
for ann, sources in zip( | ||
result.labeled_spans, result.metadata["annotations_source"]["labeled_spans"] | ||
) | ||
] | ||
assert annotations_with_sources == [ | ||
(LabeledSpan(start=0, end=4, label="label1", score=1.0), ["doc1", "doc2"]), | ||
(LabeledSpan(start=5, end=7, label="label2", score=1.0), ["doc1", "doc2"]), | ||
] | ||
predictions_with_scores = [ | ||
(ann.copy(), sources) | ||
for ann, sources in zip( | ||
result.labeled_spans.predictions, | ||
result.metadata["predictions_source"]["labeled_spans"], | ||
) | ||
] | ||
assert predictions_with_scores == [ | ||
(LabeledSpan(start=0, end=4, label="label1", score=0.9), ["doc1"]), | ||
(LabeledSpan(start=5, end=7, label="label2", score=0.7), ["doc1", "doc2"]), | ||
(LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]), | ||
(LabeledSpan(start=5, end=7, label="label3", score=0.6), ["doc2"]), | ||
] |