Skip to content

Commit

Permalink
put label in the front
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Mar 21, 2024
1 parent c0c7903 commit 18737a9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return super().resolve(), self.label
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
Expand All @@ -102,7 +102,7 @@ def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return super().resolve(), self.label
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
Expand All @@ -116,7 +116,7 @@ def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.head.resolve(), self.tail.resolve(), self.label
return self.label, (self.head.resolve(), self.tail.resolve())


@dataclass(eq=True, frozen=True)
Expand All @@ -130,7 +130,7 @@ def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return self.head.resolve(), self.tail.resolve(), self.label
return self.label, (self.head.resolve(), self.tail.resolve())


@dataclass(eq=True, frozen=True)
Expand All @@ -146,6 +146,6 @@ def __post_init__(self) -> None:

def resolve(self) -> Any:
return (
tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)),
self.label,
tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)),
)
16 changes: 8 additions & 8 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class TestDocument(TextBasedDocument):
doc = TestDocument(text="Hello, world!")
labeled_span = LabeledSpan(start=7, end=12, label="LOC")
doc.spans.append(labeled_span)
assert labeled_span.resolve() == ("world", "LOC")
assert labeled_span.resolve() == ("LOC", "world")


def test_multilabeled_span():
Expand Down Expand Up @@ -172,7 +172,7 @@ class TestDocument(TextBasedDocument):
doc = TestDocument(text="Hello, world!")
multilabeled_span = MultiLabeledSpan(start=7, end=12, label=("LOC", "ORG"))
doc.spans.append(multilabeled_span)
assert multilabeled_span.resolve() == ("world", ("LOC", "ORG"))
assert multilabeled_span.resolve() == (("LOC", "ORG"), "world")


def test_binary_relation():
Expand Down Expand Up @@ -224,9 +224,9 @@ class TestDocument(TextBasedDocument):
head = Span(start=0, end=5)
tail = Span(start=7, end=12)
doc.spans.extend([head, tail])
relation = BinaryRelation(head=head, tail=tail, label="RELATION")
relation = BinaryRelation(head=head, tail=tail, label="LABEL")
doc.relations.append(relation)
assert relation.resolve() == ("Hello", "world", "RELATION")
assert relation.resolve() == ("LABEL", ("Hello", "world"))


def test_multilabeled_binary_relation():
Expand Down Expand Up @@ -287,9 +287,9 @@ class TestDocument(TextBasedDocument):
head = Span(start=0, end=5)
tail = Span(start=7, end=12)
doc.spans.extend([head, tail])
relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("RELATION1", "RELATION2"))
relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("LABEL1", "LABEL2"))
doc.relations.append(relation)
assert relation.resolve() == ("Hello", "world", ("RELATION1", "RELATION2"))
assert relation.resolve() == (("LABEL1", "LABEL2"), ("Hello", "world"))


def test_nary_relation():
Expand Down Expand Up @@ -342,10 +342,10 @@ class TestDocument(TextBasedDocument):
arg3 = Span(start=19, end=20)
doc.spans.extend([arg1, arg2, arg3])
relation = NaryRelation(
arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="RELATION"
arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="LABEL"
)
doc.relations.append(relation)
assert relation.resolve() == (
"LABEL",
(("ARG1", "Hello"), ("ARG2", "world A"), ("ARG3", "B")),
"RELATION",
)

0 comments on commit 18737a9

Please sign in to comment.