From ef4c57485bf71c7847b0c03cd7fe2aea2556360e Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Wed, 8 Nov 2023 15:24:54 +0100 Subject: [PATCH] rename core classes (#369) * rename AnnotationList to AnnotationLayer * rename AnnotationList to AnnotationLayer in README.md * rename RequiresDocumentTypeMixin to WithDocumentTypeMixin --- README.md | 32 ++--- examples/predict/ner_span_classification.py | 4 +- examples/predict/re_generative.py | 6 +- examples/predict/re_text_classification.py | 6 +- src/pytorch_ie/core/__init__.py | 8 +- src/pytorch_ie/core/document.py | 44 +++---- src/pytorch_ie/core/metric.py | 4 +- src/pytorch_ie/core/module_mixins.py | 2 +- src/pytorch_ie/core/taskmodule.py | 4 +- src/pytorch_ie/documents.py | 18 +-- .../transformer_re_text_classification.py | 16 +-- tests/conftest.py | 8 +- tests/core/test_document.py | 8 +- tests/core/test_metric.py | 4 +- tests/metrics/test_f1.py | 4 +- .../pipeline/test_ner_span_classification.py | 4 +- tests/pipeline/test_re_generative.py | 6 +- tests/pipeline/test_re_text_classification.py | 6 +- ..._simple_transformer_text_classification.py | 4 +- tests/taskmodules/test_transformer_seq2seq.py | 6 +- .../test_transformer_token_classification.py | 6 +- tests/test_auto.py | 4 +- tests/test_document.py | 110 +++++++++--------- 23 files changed, 160 insertions(+), 154 deletions(-) diff --git a/README.md b/README.md index c5f4b489..68fbb981 100644 --- a/README.md +++ b/README.md @@ -85,16 +85,17 @@ elements: ```python from typing import Optional -from pytorch_ie.core import Document, AnnotationList, annotation_field +from pytorch_ie.core import Document, AnnotationLayer, annotation_field from pytorch_ie.annotations import LabeledSpan, BinaryRelation, Label + class MyDocument(Document): # data fields (any field that is targeted by an annotation fields) text: str # annotation fields - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - label: AnnotationList[Label] = annotation_field() + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + label: AnnotationLayer[Label] = annotation_field() # other fields doc_id: Optional[str] = None ``` @@ -147,7 +148,7 @@ The content of `self.target` is lazily assigned as soon as the annotation is add Note that this now expects a single `collections.abc.Sequence` as `target`, e.g.: ```python -my_spans: AnnotationList[Span] = annotation_field(target="") +my_spans: AnnotationLayer[Span] = annotation_field(target="") ``` If we have multiple targets, we need to define target names to access them. For this, we need to set the special @@ -178,7 +179,7 @@ class MyDocumentWithAlignment(Document): text_a: str text_b: str # `named_targets` defines the mapping from `TARGET_NAMES` to data fields - my_alignments: AnnotationList[Alignment] = annotation_field(named_targets={"text1": "text_a", "text2": "text_b"}) + my_alignments: AnnotationLayer[Alignment] = annotation_field(named_targets={"text1": "text_a", "text2": "text_b"}) ``` Note that `text1` and `text2` can also target the same field. @@ -319,12 +320,14 @@ from dataclasses import dataclass from pytorch_ie.annotations import LabeledSpan from pytorch_ie.auto import AutoPipeline -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument + @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + document = ExampleDocument( "“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio." @@ -390,14 +393,15 @@ from dataclasses import dataclass from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.auto import AutoPipeline -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + document = ExampleDocument( "“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio." @@ -550,7 +554,7 @@ print(dataset["train"][0]) # >>> CoNLL2003Document(text='EU rejects German call to boycott British lamb .', id='0', metadata={}) dataset["train"][0].entities -# >>> AnnotationList([LabeledSpan(start=0, end=2, label='ORG', score=1.0), LabeledSpan(start=11, end=17, label='MISC', score=1.0), LabeledSpan(start=34, end=41, label='MISC', score=1.0)]) +# >>> AnnotationLayer([LabeledSpan(start=0, end=2, label='ORG', score=1.0), LabeledSpan(start=11, end=17, label='MISC', score=1.0), LabeledSpan(start=34, end=41, label='MISC', score=1.0)]) entity = dataset["train"][0].entities[1] @@ -571,12 +575,12 @@ dataset from that, you have to implement: ```python @dataclass class CoNLL2003Document(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") ``` Here we derive from `TextDocument` that has a simple `text` string as base annotation target. The `CoNLL2003Document` adds one single annotation list called `entities` that consists of `LabeledSpan`s which reference the `text` field of -the document. You can add further annotation types by adding `AnnotationList` fields that may also reference (i.e. +the document. You can add further annotation types by adding `AnnotationLayer` fields that may also reference (i.e. `target`) other annotations as you like. See ['pytorch_ie.annotations`](src/pytorch_ie/annotations.py) for predefined annotation types. diff --git a/examples/predict/ner_span_classification.py b/examples/predict/ner_span_classification.py index f19ae15c..514cbe49 100644 --- a/examples/predict/ner_span_classification.py +++ b/examples/predict/ner_span_classification.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSpanClassificationModel from pytorch_ie.pipeline import Pipeline @@ -10,7 +10,7 @@ @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") def main(): diff --git a/examples/predict/re_generative.py b/examples/predict/re_generative.py index de0e656b..e56050ec 100644 --- a/examples/predict/re_generative.py +++ b/examples/predict/re_generative.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSeq2SeqModel from pytorch_ie.pipeline import Pipeline @@ -10,8 +10,8 @@ @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") def main(): diff --git a/examples/predict/re_text_classification.py b/examples/predict/re_text_classification.py index 01309c08..4ef34b17 100644 --- a/examples/predict/re_text_classification.py +++ b/examples/predict/re_text_classification.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerTextClassificationModel from pytorch_ie.pipeline import Pipeline @@ -10,8 +10,8 @@ @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") def main(): diff --git a/src/pytorch_ie/core/__init__.py b/src/pytorch_ie/core/__init__.py index 989047c8..b3462abd 100644 --- a/src/pytorch_ie/core/__init__.py +++ b/src/pytorch_ie/core/__init__.py @@ -1,6 +1,10 @@ -from .document import Annotation, AnnotationList, Document, annotation_field +from .document import Annotation, AnnotationLayer, Document, annotation_field from .metric import DocumentMetric from .model import PyTorchIEModel -from .module_mixins import RequiresDocumentTypeMixin +from .module_mixins import WithDocumentTypeMixin from .statistic import DocumentStatistic from .taskmodule import TaskEncoding, TaskModule + +# backwards compatibility +AnnotationList = AnnotationLayer +RequiresDocumentTypeMixin = WithDocumentTypeMixin diff --git a/src/pytorch_ie/core/document.py b/src/pytorch_ie/core/document.py index 0ad4b292..2a303eca 100644 --- a/src/pytorch_ie/core/document.py +++ b/src/pytorch_ie/core/document.py @@ -118,7 +118,7 @@ def _get_reference_fields_and_container_types( def _get_annotation_fields(fields: List[dataclasses.Field]) -> Set[dataclasses.Field]: - return {field for field in fields if typing.get_origin(field.type) is AnnotationList} + return {field for field in fields if typing.get_origin(field.type) is AnnotationLayer} def annotation_field( @@ -157,7 +157,7 @@ def annotation_field( # for now, we only have annotation lists and texts -TARGET_TYPE = Union["AnnotationList", str] +TARGET_TYPE = Union["AnnotationLayer", str] @dataclasses.dataclass(eq=True, frozen=True) @@ -426,15 +426,15 @@ def target(self) -> Any: return list(tgts.values())[0] @property - def target_layers(self) -> dict[str, "AnnotationList"]: + def target_layers(self) -> dict[str, "AnnotationLayer"]: return { target_name: target for target_name, target in self.targets.items() - if isinstance(target, AnnotationList) + if isinstance(target, AnnotationLayer) } @property - def target_layer(self) -> "AnnotationList": + def target_layer(self) -> "AnnotationLayer": tgt_layers = self.target_layers if len(tgt_layers) != 1: raise ValueError( @@ -443,7 +443,7 @@ def target_layer(self) -> "AnnotationList": return list(tgt_layers.values())[0] -class AnnotationList(BaseAnnotationList[T]): +class AnnotationLayer(BaseAnnotationList[T]): def __init__(self, document: "Document", targets: List["str"]): super().__init__(document=document, targets=targets) self._predictions: BaseAnnotationList[T] = BaseAnnotationList(document, targets=targets) @@ -453,13 +453,13 @@ def predictions(self) -> BaseAnnotationList[T]: return self._predictions def __eq__(self, other: object) -> bool: - if not isinstance(other, AnnotationList): + if not isinstance(other, AnnotationLayer): return NotImplemented return super().__eq__(other) and self.predictions == other.predictions def __repr__(self) -> str: - return f"AnnotationList({str(self._annotations)})" + return f"AnnotationLayer({str(self._annotations)})" D = TypeVar("D", bound="Document") @@ -485,7 +485,7 @@ def fields(cls): def annotation_fields(cls): return _get_annotation_fields(list(dataclasses.fields(cls))) - def __getitem__(self, key: str) -> AnnotationList: + def __getitem__(self, key: str) -> AnnotationLayer: if key not in self._annotation_fields: raise KeyError(f"Document has no attribute '{key}'.") return getattr(self, key) @@ -505,7 +505,7 @@ def __post_init__(self): field_origin = typing.get_origin(field.type) - if field_origin is AnnotationList: + if field_origin is AnnotationLayer: self._annotation_fields.add(field.name) targets = field.metadata.get("targets") @@ -519,7 +519,7 @@ def __post_init__(self): f'annotation target "{target}" is not in field names of the document: {field_names}' ) - # check annotation target names and use them together with target names from the AnnotationList + # check annotation target names and use them together with target names from the AnnotationLayer # to reorder targets, if available target_names = field.metadata.get("target_names") annotation_type = typing.get_args(field.type)[0] @@ -547,8 +547,8 @@ def __post_init__(self): # disallow multiple targets when target names are specified in the definition of the Annotation if len(annotation_target_names) > 1: raise TypeError( - f"A target name mapping is required for AnnotationLists containing Annotations with " - f'TARGET_NAMES, but AnnotationList "{field.name}" has no target_names. You should ' + f"A target name mapping is required for AnnotationLayers containing Annotations with " + f'TARGET_NAMES, but AnnotationLayer "{field.name}" has no target_names. You should ' f"pass the named_targets dict containing the following keys (see Annotation " f'"{annotation_type.__name__}") to annotation_field: {annotation_target_names}' ) @@ -559,7 +559,7 @@ def __post_init__(self): if "_artificial_root" in self._annotation_graph: raise ValueError( 'Failed to add the "_artificial_root" node to the annotation graph because it already exists. Note ' - "that AnnotationList entries with that name are not allowed." + "that AnnotationLayer entries with that name are not allowed." ) self._annotation_graph["_artificial_root"] = list(self._annotation_fields - targeted) @@ -568,7 +568,7 @@ def asdict(self): for field in self.fields(): value = getattr(self, field.name) - if isinstance(value, AnnotationList): + if isinstance(value, AnnotationLayer): dct[field.name] = { "annotations": [v.asdict() for v in value], "predictions": [v.asdict() for v in value.predictions], @@ -621,7 +621,7 @@ def fromdict(cls, dct): continue # TODO: handle single annotations, e.g. a document-level label - if typing.get_origin(field.type) is AnnotationList: + if typing.get_origin(field.type) is AnnotationLayer: annotation_class = typing.get_args(field.type)[0] # build annotations for annotation_data in value["annotations"]: @@ -718,15 +718,15 @@ class Attribute(Annotation): @dataclasses.dataclass class TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + relation_attributes: AnnotationLayer[Attribute] = annotation_field(target="relations") @dataclasses.dataclass class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + relation_attributes: AnnotationLayer[Attribute] = annotation_field(target="relations") doc_text = TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(text="Hello World!") diff --git a/src/pytorch_ie/core/metric.py b/src/pytorch_ie/core/metric.py index 54febd85..fe261d46 100644 --- a/src/pytorch_ie/core/metric.py +++ b/src/pytorch_ie/core/metric.py @@ -2,12 +2,12 @@ from typing import Dict, Generic, Iterable, Optional, TypeVar, Union from pytorch_ie.core.document import Document -from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin +from pytorch_ie.core.module_mixins import WithDocumentTypeMixin T = TypeVar("T") -class DocumentMetric(ABC, RequiresDocumentTypeMixin, Generic[T]): +class DocumentMetric(ABC, WithDocumentTypeMixin, Generic[T]): """This defines the interface for a document metric.""" def __init__(self): diff --git a/src/pytorch_ie/core/module_mixins.py b/src/pytorch_ie/core/module_mixins.py index be6f51ab..7811947a 100644 --- a/src/pytorch_ie/core/module_mixins.py +++ b/src/pytorch_ie/core/module_mixins.py @@ -6,7 +6,7 @@ logger = logging.getLogger(__name__) -class RequiresDocumentTypeMixin: +class WithDocumentTypeMixin: DOCUMENT_TYPE: Optional[Type[Document]] = None diff --git a/src/pytorch_ie/core/taskmodule.py b/src/pytorch_ie/core/taskmodule.py index 5e415f76..64314346 100644 --- a/src/pytorch_ie/core/taskmodule.py +++ b/src/pytorch_ie/core/taskmodule.py @@ -10,7 +10,7 @@ from pytorch_ie.core.document import Annotation, Document from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin -from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin +from pytorch_ie.core.module_mixins import WithDocumentTypeMixin from pytorch_ie.core.registrable import Registrable """ @@ -133,7 +133,7 @@ class TaskModule( PieTaskModuleHFHubMixin, HyperparametersMixin, Registrable, - RequiresDocumentTypeMixin, + WithDocumentTypeMixin, Generic[ DocumentType, InputEncoding, diff --git a/src/pytorch_ie/documents.py b/src/pytorch_ie/documents.py index 160e5b9b..6f83b5ed 100644 --- a/src/pytorch_ie/documents.py +++ b/src/pytorch_ie/documents.py @@ -4,7 +4,7 @@ from typing_extensions import TypeAlias from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, MultiLabel, Span -from pytorch_ie.core import AnnotationList, Document, annotation_field +from pytorch_ie.core import AnnotationLayer, Document, annotation_field @dataclasses.dataclass @@ -39,12 +39,12 @@ class TokenBasedDocument(WithMetadata, WithTokens, Document): @dataclasses.dataclass class DocumentWithLabel(Document): - label: AnnotationList[Label] = annotation_field() + label: AnnotationLayer[Label] = annotation_field() @dataclasses.dataclass class DocumentWithMultiLabel(Document): - label: AnnotationList[MultiLabel] = annotation_field() + label: AnnotationLayer[MultiLabel] = annotation_field() @dataclasses.dataclass @@ -59,22 +59,22 @@ class TextDocumentWithMultiLabel(DocumentWithMultiLabel, TextBasedDocument): @dataclasses.dataclass class TextDocumentWithLabeledPartitions(TextBasedDocument): - labeled_partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") + labeled_partitions: AnnotationLayer[LabeledSpan] = annotation_field(target="text") @dataclasses.dataclass class TextDocumentWithSentences(TextBasedDocument): - sentences: AnnotationList[Span] = annotation_field(target="text") + sentences: AnnotationLayer[Span] = annotation_field(target="text") @dataclasses.dataclass class TextDocumentWithSpans(TextBasedDocument): - spans: AnnotationList[Span] = annotation_field(target="text") + spans: AnnotationLayer[Span] = annotation_field(target="text") @dataclasses.dataclass class TextDocumentWithLabeledSpans(TextBasedDocument): - labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="text") + labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") @dataclasses.dataclass @@ -93,7 +93,7 @@ class TextDocumentWithLabeledSpansAndSentences( @dataclasses.dataclass class TextDocumentWithLabeledSpansAndBinaryRelations(TextDocumentWithLabeledSpans): - binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans") @dataclasses.dataclass @@ -107,7 +107,7 @@ class TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( @dataclasses.dataclass class TextDocumentWithSpansAndBinaryRelations(TextDocumentWithSpans): - binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="spans") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans") @dataclasses.dataclass diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index e25120d7..5fdaca20 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -36,7 +36,7 @@ NaryRelation, Span, ) -from pytorch_ie.core import AnnotationList, Document, TaskEncoding, TaskModule +from pytorch_ie.core import AnnotationLayer, Document, TaskEncoding, TaskModule from pytorch_ie.documents import ( TextDocument, TextDocumentWithLabeledSpansAndBinaryRelations, @@ -232,11 +232,11 @@ def document_type(self) -> Optional[Type[TextDocument]]: ) return None - def get_relation_layer(self, document: Document) -> AnnotationList[BinaryRelation]: + def get_relation_layer(self, document: Document) -> AnnotationLayer[BinaryRelation]: return document[self.relation_annotation] - def get_entity_layer(self, document: Document) -> AnnotationList[LabeledSpan]: - relations: AnnotationList[BinaryRelation] = self.get_relation_layer(document) + def get_entity_layer(self, document: Document) -> AnnotationLayer[LabeledSpan]: + relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) if len(relations._targets) != 1: raise Exception( f"the relation layer is expected to target exactly one entity layer, but it has " @@ -249,8 +249,8 @@ def _prepare(self, documents: Sequence[TextDocument]) -> None: entity_labels: Set[str] = set() relation_labels: Set[str] = set() for document in documents: - relations: AnnotationList[BinaryRelation] = self.get_relation_layer(document) - entities: AnnotationList[LabeledSpan] = self.get_entity_layer(document) + relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) + entities: AnnotationLayer[LabeledSpan] = self.get_entity_layer(document) for entity in entities: entity_labels.add(entity.label) @@ -303,8 +303,8 @@ def _create_relation_candidates( document: Document, ) -> List[BinaryRelation]: relation_candidates: List[BinaryRelation] = [] - relations: AnnotationList[BinaryRelation] = self.get_relation_layer(document) - entities: AnnotationList[LabeledSpan] = self.get_entity_layer(document) + relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) + entities: AnnotationLayer[LabeledSpan] = self.get_entity_layer(document) arguments_to_relation = {(rel.head, rel.tail): rel for rel in relations} # iterate over all possible argument candidates for head in entities: diff --git a/tests/conftest.py b/tests/conftest.py index de86b55e..0856c53e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,16 +4,16 @@ import pytest from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from tests import FIXTURES_ROOT @dataclasses.dataclass class TestDocument(TextDocument): - sentences: AnnotationList[Span] = annotation_field(target="text") - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + sentences: AnnotationLayer[Span] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") def example_to_doc_dict(example): diff --git a/tests/core/test_document.py b/tests/core/test_document.py index 3236f0d9..5512e53e 100644 --- a/tests/core/test_document.py +++ b/tests/core/test_document.py @@ -6,7 +6,7 @@ from pytorch_ie.core import Annotation from pytorch_ie.core.document import ( - AnnotationList, + AnnotationLayer, Document, _contains_annotation_type, _get_reference_fields_and_container_types, @@ -269,7 +269,7 @@ def test_annotation_is_attached(): @dataclasses.dataclass class MyDocument(Document): text: str - words: AnnotationList[Span] = annotation_field(target="text") + words: AnnotationLayer[Span] = annotation_field(target="text") document = MyDocument(text="Hello world!") word = Span(start=0, end=5) @@ -292,8 +292,8 @@ def __repr__(self): @dataclasses.dataclass class MyDocument(Document): text: str - words: AnnotationList[Span] = annotation_field(target="text") - attributes: AnnotationList[Attribute] = annotation_field(target="words") + words: AnnotationLayer[Span] = annotation_field(target="text") + attributes: AnnotationLayer[Attribute] = annotation_field(target="words") document = MyDocument(text="Hello world!") word = Span(start=0, end=5) diff --git a/tests/core/test_metric.py b/tests/core/test_metric.py index 987492e0..6041c121 100644 --- a/tests/core/test_metric.py +++ b/tests/core/test_metric.py @@ -4,7 +4,7 @@ import pytest from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, Document, DocumentMetric, annotation_field +from pytorch_ie.core import AnnotationLayer, Document, DocumentMetric, annotation_field from pytorch_ie.documents import TextBasedDocument @@ -12,7 +12,7 @@ def documents(): @dataclass class TextDocumentWithEntities(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") # a test sentence with two entities doc1 = TextDocumentWithEntities( diff --git a/tests/metrics/test_f1.py b/tests/metrics/test_f1.py index 5c7b570c..3c7cd657 100644 --- a/tests/metrics/test_f1.py +++ b/tests/metrics/test_f1.py @@ -3,7 +3,7 @@ import pytest from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextBasedDocument from pytorch_ie.metrics import F1Metric @@ -12,7 +12,7 @@ def documents(): @dataclass class TextDocumentWithEntities(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") # a test sentence with two entities doc1 = TextDocumentWithEntities( diff --git a/tests/pipeline/test_ner_span_classification.py b/tests/pipeline/test_ner_span_classification.py index 74b69665..4c4a3752 100644 --- a/tests/pipeline/test_ner_span_classification.py +++ b/tests/pipeline/test_ner_span_classification.py @@ -4,7 +4,7 @@ from pytorch_ie import AutoPipeline from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSpanClassificationModel from pytorch_ie.pipeline import Pipeline @@ -13,7 +13,7 @@ @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") @pytest.mark.slow diff --git a/tests/pipeline/test_re_generative.py b/tests/pipeline/test_re_generative.py index 87899b59..988c2e77 100644 --- a/tests/pipeline/test_re_generative.py +++ b/tests/pipeline/test_re_generative.py @@ -3,7 +3,7 @@ import pytest from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSeq2SeqModel from pytorch_ie.pipeline import Pipeline @@ -12,8 +12,8 @@ @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") @pytest.mark.slow diff --git a/tests/pipeline/test_re_text_classification.py b/tests/pipeline/test_re_text_classification.py index 4192c491..35dfb82e 100644 --- a/tests/pipeline/test_re_text_classification.py +++ b/tests/pipeline/test_re_text_classification.py @@ -5,7 +5,7 @@ from pytorch_ie import AutoPipeline from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerTextClassificationModel from pytorch_ie.pipeline import Pipeline @@ -14,8 +14,8 @@ @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") @pytest.mark.slow diff --git a/tests/taskmodules/test_simple_transformer_text_classification.py b/tests/taskmodules/test_simple_transformer_text_classification.py index 8da9f327..3ce48bc7 100644 --- a/tests/taskmodules/test_simple_transformer_text_classification.py +++ b/tests/taskmodules/test_simple_transformer_text_classification.py @@ -8,7 +8,7 @@ from pytorch_ie import SimpleTransformerTextClassificationTaskModule from pytorch_ie.annotations import Label -from pytorch_ie.core import AnnotationList, Document, annotation_field +from pytorch_ie.core import AnnotationLayer, Document, annotation_field def _config_to_str(cfg: Dict[str, Any]) -> str: @@ -55,7 +55,7 @@ def test_taskmodule(unprepared_taskmodule): @dataclass class ExampleDocument(Document): text: str - label: AnnotationList[Label] = annotation_field() + label: AnnotationLayer[Label] = annotation_field() @pytest.fixture(scope="module") diff --git a/tests/taskmodules/test_transformer_seq2seq.py b/tests/taskmodules/test_transformer_seq2seq.py index 39189f4f..f652fc00 100644 --- a/tests/taskmodules/test_transformer_seq2seq.py +++ b/tests/taskmodules/test_transformer_seq2seq.py @@ -7,7 +7,7 @@ from pytorch_ie import TransformerSeq2SeqTaskModule from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.documents import TextDocument @@ -30,8 +30,8 @@ def test_taskmodule(taskmodule): @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") @pytest.fixture(scope="module") diff --git a/tests/taskmodules/test_transformer_token_classification.py b/tests/taskmodules/test_transformer_token_classification.py index e2805b9f..635aa98d 100644 --- a/tests/taskmodules/test_transformer_token_classification.py +++ b/tests/taskmodules/test_transformer_token_classification.py @@ -8,7 +8,7 @@ from pytorch_ie import TransformerTokenClassificationTaskModule from pytorch_ie.annotations import LabeledSpan, Span -from pytorch_ie.core import AnnotationList, Document, annotation_field +from pytorch_ie.core import AnnotationLayer, Document, annotation_field def _config_to_str(cfg: Dict[str, Any]) -> str: @@ -58,8 +58,8 @@ def unprepared_taskmodule(config): @dataclass class ExampleDocument(Document): text: str - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - sentences: AnnotationList[Span] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + sentences: AnnotationLayer[Span] = annotation_field(target="text") @pytest.fixture(scope="module") diff --git a/tests/test_auto.py b/tests/test_auto.py index 12ec2220..487b6562 100644 --- a/tests/test_auto.py +++ b/tests/test_auto.py @@ -4,7 +4,7 @@ from pytorch_ie.annotations import LabeledSpan from pytorch_ie.auto import AutoModel, AutoPipeline, AutoTaskModule -from pytorch_ie.core import AnnotationList, TaskModule, annotation_field +from pytorch_ie.core import AnnotationLayer, TaskModule, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSpanClassificationModel from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule @@ -61,7 +61,7 @@ def test_auto_model(): def test_auto_pipeline(): @dataclass class ExampleDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") pipeline = AutoPipeline.from_pretrained("pie/example-ner-spanclf-conll03") diff --git a/tests/test_document.py b/tests/test_document.py index c0381bb7..f1e3432a 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -5,7 +5,7 @@ import pytest from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span -from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core import AnnotationLayer, annotation_field from pytorch_ie.core.document import Annotation, Document, _enumerate_dependencies from pytorch_ie.documents import TextDocument, TokenBasedDocument @@ -42,15 +42,15 @@ def test_text_document(): def test_document_with_annotations(): @dataclasses.dataclass class TestDocument(TextDocument): - sentences: AnnotationList[Span] = annotation_field(target="text") - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - label: AnnotationList[Label] = annotation_field() + sentences: AnnotationLayer[Span] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + label: AnnotationLayer[Label] = annotation_field() document1 = TestDocument(text="test1") - assert isinstance(document1.sentences, AnnotationList) - assert isinstance(document1.entities, AnnotationList) - assert isinstance(document1.relations, AnnotationList) + assert isinstance(document1.sentences, AnnotationLayer) + assert isinstance(document1.entities, AnnotationLayer) + assert isinstance(document1.relations, AnnotationLayer) assert len(document1.sentences) == 0 assert len(document1.entities) == 0 assert len(document1.relations) == 0 @@ -141,10 +141,10 @@ class TestDocument(Document): text: str text2: str text3: str - tokens0: AnnotationList[Span] = annotation_field(target="text") - tokens1: AnnotationList[Span] = annotation_field(target="text") - tokens2: AnnotationList[Span] = annotation_field(target="text2") - tokens3: AnnotationList[Span] = annotation_field(target="text3") + tokens0: AnnotationLayer[Span] = annotation_field(target="text") + tokens1: AnnotationLayer[Span] = annotation_field(target="text") + tokens2: AnnotationLayer[Span] = annotation_field(target="text2") + tokens3: AnnotationLayer[Span] = annotation_field(target="text3") doc = TestDocument(text="test1", text2="test1", text3="test2") start = 0 @@ -175,18 +175,18 @@ class TestDocument(Document): def test_as_type(): @dataclasses.dataclass class TestDocument1(TextDocument): - sentences: AnnotationList[Span] = annotation_field(target="text") - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + sentences: AnnotationLayer[Span] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") @dataclasses.dataclass class TestDocument2(TextDocument): - sentences: AnnotationList[Span] = annotation_field(target="text") - ents: AnnotationList[LabeledSpan] = annotation_field(target="text") + sentences: AnnotationLayer[Span] = annotation_field(target="text") + ents: AnnotationLayer[LabeledSpan] = annotation_field(target="text") @dataclasses.dataclass class TestDocument3(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") # create input document with "sentences" and "relations" document1 = TestDocument1(text="test1") @@ -238,7 +238,7 @@ def test_enumerate_dependencies_with_circle(): def test_annotation_list_wrong_target(): @dataclasses.dataclass class TestDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="does_not_exist") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="does_not_exist") with pytest.raises( TypeError, @@ -252,7 +252,7 @@ class TestDocument(TextDocument): def test_annotation_list(): @dataclasses.dataclass class TestDocument(TextDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") document = TestDocument(text="Entity A works at B.") @@ -272,7 +272,7 @@ class TestDocument(TextDocument): document.entities.predictions.append(entity3) document.entities.predictions.append(entity4) - assert isinstance(document.entities, AnnotationList) + assert isinstance(document.entities, AnnotationLayer) assert len(document.entities) == 2 assert document.entities[0] == entity1 assert document.entities[1] == entity2 @@ -307,12 +307,12 @@ class TestDocument(TextDocument): def test_annotation_list_with_multiple_targets(): @dataclasses.dataclass class TestDocument(TextDocument): - entities1: AnnotationList[LabeledSpan] = annotation_field(target="text") - entities2: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field( + entities1: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + entities2: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field( targets=["entities1", "entities2"] ) - label: AnnotationList[Label] = annotation_field() + label: AnnotationLayer[Label] = annotation_field() doc = TestDocument(text="test1") @@ -386,10 +386,10 @@ def test_annotation_list_with_named_targets(): class TestDocument(Document): texta: str textb: str - entities1: AnnotationList[LabeledSpan] = annotation_field(target="texta") - entities2: AnnotationList[LabeledSpan] = annotation_field(target="textb") + entities1: AnnotationLayer[LabeledSpan] = annotation_field(target="texta") + entities2: AnnotationLayer[LabeledSpan] = annotation_field(target="textb") # note that the entries in targets do not follow the order of DoubleTextSpan.TARGET_NAMES - crossrefs: AnnotationList[DoubleTextSpan] = annotation_field( + crossrefs: AnnotationLayer[DoubleTextSpan] = annotation_field( named_targets={"text2": "textb", "text1": "texta"} ) @@ -446,7 +446,7 @@ def __str__(self) -> str: @dataclasses.dataclass class TestDocument(Document): text: str - entities1: AnnotationList[TextSpan] = annotation_field(named_targets={"textx": "text"}) + entities1: AnnotationLayer[TextSpan] = annotation_field(named_targets={"textx": "text"}) with pytest.raises( TypeError, @@ -461,17 +461,15 @@ class TestDocument(Document): texta: str textb: str # note that the entries in targets do not follow the order of DoubleTextSpan.TARGET_NAMES - crossrefs: AnnotationList[DoubleTextSpan] = annotation_field(targets=["textb", "texta"]) + crossrefs: AnnotationLayer[DoubleTextSpan] = annotation_field(targets=["textb", "texta"]) - with pytest.raises( - TypeError, - match=re.escape( - "A target name mapping is required for AnnotationLists containing Annotations with TARGET_NAMES, but " - 'AnnotationList "crossrefs" has no target_names. You should pass the named_targets dict containing the ' - "following keys (see Annotation \"DoubleTextSpan\") to annotation_field: ('text1', 'text2')" - ), - ): + with pytest.raises(TypeError) as excinfo: doc = TestDocument(texta="text1", textb="text2") + assert str(excinfo.value) == ( + "A target name mapping is required for AnnotationLayers containing Annotations with TARGET_NAMES, but " + 'AnnotationLayer "crossrefs" has no target_names. You should pass the named_targets dict containing the ' + "following keys (see Annotation \"DoubleTextSpan\") to annotation_field: ('text1', 'text2')" + ) def test_annotation_list_number_of_targets_mismatch_error(): @@ -479,7 +477,7 @@ def test_annotation_list_number_of_targets_mismatch_error(): class TestDocument(Document): texta: str textb: str - crossrefs: AnnotationList[DoubleTextSpan] = annotation_field(target="texta") + crossrefs: AnnotationLayer[DoubleTextSpan] = annotation_field(target="texta") with pytest.raises( TypeError, @@ -495,13 +493,13 @@ def test_annotation_list_artificial_root_error(): @dataclasses.dataclass class TestDocument(Document): text: str - _artificial_root: AnnotationList[LabeledSpan] = annotation_field(target="text") + _artificial_root: AnnotationLayer[LabeledSpan] = annotation_field(target="text") with pytest.raises( ValueError, match=re.escape( 'Failed to add the "_artificial_root" node to the annotation graph because it already exists. Note ' - "that AnnotationList entries with that name are not allowed." + "that AnnotationLayer entries with that name are not allowed." ), ): doc = TestDocument(text="text1") @@ -511,10 +509,10 @@ def test_annotation_list_targets(): @dataclasses.dataclass class TestDocument(Document): text: str - entities1: AnnotationList[LabeledSpan] = annotation_field(target="text") - entities2: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations1: AnnotationList[BinaryRelation] = annotation_field(target="entities1") - relations2: AnnotationList[BinaryRelation] = annotation_field( + entities1: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + entities2: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations1: AnnotationLayer[BinaryRelation] = annotation_field(target="entities1") + relations2: AnnotationLayer[BinaryRelation] = annotation_field( targets=["entities1", "entities2"] ) @@ -584,7 +582,7 @@ class TestAnnotation(Annotation): # assert that nothing changes when adding the annotation to a document @dataclasses.dataclass class TestDocument(TextDocument): - annotations: AnnotationList[TestAnnotation] = annotation_field(target="text") + annotations: AnnotationLayer[TestAnnotation] = annotation_field(target="text") id0 = annotation0._id hash0 = hash(annotation0) @@ -619,13 +617,13 @@ class Attribute(Annotation): def text_document(): @dataclasses.dataclass class TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(TextDocument): - entities1: AnnotationList[LabeledSpan] = annotation_field(target="text") - entities2: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field( + entities1: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + entities2: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field( targets=["entities1", "entities2"] ) - labels: AnnotationList[Label] = annotation_field() - relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations") + labels: AnnotationLayer[Label] = annotation_field() + relation_attributes: AnnotationLayer[Attribute] = annotation_field(target="relations") doc1 = TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(text="Hello World!") e1 = LabeledSpan(0, 5, "word1") @@ -660,13 +658,13 @@ def test_document_extend_from_other_wrong_override_annotation_mapping(text_docum def test_document_extend_from_other_override(text_document): @dataclasses.dataclass class TestDocument2(TokenBasedDocument): - entities1: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - entities2: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - relations: AnnotationList[BinaryRelation] = annotation_field( + entities1: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") + entities2: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") + relations: AnnotationLayer[BinaryRelation] = annotation_field( targets=["entities1", "entities2"] ) - labels: AnnotationList[Label] = annotation_field() - relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations") + labels: AnnotationLayer[Label] = annotation_field() + relation_attributes: AnnotationLayer[Attribute] = annotation_field(target="relations") token_document = TestDocument2(tokens=("Hello", "World", "!")) # create new entities