diff --git a/src/pytorch_ie/annotations.py b/src/pytorch_ie/annotations.py index edd285ee..72ce6a87 100644 --- a/src/pytorch_ie/annotations.py +++ b/src/pytorch_ie/annotations.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Tuple +from typing import Any, Optional, Tuple from pytorch_ie.core.document import Annotation @@ -48,6 +48,9 @@ class Label(Annotation): def __post_init__(self) -> None: _post_init_single_label(self) + def resolve(self) -> Any: + return self.label + @dataclass(eq=True, frozen=True) class MultiLabel(Annotation): @@ -57,6 +60,9 @@ class MultiLabel(Annotation): def __post_init__(self) -> None: _post_init_multi_label(self) + def resolve(self) -> Any: + return self.label + @dataclass(eq=True, frozen=True) class Span(Annotation): @@ -68,6 +74,12 @@ def __str__(self) -> str: return super().__str__() return str(self.target[self.start : self.end]) + def resolve(self) -> Any: + if self.is_attached: + return self.target[self.start : self.end] + else: + raise ValueError("Span is not attached to a target.") + @dataclass(eq=True, frozen=True) class LabeledSpan(Span): @@ -77,6 +89,9 @@ class LabeledSpan(Span): def __post_init__(self) -> None: _post_init_single_label(self) + def resolve(self) -> Any: + return super().resolve(), self.label + @dataclass(eq=True, frozen=True) class MultiLabeledSpan(Span): @@ -86,6 +101,9 @@ class MultiLabeledSpan(Span): def __post_init__(self) -> None: _post_init_multi_label(self) + def resolve(self) -> Any: + return super().resolve(), self.label + @dataclass(eq=True, frozen=True) class BinaryRelation(Annotation): @@ -97,6 +115,9 @@ class BinaryRelation(Annotation): def __post_init__(self) -> None: _post_init_single_label(self) + def resolve(self) -> Any: + return self.head.resolve(), self.tail.resolve(), self.label + @dataclass(eq=True, frozen=True) class MultiLabeledBinaryRelation(Annotation): @@ -108,6 +129,9 @@ class MultiLabeledBinaryRelation(Annotation): def __post_init__(self) -> None: _post_init_multi_label(self) + def resolve(self) -> Any: + return self.head.resolve(), self.tail.resolve(), self.label + @dataclass(eq=True, frozen=True) class NaryRelation(Annotation): @@ -119,3 +143,6 @@ class NaryRelation(Annotation): def __post_init__(self) -> None: _post_init_arguments_and_roles(self) _post_init_single_label(self) + + def resolve(self) -> Any: + return tuple((role, arg) for arg, role in zip(self.arguments, self.roles)), self.label