Skip to content

Commit

Permalink
implement Annotation.resolve() (#409)
Browse files Browse the repository at this point in the history
* implement Annotation.resolve() interface

* add resolve() to all Annotation implementations

* improve error message

* add tests

* put label in the front
  • Loading branch information
ArneBinder authored Mar 21, 2024
1 parent b22f9df commit 821519b
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 1 deletion.
32 changes: 31 additions & 1 deletion src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

from pytorch_ie.core.document import Annotation

Expand Down Expand Up @@ -48,6 +48,9 @@ class Label(Annotation):
def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.label


@dataclass(eq=True, frozen=True)
class MultiLabel(Annotation):
Expand All @@ -57,6 +60,9 @@ class MultiLabel(Annotation):
def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return self.label


@dataclass(eq=True, frozen=True)
class Span(Annotation):
Expand All @@ -68,6 +74,12 @@ def __str__(self) -> str:
return super().__str__()
return str(self.target[self.start : self.end])

def resolve(self) -> Any:
if self.is_attached:
return self.target[self.start : self.end]
else:
raise ValueError(f"{self} is not attached to a target.")


@dataclass(eq=True, frozen=True)
class LabeledSpan(Span):
Expand All @@ -77,6 +89,9 @@ class LabeledSpan(Span):
def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
class MultiLabeledSpan(Span):
Expand All @@ -86,6 +101,9 @@ class MultiLabeledSpan(Span):
def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return self.label, super().resolve()


@dataclass(eq=True, frozen=True)
class BinaryRelation(Annotation):
Expand All @@ -97,6 +115,9 @@ class BinaryRelation(Annotation):
def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.label, (self.head.resolve(), self.tail.resolve())


@dataclass(eq=True, frozen=True)
class MultiLabeledBinaryRelation(Annotation):
Expand All @@ -108,6 +129,9 @@ class MultiLabeledBinaryRelation(Annotation):
def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return self.label, (self.head.resolve(), self.tail.resolve())


@dataclass(eq=True, frozen=True)
class NaryRelation(Annotation):
Expand All @@ -119,3 +143,9 @@ class NaryRelation(Annotation):
def __post_init__(self) -> None:
_post_init_arguments_and_roles(self)
_post_init_single_label(self)

def resolve(self) -> Any:
return (
self.label,
tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)),
)
3 changes: 3 additions & 0 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def __lt__(self, other: "Annotation") -> bool:
return value < other_value
return False

def resolve(self) -> Any:
raise NotImplementedError(f"resolve() is not implemented for {self.__class__}")


T = TypeVar("T", covariant=False, bound="Annotation")

Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,17 @@ class DummyWithNestedAnnotation(Annotation):
)


def test_annotation_resolve():
@dataclasses.dataclass(eq=True, frozen=True)
class Dummy(Annotation):
a: int

dummy = Dummy(a=1)
with pytest.raises(NotImplementedError) as excinfo:
dummy.resolve()
assert str(excinfo.value) == f"resolve() is not implemented for {Dummy}"


def test_annotation_is_attached():
@dataclasses.dataclass
class MyDocument(Document):
Expand Down
144 changes: 144 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import dataclasses
import re

import pytest

from pytorch_ie import AnnotationLayer, annotation_field
from pytorch_ie.annotations import (
BinaryRelation,
Label,
LabeledSpan,
MultiLabel,
MultiLabeledBinaryRelation,
MultiLabeledSpan,
NaryRelation,
Span,
)
from pytorch_ie.documents import TextBasedDocument
from tests.core.test_document import _test_annotation_reconstruction


def test_label():
label1 = Label(label="label1")
assert label1.label == "label1"
assert label1.score == pytest.approx(1.0)
assert label1.resolve() == "label1"

label2 = Label(label="label2", score=0.5)
assert label2.label == "label2"
Expand All @@ -36,6 +41,7 @@ def test_multilabel():
multilabel1 = MultiLabel(label=("label1", "label2"))
assert multilabel1.label == ("label1", "label2")
assert multilabel1.score == pytest.approx((1.0, 1.0))
assert multilabel1.resolve() == ("label1", "label2")

multilabel2 = MultiLabel(label=("label3", "label4"), score=(0.4, 0.5))
assert multilabel2.label == ("label3", "label4")
Expand Down Expand Up @@ -68,6 +74,19 @@ def test_span():

_test_annotation_reconstruction(span)

with pytest.raises(ValueError) as excinfo:
span.resolve()
assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
spans: AnnotationLayer[Span] = annotation_field(target="text")

doc = TestDocument(text="Hello, world!")
span = Span(start=7, end=12)
doc.spans.append(span)
assert span.resolve() == "world"


def test_labeled_span():
labeled_span1 = LabeledSpan(start=1, end=2, label="label1")
Expand All @@ -92,6 +111,22 @@ def test_labeled_span():

_test_annotation_reconstruction(labeled_span2)

