diff --git a/src/pytorch_ie/core/document.py b/src/pytorch_ie/core/document.py index 776a6387..9f8107ab 100644 --- a/src/pytorch_ie/core/document.py +++ b/src/pytorch_ie/core/document.py @@ -779,7 +779,7 @@ def add_all_annotations_from_other( process_predictions: bool = True, strict: bool = True, verbose: bool = True, - ) -> Dict[str, Dict[Annotation, Annotation]]: + ) -> Dict[str, Dict[int, Annotation]]: """Adds all annotations from another document to this document. It allows to blacklist annotations and also to override annotations. It returns the original annotations for which a new annotation was added to the current document. @@ -862,7 +862,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc ``` """ removed_annotations = defaultdict(set, removed_annotations or dict()) - added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict) + added_annotations: Dict[str, Dict[int, Annotation]] = defaultdict(dict) annotation_store: Dict[str, Dict[int, Annotation]] = defaultdict(dict) named_annotation_fields = {field.name: field for field in self.annotation_fields()} @@ -905,7 +905,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc if ann._id != new_ann._id: annotation_store[field_name][ann._id] = new_ann self[field_name].append(new_ann) - added_annotations[field_name][ann] = new_ann + added_annotations[field_name][ann._id] = new_ann else: if strict: raise ValueError( @@ -930,7 +930,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc if ann._id != new_ann._id: annotation_store[field_name][ann._id] = new_ann self[field_name].predictions.append(new_ann) - added_annotations[field_name][ann] = new_ann + added_annotations[field_name][ann._id] = new_ann else: if strict: raise ValueError( diff --git a/tests/test_document.py b/tests/test_document.py index 227ac839..8c43358a 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -665,10 +665,11 @@ def test_document_extend_from_other_full_copy(text_document): for layer_name, annotation_mapping in added_annotations.items(): assert len(annotation_mapping) > 0 available_annotations = text_document[layer_name] - assert set(annotation_mapping) == set(available_annotations) + available_annotation_ids = [a._id for a in available_annotations] + assert set(annotation_mapping) == set(available_annotation_ids) assert len(annotation_mapping) == 1 # since we have only one annotation, we can construct the expected mapping - assert annotation_mapping == {available_annotations[0]: doc_new[layer_name][0]} + assert annotation_mapping == {available_annotation_ids[0]: doc_new[layer_name][0]} def test_document_extend_from_other_wrong_override_annotation_mapping(text_document): @@ -711,16 +712,16 @@ class TestDocument2(TokenBasedDocument): added_annotation_sets = {k: set(v) for k, v in added_annotations.items()} # check that the added annotations are as expected (the entity annotations are already there) assert added_annotation_sets == { - "relations": set(text_document.relations), - "relation_attributes": set(text_document.relation_attributes), - "labels": set(text_document.labels), + "relations": {ann._id for ann in text_document.relations}, + "relation_attributes": {ann._id for ann in text_document.relation_attributes}, + "labels": {ann._id for ann in text_document.labels}, } for layer_name, annotation_mapping in added_annotations.items(): text_annotations = text_document[layer_name] token_annotations = token_document[layer_name] assert len(annotation_mapping) == len(text_annotations) == len(token_annotations) == 1 # since we have only one annotation, we can construct the expected mapping - assert annotation_mapping == {text_annotations[0]: token_annotations[0]} + assert annotation_mapping == {text_annotations[0]._id: token_annotations[0]} assert ( len(token_document.entities1) @@ -753,12 +754,12 @@ def test_document_extend_from_other_remove(text_document): added_annotation_sets = {k: set(v) for k, v in added_annotations.items()} # the only entity in entities1 is removed and since the relation has it as head, the relation is removed as well assert added_annotation_sets == { - "entities2": set(text_document.entities2), - "labels": set(text_document.labels), + "entities2": {ann._id for ann in text_document.entities2}, + "labels": {ann._id for ann in text_document.labels}, } assert added_annotations == { - "entities2": {text_document.entities2[0]: doc_new.entities2[0]}, - "labels": {text_document.labels[0]: doc_new.labels[0]}, + "entities2": {text_document.entities2[0]._id: doc_new.entities2[0]}, + "labels": {text_document.labels[0]._id: doc_new.labels[0]}, } assert len(doc_new.entities1) == 0