Skip to content

Commit

Permalink
add resolve() to all Annotation implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Mar 21, 2024
1 parent 9c92b47 commit 3dfca5d
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

from pytorch_ie.core.document import Annotation

Expand Down Expand Up @@ -48,6 +48,9 @@ class Label(Annotation):
def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.label


@dataclass(eq=True, frozen=True)
class MultiLabel(Annotation):
Expand All @@ -57,6 +60,9 @@ class MultiLabel(Annotation):
def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return self.label


@dataclass(eq=True, frozen=True)
class Span(Annotation):
Expand All @@ -68,6 +74,12 @@ def __str__(self) -> str:
return super().__str__()
return str(self.target[self.start : self.end])

def resolve(self) -> Any:
if self.is_attached:
return self.target[self.start : self.end]
else:
raise ValueError("Span is not attached to a target.")


@dataclass(eq=True, frozen=True)
class LabeledSpan(Span):
Expand All @@ -77,6 +89,9 @@ class LabeledSpan(Span):
def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return super().resolve(), self.label


@dataclass(eq=True, frozen=True)
class MultiLabeledSpan(Span):
Expand All @@ -86,6 +101,9 @@ class MultiLabeledSpan(Span):
def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return super().resolve(), self.label


@dataclass(eq=True, frozen=True)
class BinaryRelation(Annotation):
Expand All @@ -97,6 +115,9 @@ class BinaryRelation(Annotation):
def __post_init__(self) -> None:
_post_init_single_label(self)

def resolve(self) -> Any:
return self.head.resolve(), self.tail.resolve(), self.label


@dataclass(eq=True, frozen=True)
class MultiLabeledBinaryRelation(Annotation):
Expand All @@ -108,6 +129,9 @@ class MultiLabeledBinaryRelation(Annotation):
def __post_init__(self) -> None:
_post_init_multi_label(self)

def resolve(self) -> Any:
return self.head.resolve(), self.tail.resolve(), self.label


@dataclass(eq=True, frozen=True)
class NaryRelation(Annotation):
Expand All @@ -119,3 +143,6 @@ class NaryRelation(Annotation):
def __post_init__(self) -> None:
_post_init_arguments_and_roles(self)
_post_init_single_label(self)

def resolve(self) -> Any:
return tuple((role, arg) for arg, role in zip(self.arguments, self.roles)), self.label

0 comments on commit 3dfca5d

Please sign in to comment.