Skip to content

Commit

Permalink
add_all_annotations_from_other() returns a dict with annotation ids i…
Browse files Browse the repository at this point in the history
…nstead of annotations as keys
  • Loading branch information
ArneBinder committed Oct 8, 2024
1 parent bd5c6ac commit 38c77a1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def add_all_annotations_from_other(
process_predictions: bool = True,
strict: bool = True,
verbose: bool = True,
) -> Dict[str, Dict[Annotation, Annotation]]:
) -> Dict[str, Dict[int, Annotation]]:
"""Adds all annotations from another document to this document. It allows to blacklist annotations
and also to override annotations. It returns the original annotations for which a new annotation was
added to the current document.
Expand Down Expand Up @@ -862,7 +862,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
```
"""
removed_annotations = defaultdict(set, removed_annotations or dict())
added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict)
added_annotations: Dict[str, Dict[int, Annotation]] = defaultdict(dict)

annotation_store: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
named_annotation_fields = {field.name: field for field in self.annotation_fields()}
Expand Down Expand Up @@ -905,7 +905,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
if ann._id != new_ann._id:
annotation_store[field_name][ann._id] = new_ann
self[field_name].append(new_ann)
added_annotations[field_name][ann] = new_ann
added_annotations[field_name][ann._id] = new_ann
else:
if strict:
raise ValueError(
Expand All @@ -930,7 +930,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
if ann._id != new_ann._id:
annotation_store[field_name][ann._id] = new_ann
self[field_name].predictions.append(new_ann)
added_annotations[field_name][ann] = new_ann
added_annotations[field_name][ann._id] = new_ann
else:
if strict:
raise ValueError(
Expand Down
21 changes: 11 additions & 10 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,10 +665,11 @@ def test_document_extend_from_other_full_copy(text_document):
for layer_name, annotation_mapping in added_annotations.items():
assert len(annotation_mapping) > 0
available_annotations = text_document[layer_name]
assert set(annotation_mapping) == set(available_annotations)
available_annotation_ids = [a._id for a in available_annotations]
assert set(annotation_mapping) == set(available_annotation_ids)
assert len(annotation_mapping) == 1
# since we have only one annotation, we can construct the expected mapping
assert annotation_mapping == {available_annotations[0]: doc_new[layer_name][0]}
assert annotation_mapping == {available_annotation_ids[0]: doc_new[layer_name][0]}


def test_document_extend_from_other_wrong_override_annotation_mapping(text_document):
Expand Down Expand Up @@ -711,16 +712,16 @@ class TestDocument2(TokenBasedDocument):
added_annotation_sets = {k: set(v) for k, v in added_annotations.items()}
# check that the added annotations are as expected (the entity annotations are already there)
assert added_annotation_sets == {
"relations": set(text_document.relations),
"relation_attributes": set(text_document.relation_attributes),
"labels": set(text_document.labels),
"relations": {ann._id for ann in text_document.relations},
"relation_attributes": {ann._id for ann in text_document.relation_attributes},
"labels": {ann._id for ann in text_document.labels},
}
for layer_name, annotation_mapping in added_annotations.items():
text_annotations = text_document[layer_name]
token_annotations = token_document[layer_name]
assert len(annotation_mapping) == len(text_annotations) == len(token_annotations) == 1
# since we have only one annotation, we can construct the expected mapping
assert annotation_mapping == {text_annotations[0]: token_annotations[0]}
assert annotation_mapping == {text_annotations[0]._id: token_annotations[0]}

assert (
len(token_document.entities1)
Expand Down Expand Up @@ -753,12 +754,12 @@ def test_document_extend_from_other_remove(text_document):
added_annotation_sets = {k: set(v) for k, v in added_annotations.items()}
# the only entity in entities1 is removed and since the relation has it as head, the relation is removed as well
assert added_annotation_sets == {
"entities2": set(text_document.entities2),
"labels": set(text_document.labels),
"entities2": {ann._id for ann in text_document.entities2},
"labels": {ann._id for ann in text_document.labels},
}
assert added_annotations == {
"entities2": {text_document.entities2[0]: doc_new.entities2[0]},
"labels": {text_document.labels[0]: doc_new.labels[0]},
"entities2": {text_document.entities2[0]._id: doc_new.entities2[0]},
"labels": {text_document.labels[0]._id: doc_new.labels[0]},
}

assert len(doc_new.entities1) == 0
Expand Down

0 comments on commit 38c77a1

Please sign in to comment.