diff --git a/src/pytorch_ie/core/document.py b/src/pytorch_ie/core/document.py index ed31999c..d8c87cf3 100644 --- a/src/pytorch_ie/core/document.py +++ b/src/pytorch_ie/core/document.py @@ -518,6 +518,26 @@ def annotation_fields(cls) -> Set[dataclasses.Field]: if typing.get_origin(ann_field_types[f.name]) is AnnotationLayer } + @classmethod + def target_names(cls, field_name: str) -> Set[str]: + a_field = next((f for f in cls.annotation_fields() if f.name == field_name), None) + if a_field is None: + raise ValueError(f"'{field_name}' is not an annotation field of {cls.__name__}.") + result = a_field.metadata.get("targets") + if result is None: + raise ValueError(f"Annotation field '{field_name}' has no targets.") + return set(result) + + @classmethod + def target_name(cls, field_name: str) -> str: + target_names = cls.target_names(field_name) + if len(target_names) != 1: + raise ValueError( + f"The annotation field '{field_name}' has more or less than one target, " + f"can not return a single target name: {sorted(target_names)}" + ) + return list(target_names)[0] + def __getitem__(self, key: str) -> AnnotationLayer: if key not in self._annotation_fields: raise KeyError(f"Document has no attribute '{key}'.") diff --git a/tests/core/test_document.py b/tests/core/test_document.py index 5906319e..d5a3c6c2 100644 --- a/tests/core/test_document.py +++ b/tests/core/test_document.py @@ -4,6 +4,7 @@ import pytest +from pytorch_ie.annotations import BinaryRelation from pytorch_ie.core import Annotation from pytorch_ie.core.document import ( AnnotationLayer, @@ -347,6 +348,35 @@ class MyDocument(Document): assert annotation_field_names == {"words"} +def test_document_target_names(): + @dataclasses.dataclass + class MyDocument(Document): + text: str + words: AnnotationLayer[Span] = annotation_field(target="text") + sentences: AnnotationLayer[Span] = annotation_field(target="text") + belongs_to: AnnotationLayer[BinaryRelation] = annotation_field( + targets=["words", "sentences"] + ) + + # request target names for annotation field + assert MyDocument.target_names("words") == {"text"} + assert MyDocument.target_name("words") == "text" + + # requested field is not an annotation field + with pytest.raises(ValueError) as excinfo: + MyDocument.target_names("text") + assert str(excinfo.value) == f"'text' is not an annotation field of {MyDocument.__name__}." + + # requested field has two targets + assert MyDocument.target_names("belongs_to") == {"words", "sentences"} + with pytest.raises(ValueError) as excinfo: + MyDocument.target_name("belongs_to") + assert ( + str(excinfo.value) + == f"The annotation field 'belongs_to' has more or less than one target, can not return a single target name: ['sentences', 'words']" + ) + + def test_document_copy(): @dataclasses.dataclass class MyDocument(Document):