From 4134c3d97816fd2fbd5d3cb916ebf77b7284e0b3 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 10 Dec 2024 16:24:57 +0100 Subject: [PATCH 1/2] remove utils.document.deduplicate_annotations in favor of Document.deduplicate_annotations --- src/pytorch_ie/utils/document.py | 24 ++---------------------- tests/utils/test_document.py | 6 +++--- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/src/pytorch_ie/utils/document.py b/src/pytorch_ie/utils/document.py index ab0ba969..8797c5d8 100644 --- a/src/pytorch_ie/utils/document.py +++ b/src/pytorch_ie/utils/document.py @@ -29,27 +29,6 @@ def deduplicate_annotation_dicts( D = TypeVar("D", bound=Document) -def deduplicate_annotations(document: D) -> D: - """Remove duplicate annotations from a document. - - Args: - document: The document to remove duplicate annotations from. - - Returns: - The document with duplicate annotations removed. - """ - annotation_field_names = [field.name for field in document.annotation_fields()] - doc_dict = document.asdict() - for annotation_field_name in annotation_field_names: - doc_dict[annotation_field_name]["annotations"] = deduplicate_annotation_dicts( - doc_dict[annotation_field_name]["annotations"] - ) - doc_dict[annotation_field_name]["predictions"] = deduplicate_annotation_dicts( - doc_dict[annotation_field_name]["predictions"] - ) - return type(document).fromdict(doc_dict) - - def save_annotation_sources_to_metadata( document: D, annotation_id2source: Dict[int, List[str]], @@ -135,7 +114,8 @@ def merge_annotations_from_documents( for orig_id, new_annotation in orig_id2new_annotation.items(): added_annotation_id2source_names[new_annotation._id].append(source_name) - merged_document = deduplicate_annotations(merged_document) + # merged_document = deduplicate_annotations(merged_document) + merged_document = merged_document.deduplicate_annotations() # save source names in metadata (at key metadata_key_source_annotations / metadata_key_source_predictions # for each layer in the order of the annotations / predictions) diff --git a/tests/utils/test_document.py b/tests/utils/test_document.py index a27734f2..d367ae50 100644 --- a/tests/utils/test_document.py +++ b/tests/utils/test_document.py @@ -32,11 +32,11 @@ def test_document_merge_annotations(): assert result.id == "doc1" assert set(result.labeled_spans) == set(base_doc.labeled_spans) assert len(result.labeled_spans) == len(base_doc.labeled_spans) == 2 - assert len(result.labeled_spans.predictions) == 4 + # assert len(result.labeled_spans.predictions) == 4 assert result.labeled_spans.predictions.resolve() == [ ("label1", "This"), ("label2", "is"), - ("label1", "This"), + # ("label1", "This"), ("label3", "is"), ] annotations_with_sources = [ @@ -59,6 +59,6 @@ def test_document_merge_annotations(): assert predictions_with_scores == [ (LabeledSpan(start=0, end=4, label="label1", score=0.9), ["doc1"]), (LabeledSpan(start=5, end=7, label="label2", score=0.7), ["doc1", "doc2"]), - (LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]), + # (LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]), (LabeledSpan(start=5, end=7, label="label3", score=0.6), ["doc2"]), ] From 7c3b218f790f8d257d62c43eb66d18d964806fdd Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 17 Dec 2024 11:38:31 +0100 Subject: [PATCH 2/2] cleanup --- src/pytorch_ie/utils/document.py | 4 +--- tests/utils/test_document.py | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/pytorch_ie/utils/document.py b/src/pytorch_ie/utils/document.py index 8797c5d8..a74246da 100644 --- a/src/pytorch_ie/utils/document.py +++ b/src/pytorch_ie/utils/document.py @@ -103,8 +103,6 @@ def merge_annotations_from_documents( f"Document IDs do not match: {document.id} and {merged_document.id}" ) - # TODO: add_all_annotations_from_other needs to be fixed! it should return a mapping from - # original annotation *IDs* to new annotations! # Note: this does not check for duplicates! added_annotations = merged_document.add_all_annotations_from_other( other=document, strict=True @@ -114,7 +112,7 @@ def merge_annotations_from_documents( for orig_id, new_annotation in orig_id2new_annotation.items(): added_annotation_id2source_names[new_annotation._id].append(source_name) - # merged_document = deduplicate_annotations(merged_document) + # this will remove duplicates. If duplicates have different scores, the one with the highest score will be kept merged_document = merged_document.deduplicate_annotations() # save source names in metadata (at key metadata_key_source_annotations / metadata_key_source_predictions diff --git a/tests/utils/test_document.py b/tests/utils/test_document.py index d367ae50..3be4d0b3 100644 --- a/tests/utils/test_document.py +++ b/tests/utils/test_document.py @@ -32,11 +32,9 @@ def test_document_merge_annotations(): assert result.id == "doc1" assert set(result.labeled_spans) == set(base_doc.labeled_spans) assert len(result.labeled_spans) == len(base_doc.labeled_spans) == 2 - # assert len(result.labeled_spans.predictions) == 4 assert result.labeled_spans.predictions.resolve() == [ ("label1", "This"), ("label2", "is"), - # ("label1", "This"), ("label3", "is"), ] annotations_with_sources = [ @@ -59,6 +57,5 @@ def test_document_merge_annotations(): assert predictions_with_scores == [ (LabeledSpan(start=0, end=4, label="label1", score=0.9), ["doc1"]), (LabeledSpan(start=5, end=7, label="label2", score=0.7), ["doc1", "doc2"]), - # (LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]), (LabeledSpan(start=5, end=7, label="label3", score=0.6), ["doc2"]), ]