Skip to content

Commit

Permalink
add resolve() to BaseAnnotationList (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Apr 4, 2024
1 parent 83929c4 commit 27b4c20
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ def target_layer(self) -> "AnnotationLayer":
)
return list(tgt_layers.values())[0]

def resolve(self) -> List[Any]:
return [annotation.resolve() for annotation in self]


class AnnotationLayer(BaseAnnotationList[T]):
def __init__(self, document: "Document", targets: List["str"]):
Expand Down
22 changes: 15 additions & 7 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TestDocument(TextDocument):
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
label: AnnotationLayer[Label] = annotation_field()

document1 = TestDocument(text="test1")
document1 = TestDocument(text="text1 and some more text.")
assert isinstance(document1.sentences, AnnotationLayer)
assert isinstance(document1.entities, AnnotationLayer)
assert isinstance(document1.relations, AnnotationLayer)
Expand All @@ -76,21 +76,25 @@ class TestDocument(TextDocument):
"label",
}

span1 = Span(start=1, end=2)
span2 = Span(start=3, end=4)
span1 = Span(start=0, end=5)
span2 = Span(start=6, end=9)

document1.sentences.append(span1)
document1.sentences.append(span2)
assert len(document1.sentences) == 2
assert document1.sentences[:2] == [span1, span2]
assert document1.sentences[0].target == document1.text
resolved_sentences = document1.sentences.resolve()
assert resolved_sentences == ["text1", "and"]

labeled_span1 = LabeledSpan(start=1, end=2, label="label1")
labeled_span2 = LabeledSpan(start=3, end=4, label="label2")
labeled_span1 = LabeledSpan(start=0, end=5, label="label1")
labeled_span2 = LabeledSpan(start=6, end=9, label="label2")
document1.entities.append(labeled_span1)
document1.entities.append(labeled_span2)
assert len(document1.entities) == 2
assert document1.sentences[0].target == document1.text
resolved_entities = document1.entities.resolve()
assert resolved_entities == [("label1", "text1"), ("label2", "and")]

relation1 = BinaryRelation(head=labeled_span1, tail=labeled_span2, label="label1")
relation2 = BinaryRelation(head=labeled_span1, tail=labeled_span2, label="label1")
Expand All @@ -101,6 +105,8 @@ class TestDocument(TextDocument):
document1.relations.append(relation1)
assert len(document1.relations) == 1
assert document1.relations[0].target == document1.entities
resolved_relations = document1.relations.resolve()
assert resolved_relations == [("label1", (("label1", "text1"), ("label2", "and")))]

assert document1 == TestDocument.fromdict(document1.asdict())

Expand All @@ -113,11 +119,13 @@ class TestDocument(TextDocument):
):
document1["non_existing_annotation"]

span3 = Span(start=5, end=6)
span4 = Span(start=7, end=8)
span3 = Span(start=10, end=14)
span4 = Span(start=15, end=19)

document1.sentences.predictions.append(span3)
document1.sentences.predictions.append(span4)
resolved_sentences_predictions = document1.sentences.predictions.resolve()
assert resolved_sentences_predictions == ["some", "more"]
# add a prediction that is also an annotation
# remove the annotation to allow reassigning it
relation1_popped = document1.relations.pop(0)
Expand Down

0 comments on commit 27b4c20

Please sign in to comment.