Skip to content

Commit

Permalink
implement target_name(field_name) and target_names(field_name) for Do…
Browse files Browse the repository at this point in the history
…cument
  • Loading branch information
ArneBinder committed Mar 21, 2024
1 parent a5cadff commit aec82f6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,26 @@ def annotation_fields(cls) -> Set[dataclasses.Field]:
if typing.get_origin(ann_field_types[f.name]) is AnnotationLayer
}

@classmethod
def target_names(cls, field_name: str) -> Set[str]:
a_field = next((f for f in cls.annotation_fields() if f.name == field_name), None)
if a_field is None:
raise ValueError(f"'{field_name}' is not an annotation field of {cls.__name__}.")
result = a_field.metadata.get("targets")
if result is None:
raise ValueError(f"Annotation field '{field_name}' has no targets.")
return set(result)

@classmethod
def target_name(cls, field_name: str) -> str:
target_names = cls.target_names(field_name)
if len(target_names) != 1:
raise ValueError(
f"The annotation field '{field_name}' has more or less than one target, "
f"can not return a single target name: {sorted(target_names)}"
)
return list(target_names)[0]

def __getitem__(self, key: str) -> AnnotationLayer:
if key not in self._annotation_fields:
raise KeyError(f"Document has no attribute '{key}'.")
Expand Down
30 changes: 30 additions & 0 deletions tests/core/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.core import Annotation
from pytorch_ie.core.document import (
AnnotationLayer,
Expand Down Expand Up @@ -347,6 +348,35 @@ class MyDocument(Document):
assert annotation_field_names == {"words"}


def test_document_target_names():
@dataclasses.dataclass
class MyDocument(Document):
text: str
words: AnnotationLayer[Span] = annotation_field(target="text")
sentences: AnnotationLayer[Span] = annotation_field(target="text")
belongs_to: AnnotationLayer[BinaryRelation] = annotation_field(
targets=["words", "sentences"]
)

# request target names for annotation field
assert MyDocument.target_names("words") == {"text"}
assert MyDocument.target_name("words") == "text"

# requested field is not an annotation field
with pytest.raises(ValueError) as excinfo:
MyDocument.target_names("text")
assert str(excinfo.value) == f"'text' is not an annotation field of {MyDocument.__name__}."

# requested field has two targets
assert MyDocument.target_names("belongs_to") == {"words", "sentences"}
with pytest.raises(ValueError) as excinfo:
MyDocument.target_name("belongs_to")
assert (
str(excinfo.value)
== f"The annotation field 'belongs_to' has more or less than one target, can not return a single target name: ['sentences', 'words']"
)


def test_document_copy():
@dataclasses.dataclass
class MyDocument(Document):
Expand Down

0 comments on commit aec82f6

Please sign in to comment.