diff --git a/src/pytorch_ie/annotations.py b/src/pytorch_ie/annotations.py index edd285ee..2899e18a 100644 --- a/src/pytorch_ie/annotations.py +++ b/src/pytorch_ie/annotations.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Tuple +from typing import Any, Optional, Tuple from pytorch_ie.core.document import Annotation @@ -48,6 +48,9 @@ class Label(Annotation): def __post_init__(self) -> None: _post_init_single_label(self) + def resolve(self) -> Any: + return self.label + @dataclass(eq=True, frozen=True) class MultiLabel(Annotation): @@ -57,6 +60,9 @@ class MultiLabel(Annotation): def __post_init__(self) -> None: _post_init_multi_label(self) + def resolve(self) -> Any: + return self.label + @dataclass(eq=True, frozen=True) class Span(Annotation): @@ -68,6 +74,12 @@ def __str__(self) -> str: return super().__str__() return str(self.target[self.start : self.end]) + def resolve(self) -> Any: + if self.is_attached: + return self.target[self.start : self.end] + else: + raise ValueError(f"{self} is not attached to a target.") + @dataclass(eq=True, frozen=True) class LabeledSpan(Span): @@ -77,6 +89,9 @@ class LabeledSpan(Span): def __post_init__(self) -> None: _post_init_single_label(self) + def resolve(self) -> Any: + return self.label, super().resolve() + @dataclass(eq=True, frozen=True) class MultiLabeledSpan(Span): @@ -86,6 +101,9 @@ class MultiLabeledSpan(Span): def __post_init__(self) -> None: _post_init_multi_label(self) + def resolve(self) -> Any: + return self.label, super().resolve() + @dataclass(eq=True, frozen=True) class BinaryRelation(Annotation): @@ -97,6 +115,9 @@ class BinaryRelation(Annotation): def __post_init__(self) -> None: _post_init_single_label(self) + def resolve(self) -> Any: + return self.label, (self.head.resolve(), self.tail.resolve()) + @dataclass(eq=True, frozen=True) class MultiLabeledBinaryRelation(Annotation): @@ -108,6 +129,9 @@ class MultiLabeledBinaryRelation(Annotation): def __post_init__(self) -> None: _post_init_multi_label(self) + def resolve(self) -> Any: + return self.label, (self.head.resolve(), self.tail.resolve()) + @dataclass(eq=True, frozen=True) class NaryRelation(Annotation): @@ -119,3 +143,9 @@ class NaryRelation(Annotation): def __post_init__(self) -> None: _post_init_arguments_and_roles(self) _post_init_single_label(self) + + def resolve(self) -> Any: + return ( + self.label, + tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)), + ) diff --git a/src/pytorch_ie/core/document.py b/src/pytorch_ie/core/document.py index 58571dc2..97eba705 100644 --- a/src/pytorch_ie/core/document.py +++ b/src/pytorch_ie/core/document.py @@ -385,6 +385,9 @@ def __lt__(self, other: "Annotation") -> bool: return value < other_value return False + def resolve(self) -> Any: + raise NotImplementedError(f"resolve() is not implemented for {self.__class__}") + T = TypeVar("T", covariant=False, bound="Annotation") diff --git a/tests/core/test_document.py b/tests/core/test_document.py index ee5dbb5d..20986f96 100644 --- a/tests/core/test_document.py +++ b/tests/core/test_document.py @@ -304,6 +304,17 @@ class DummyWithNestedAnnotation(Annotation): ) +def test_annotation_resolve(): + @dataclasses.dataclass(eq=True, frozen=True) + class Dummy(Annotation): + a: int + + dummy = Dummy(a=1) + with pytest.raises(NotImplementedError) as excinfo: + dummy.resolve() + assert str(excinfo.value) == f"resolve() is not implemented for {Dummy}" + + def test_annotation_is_attached(): @dataclasses.dataclass class MyDocument(Document): diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 85af40e2..1cdb6abb 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -1,7 +1,9 @@ +import dataclasses import re import pytest +from pytorch_ie import AnnotationLayer, annotation_field from pytorch_ie.annotations import ( BinaryRelation, Label, @@ -9,8 +11,10 @@ MultiLabel, MultiLabeledBinaryRelation, MultiLabeledSpan, + NaryRelation, Span, ) +from pytorch_ie.documents import TextBasedDocument from tests.core.test_document import _test_annotation_reconstruction @@ -18,6 +22,7 @@ def test_label(): label1 = Label(label="label1") assert label1.label == "label1" assert label1.score == pytest.approx(1.0) + assert label1.resolve() == "label1" label2 = Label(label="label2", score=0.5) assert label2.label == "label2" @@ -36,6 +41,7 @@ def test_multilabel(): multilabel1 = MultiLabel(label=("label1", "label2")) assert multilabel1.label == ("label1", "label2") assert multilabel1.score == pytest.approx((1.0, 1.0)) + assert multilabel1.resolve() == ("label1", "label2") multilabel2 = MultiLabel(label=("label3", "label4"), score=(0.4, 0.5)) assert multilabel2.label == ("label3", "label4") @@ -68,6 +74,19 @@ def test_span(): _test_annotation_reconstruction(span) + with pytest.raises(ValueError) as excinfo: + span.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + + doc = TestDocument(text="Hello, world!") + span = Span(start=7, end=12) + doc.spans.append(span) + assert span.resolve() == "world" + def test_labeled_span(): labeled_span1 = LabeledSpan(start=1, end=2, label="label1") @@ -92,6 +111,22 @@ def test_labeled_span(): _test_annotation_reconstruction(labeled_span2) + with pytest.raises(ValueError) as excinfo: + labeled_span1.resolve() + assert ( + str(excinfo.value) + == "LabeledSpan(start=1, end=2, label='label1', score=1.0) is not attached to a target." + ) + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + + doc = TestDocument(text="Hello, world!") + labeled_span = LabeledSpan(start=7, end=12, label="LOC") + doc.spans.append(labeled_span) + assert labeled_span.resolve() == ("LOC", "world") + def test_multilabeled_span(): multilabeled_span1 = MultiLabeledSpan(start=1, end=2, label=("label1", "label2")) @@ -123,6 +158,22 @@ def test_multilabeled_span(): ): MultiLabeledSpan(start=5, end=6, label=("label5", "label6"), score=(0.1, 0.2, 0.3)) + with pytest.raises(ValueError) as excinfo: + multilabeled_span1.resolve() + assert ( + str(excinfo.value) + == "MultiLabeledSpan(start=1, end=2, label=('label1', 'label2'), score=(1.0, 1.0)) is not attached to a target." + ) + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[MultiLabeledSpan] = annotation_field(target="text") + + doc = TestDocument(text="Hello, world!") + multilabeled_span = MultiLabeledSpan(start=7, end=12, label=("LOC", "ORG")) + doc.spans.append(multilabeled_span) + assert multilabeled_span.resolve() == (("LOC", "ORG"), "world") + def test_binary_relation(): head = Span(start=1, end=2) @@ -160,6 +211,23 @@ def test_binary_relation(): ): BinaryRelation.fromdict(binary_relation2.asdict()) + with pytest.raises(ValueError) as excinfo: + binary_relation1.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans") + + doc = TestDocument(text="Hello, world!") + head = Span(start=0, end=5) + tail = Span(start=7, end=12) + doc.spans.extend([head, tail]) + relation = BinaryRelation(head=head, tail=tail, label="LABEL") + doc.relations.append(relation) + assert relation.resolve() == ("LABEL", ("Hello", "world")) + def test_multilabeled_binary_relation(): head = Span(start=1, end=2) @@ -205,3 +273,79 @@ def test_multilabeled_binary_relation(): MultiLabeledBinaryRelation( head=head, tail=tail, label=("label5", "label6"), score=(0.1, 0.2, 0.3) ) + + with pytest.raises(ValueError) as excinfo: + binary_relation1.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + relations: AnnotationLayer[MultiLabeledBinaryRelation] = annotation_field(target="spans") + + doc = TestDocument(text="Hello, world!") + head = Span(start=0, end=5) + tail = Span(start=7, end=12) + doc.spans.extend([head, tail]) + relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("LABEL1", "LABEL2")) + doc.relations.append(relation) + assert relation.resolve() == (("LABEL1", "LABEL2"), ("Hello", "world")) + + +def test_nary_relation(): + arg1 = Span(start=1, end=2) + arg2 = Span(start=3, end=4) + arg3 = Span(start=5, end=6) + + nary_relation1 = NaryRelation( + arguments=(arg1, arg2, arg3), roles=("role1", "role2", "role3"), label="label1" + ) + + assert nary_relation1.arguments == (arg1, arg2, arg3) + assert nary_relation1.roles == ("role1", "role2", "role3") + assert nary_relation1.label == "label1" + assert nary_relation1.score == pytest.approx(1.0) + + assert nary_relation1.asdict() == { + "_id": nary_relation1._id, + "arguments": [arg1._id, arg2._id, arg3._id], + "roles": ("role1", "role2", "role3"), + "label": "label1", + "score": 1.0, + } + + annotation_store = { + arg1._id: arg1, + arg2._id: arg2, + arg3._id: arg3, + } + _test_annotation_reconstruction(nary_relation1, annotation_store=annotation_store) + + with pytest.raises( + ValueError, + match=re.escape("Unable to resolve the annotation id without annotation_store."), + ): + NaryRelation.fromdict(nary_relation1.asdict()) + + with pytest.raises(ValueError) as excinfo: + nary_relation1.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + relations: AnnotationLayer[NaryRelation] = annotation_field(target="spans") + + doc = TestDocument(text="Hello, world A and B!") + arg1 = Span(start=0, end=5) + arg2 = Span(start=7, end=14) + arg3 = Span(start=19, end=20) + doc.spans.extend([arg1, arg2, arg3]) + relation = NaryRelation( + arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="LABEL" + ) + doc.relations.append(relation) + assert relation.resolve() == ( + "LABEL", + (("ARG1", "Hello"), ("ARG2", "world A"), ("ARG3", "B")), + )