diff --git a/src/pytorch_ie/annotations.py b/src/pytorch_ie/annotations.py index 40d719a8..2899e18a 100644 --- a/src/pytorch_ie/annotations.py +++ b/src/pytorch_ie/annotations.py @@ -90,7 +90,7 @@ def __post_init__(self) -> None: _post_init_single_label(self) def resolve(self) -> Any: - return super().resolve(), self.label + return self.label, super().resolve() @dataclass(eq=True, frozen=True) @@ -102,7 +102,7 @@ def __post_init__(self) -> None: _post_init_multi_label(self) def resolve(self) -> Any: - return super().resolve(), self.label + return self.label, super().resolve() @dataclass(eq=True, frozen=True) @@ -116,7 +116,7 @@ def __post_init__(self) -> None: _post_init_single_label(self) def resolve(self) -> Any: - return self.head.resolve(), self.tail.resolve(), self.label + return self.label, (self.head.resolve(), self.tail.resolve()) @dataclass(eq=True, frozen=True) @@ -130,7 +130,7 @@ def __post_init__(self) -> None: _post_init_multi_label(self) def resolve(self) -> Any: - return self.head.resolve(), self.tail.resolve(), self.label + return self.label, (self.head.resolve(), self.tail.resolve()) @dataclass(eq=True, frozen=True) @@ -146,6 +146,6 @@ def __post_init__(self) -> None: def resolve(self) -> Any: return ( - tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)), self.label, + tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)), ) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 49219e46..1cdb6abb 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -125,7 +125,7 @@ class TestDocument(TextBasedDocument): doc = TestDocument(text="Hello, world!") labeled_span = LabeledSpan(start=7, end=12, label="LOC") doc.spans.append(labeled_span) - assert labeled_span.resolve() == ("world", "LOC") + assert labeled_span.resolve() == ("LOC", "world") def test_multilabeled_span(): @@ -172,7 +172,7 @@ class TestDocument(TextBasedDocument): doc = TestDocument(text="Hello, world!") multilabeled_span = MultiLabeledSpan(start=7, end=12, label=("LOC", "ORG")) doc.spans.append(multilabeled_span) - assert multilabeled_span.resolve() == ("world", ("LOC", "ORG")) + assert multilabeled_span.resolve() == (("LOC", "ORG"), "world") def test_binary_relation(): @@ -224,9 +224,9 @@ class TestDocument(TextBasedDocument): head = Span(start=0, end=5) tail = Span(start=7, end=12) doc.spans.extend([head, tail]) - relation = BinaryRelation(head=head, tail=tail, label="RELATION") + relation = BinaryRelation(head=head, tail=tail, label="LABEL") doc.relations.append(relation) - assert relation.resolve() == ("Hello", "world", "RELATION") + assert relation.resolve() == ("LABEL", ("Hello", "world")) def test_multilabeled_binary_relation(): @@ -287,9 +287,9 @@ class TestDocument(TextBasedDocument): head = Span(start=0, end=5) tail = Span(start=7, end=12) doc.spans.extend([head, tail]) - relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("RELATION1", "RELATION2")) + relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("LABEL1", "LABEL2")) doc.relations.append(relation) - assert relation.resolve() == ("Hello", "world", ("RELATION1", "RELATION2")) + assert relation.resolve() == (("LABEL1", "LABEL2"), ("Hello", "world")) def test_nary_relation(): @@ -342,10 +342,10 @@ class TestDocument(TextBasedDocument): arg3 = Span(start=19, end=20) doc.spans.extend([arg1, arg2, arg3]) relation = NaryRelation( - arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="RELATION" + arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="LABEL" ) doc.relations.append(relation) assert relation.resolve() == ( + "LABEL", (("ARG1", "Hello"), ("ARG2", "world A"), ("ARG3", "B")), - "RELATION", )