diff --git a/src/pytorch_ie/utils/document.py b/src/pytorch_ie/utils/document.py new file mode 100644 index 00000000..7e0742df --- /dev/null +++ b/src/pytorch_ie/utils/document.py @@ -0,0 +1,152 @@ +from collections import defaultdict +from typing import Dict, Hashable, List, Optional, TypeVar + +from pytorch_ie.core.document import Document +from pytorch_ie.documents import WithMetadata + + +def deduplicate_annotation_dicts( + annotation_dicts: List[Dict[str, Hashable]] +) -> List[Dict[str, Hashable]]: + """Remove duplicate annotation dictionaries from a list of annotation dictionaries. + + Args: + annotation_dicts: The list of annotation dictionaries to remove duplicates from. + + Returns: + The list of annotation dictionaries with duplicates removed. + """ + unique_annotation_dicts = [] + seen_annotation_dicts = set() + for annotation_dict in annotation_dicts: + annotation_dict_tuple = tuple(sorted(annotation_dict.items())) + if annotation_dict_tuple not in seen_annotation_dicts: + unique_annotation_dicts.append(annotation_dict) + seen_annotation_dicts.add(annotation_dict_tuple) + return unique_annotation_dicts + + +D = TypeVar("D", bound=Document) + + +def deduplicate_annotations(document: D) -> D: + """Remove duplicate annotations from a document. + + Args: + document: The document to remove duplicate annotations from. + + Returns: + The document with duplicate annotations removed. + """ + annotation_field_names = [field.name for field in document.annotation_fields()] + doc_dict = document.asdict() + for annotation_field_name in annotation_field_names: + doc_dict[annotation_field_name]["annotations"] = deduplicate_annotation_dicts( + doc_dict[annotation_field_name]["annotations"] + ) + doc_dict[annotation_field_name]["predictions"] = deduplicate_annotation_dicts( + doc_dict[annotation_field_name]["predictions"] + ) + return type(document).fromdict(doc_dict) + + +def save_annotation_sources_to_metadata( + document: D, + annotation_id2source: Dict[int, List[str]], + metadata_key: str, + use_predictions: bool, +) -> None: + """Save the source names for the annotations or predictions in the metadata of the document. + + Args: + document: The document to save the source names in the metadata for. + metadata_key: The key in the metadata where the source names should be stored. + annotation_id2source: A mapping from annotation IDs to the source names. Should contain + the ids of all annotations or predictions (depending on use_predictions) in the document. + use_predictions: Whether to store the source names for the predictions or the annotations. + """ + + if not hasattr(document, "metadata"): + raise ValueError("Document does not have metadata, can not store source names.") + if metadata_key in document.metadata: + raise ValueError(f"Metadata key '{metadata_key}' already exists in the document.") + document.metadata[metadata_key] = defaultdict(dict) + for annotation_field in document.annotation_fields(): + layer_name = annotation_field.name + document.metadata[metadata_key][layer_name] = [] + layer = document[layer_name] + if use_predictions: + layer = layer.predictions + for ann in layer: + document.metadata[metadata_key][layer_name].append(annotation_id2source[ann._id]) + document.metadata[metadata_key] = dict(document.metadata[metadata_key]) + + +def merge_annotations_from_documents( + documents: Dict[str, D], + metadata_key_source_annotations: Optional[str] = None, + metadata_key_source_predictions: Optional[str] = None, +) -> D: + """Merge annotations from multiple documents into a single document. Note that this will remove + any annotation duplicates. + + Args: + documents: A dictionary mapping document source (e.g. dataset names) to documents. + metadata_key_source_annotations: If not None, the key in the metadata where the source names + for the (gold) annotations are stored. + metadata_key_source_predictions: If not None, the key in the metadata where the source names + for the predictions are stored. + + Returns: + The merged document with the source names and annotation scores stored in the metadata at key + metadata_key, for each layer in the order of the predictions. + """ + if len(documents) == 0: + raise ValueError("No documents provided.") + source_names = sorted(documents) + first_source_name = source_names[0] + merged_document: D = documents[first_source_name].copy(with_annotations=False) + + added_annotation_id2source_names: Dict[int, List[str]] = defaultdict(list) + for source_name in source_names: + document = documents[source_name] + if type(document) is not type(merged_document): + raise ValueError( + f"Document types do not match: {type(document)} and {type(merged_document)}" + ) + if isinstance(document, WithMetadata) and document.id is not None: + if document.id != merged_document.id: + raise ValueError( + f"Document IDs do not match: {document.id} and {merged_document.id}" + ) + + # TODO: add_all_annotations_from_other needs to be fixed! it should return a mapping from + # original annotation *IDs* to new annotations! + # Note: this does not check for duplicates! + added_annotations = merged_document.add_all_annotations_from_other( + other=document, strict=True + ) + + for layer_name, orig_id2new_annotation in added_annotations.items(): + for orig_id, new_annotation in orig_id2new_annotation.items(): + added_annotation_id2source_names[new_annotation._id].append(source_name) + + merged_document = deduplicate_annotations(merged_document) + + # save source names in metadata (at key metadata_key_source_annotations / metadata_key_source_predictions + # for each layer in the order of the annotations / predictions) + if metadata_key_source_annotations is not None: + save_annotation_sources_to_metadata( + document=merged_document, + annotation_id2source=added_annotation_id2source_names, + metadata_key=metadata_key_source_annotations, + use_predictions=False, + ) + if metadata_key_source_predictions is not None: + save_annotation_sources_to_metadata( + document=merged_document, + annotation_id2source=added_annotation_id2source_names, + metadata_key=metadata_key_source_predictions, + use_predictions=True, + ) + return merged_document diff --git a/tests/utils/test_document.py b/tests/utils/test_document.py new file mode 100644 index 00000000..a27734f2 --- /dev/null +++ b/tests/utils/test_document.py @@ -0,0 +1,64 @@ +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.documents import TextDocumentWithLabeledSpans +from pytorch_ie.utils.document import merge_annotations_from_documents + + +def test_document_merge_annotations(): + base_doc = TextDocumentWithLabeledSpans(id="doc1", text="This is a test.") + # add annotations + base_doc.labeled_spans.append(LabeledSpan(start=0, end=4, label="label1", score=1.0)) + base_doc.labeled_spans.append(LabeledSpan(start=5, end=7, label="label2", score=1.0)) + + input1 = base_doc.copy() + # add predictions + input1.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.9)) + input1.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7)) + + input2 = base_doc.copy() + # add predictions + input2.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.8)) + input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7)) + input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label3", score=0.6)) + + documents = { + "doc1": input1, + "doc2": input2, + } + result = merge_annotations_from_documents( + documents, + metadata_key_source_annotations="annotations_source", + metadata_key_source_predictions="predictions_source", + ) + assert result.id == "doc1" + assert set(result.labeled_spans) == set(base_doc.labeled_spans) + assert len(result.labeled_spans) == len(base_doc.labeled_spans) == 2 + assert len(result.labeled_spans.predictions) == 4 + assert result.labeled_spans.predictions.resolve() == [ + ("label1", "This"), + ("label2", "is"), + ("label1", "This"), + ("label3", "is"), + ] + annotations_with_sources = [ + (ann.copy(), sources) + for ann, sources in zip( + result.labeled_spans, result.metadata["annotations_source"]["labeled_spans"] + ) + ] + assert annotations_with_sources == [ + (LabeledSpan(start=0, end=4, label="label1", score=1.0), ["doc1", "doc2"]), + (LabeledSpan(start=5, end=7, label="label2", score=1.0), ["doc1", "doc2"]), + ] + predictions_with_scores = [ + (ann.copy(), sources) + for ann, sources in zip( + result.labeled_spans.predictions, + result.metadata["predictions_source"]["labeled_spans"], + ) + ] + assert predictions_with_scores == [ + (LabeledSpan(start=0, end=4, label="label1", score=0.9), ["doc1"]), + (LabeledSpan(start=5, end=7, label="label2", score=0.7), ["doc1", "doc2"]), + (LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]), + (LabeledSpan(start=5, end=7, label="label3", score=0.6), ["doc2"]), + ]