with pytest.raises(ValueError) as excinfo:
labeled_span1.resolve()
assert (
str(excinfo.value)
== "LabeledSpan(start=1, end=2, label='label1', score=1.0) is not attached to a target."
)

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text")

doc = TestDocument(text="Hello, world!")
labeled_span = LabeledSpan(start=7, end=12, label="LOC")
doc.spans.append(labeled_span)
assert labeled_span.resolve() == ("LOC", "world")


def test_multilabeled_span():
multilabeled_span1 = MultiLabeledSpan(start=1, end=2, label=("label1", "label2"))
Expand Down Expand Up @@ -123,6 +158,22 @@ def test_multilabeled_span():
):
MultiLabeledSpan(start=5, end=6, label=("label5", "label6"), score=(0.1, 0.2, 0.3))

with pytest.raises(ValueError) as excinfo:
multilabeled_span1.resolve()
assert (
str(excinfo.value)
== "MultiLabeledSpan(start=1, end=2, label=('label1', 'label2'), score=(1.0, 1.0)) is not attached to a target."
)

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
spans: AnnotationLayer[MultiLabeledSpan] = annotation_field(target="text")

doc = TestDocument(text="Hello, world!")
multilabeled_span = MultiLabeledSpan(start=7, end=12, label=("LOC", "ORG"))
doc.spans.append(multilabeled_span)
assert multilabeled_span.resolve() == (("LOC", "ORG"), "world")


def test_binary_relation():
head = Span(start=1, end=2)
Expand Down Expand Up @@ -160,6 +211,23 @@ def test_binary_relation():
):
BinaryRelation.fromdict(binary_relation2.asdict())

with pytest.raises(ValueError) as excinfo:
binary_relation1.resolve()
assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
spans: AnnotationLayer[Span] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans")

doc = TestDocument(text="Hello, world!")
head = Span(start=0, end=5)
tail = Span(start=7, end=12)
doc.spans.extend([head, tail])
relation = BinaryRelation(head=head, tail=tail, label="LABEL")
doc.relations.append(relation)
assert relation.resolve() == ("LABEL", ("Hello", "world"))


def test_multilabeled_binary_relation():
head = Span(start=1, end=2)
Expand Down Expand Up @@ -205,3 +273,79 @@ def test_multilabeled_binary_relation():
MultiLabeledBinaryRelation(
head=head, tail=tail, label=("label5", "label6"), score=(0.1, 0.2, 0.3)
)

with pytest.raises(ValueError) as excinfo:
binary_relation1.resolve()
assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
spans: AnnotationLayer[Span] = annotation_field(target="text")
relations: AnnotationLayer[MultiLabeledBinaryRelation] = annotation_field(target="spans")

doc = TestDocument(text="Hello, world!")
head = Span(start=0, end=5)
tail = Span(start=7, end=12)
doc.spans.extend([head, tail])
relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("LABEL1", "LABEL2"))
doc.relations.append(relation)
assert relation.resolve() == (("LABEL1", "LABEL2"), ("Hello", "world"))


def test_nary_relation():
arg1 = Span(start=1, end=2)
arg2 = Span(start=3, end=4)
arg3 = Span(start=5, end=6)

nary_relation1 = NaryRelation(
arguments=(arg1, arg2, arg3), roles=("role1", "role2", "role3"), label="label1"
)

assert nary_relation1.arguments == (arg1, arg2, arg3)
assert nary_relation1.roles == ("role1", "role2", "role3")
assert nary_relation1.label == "label1"
assert nary_relation1.score == pytest.approx(1.0)

assert nary_relation1.asdict() == {
"_id": nary_relation1._id,
"arguments": [arg1._id, arg2._id, arg3._id],
"roles": ("role1", "role2", "role3"),
"label": "label1",
"score": 1.0,
}

annotation_store = {
arg1._id: arg1,
arg2._id: arg2,
arg3._id: arg3,
}
_test_annotation_reconstruction(nary_relation1, annotation_store=annotation_store)

with pytest.raises(
ValueError,
match=re.escape("Unable to resolve the annotation id without annotation_store."),
):
NaryRelation.fromdict(nary_relation1.asdict())

with pytest.raises(ValueError) as excinfo:
nary_relation1.resolve()
assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
spans: AnnotationLayer[Span] = annotation_field(target="text")
relations: AnnotationLayer[NaryRelation] = annotation_field(target="spans")

doc = TestDocument(text="Hello, world A and B!")
arg1 = Span(start=0, end=5)
arg2 = Span(start=7, end=14)
arg3 = Span(start=19, end=20)
doc.spans.extend([arg1, arg2, arg3])
relation = NaryRelation(
arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="LABEL"
)
doc.relations.append(relation)
assert relation.resolve() == (
"LABEL",
(("ARG1", "Hello"), ("ARG2", "world A"), ("ARG3", "B")),
)

0 comments on commit 821519b

Please sign in to comment.