Skip to content

Commit

Permalink
rename core classes (#369)
Browse files Browse the repository at this point in the history
* rename AnnotationList to AnnotationLayer

* rename AnnotationList to AnnotationLayer in README.md

* rename RequiresDocumentTypeMixin to WithDocumentTypeMixin
  • Loading branch information
ArneBinder authored Nov 8, 2023
1 parent d35e983 commit ef4c574
Show file tree
Hide file tree
Showing 23 changed files with 160 additions and 154 deletions.
32 changes: 18 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,17 @@ elements:

```python
from typing import Optional
from pytorch_ie.core import Document, AnnotationList, annotation_field
from pytorch_ie.core import Document, AnnotationLayer, annotation_field
from pytorch_ie.annotations import LabeledSpan, BinaryRelation, Label


class MyDocument(Document):
# data fields (any field that is targeted by an annotation fields)
text: str
# annotation fields
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
label: AnnotationList[Label] = annotation_field()
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
label: AnnotationLayer[Label] = annotation_field()
# other fields
doc_id: Optional[str] = None
```
Expand Down Expand Up @@ -147,7 +148,7 @@ The content of `self.target` is lazily assigned as soon as the annotation is add
Note that this now expects a single `collections.abc.Sequence` as `target`, e.g.:

```python
my_spans: AnnotationList[Span] = annotation_field(target="<NAME_OF_THE_SEQUENCE_FIELD>")
my_spans: AnnotationLayer[Span] = annotation_field(target="<NAME_OF_THE_SEQUENCE_FIELD>")
```

If we have multiple targets, we need to define target names to access them. For this, we need to set the special
Expand Down Expand Up @@ -178,7 +179,7 @@ class MyDocumentWithAlignment(Document):
text_a: str
text_b: str
# `named_targets` defines the mapping from `TARGET_NAMES` to data fields
my_alignments: AnnotationList[Alignment] = annotation_field(named_targets={"text1": "text_a", "text2": "text_b"})
my_alignments: AnnotationLayer[Alignment] = annotation_field(named_targets={"text1": "text_a", "text2": "text_b"})
```

Note that `text1` and `text2` can also target the same field.
Expand Down Expand Up @@ -319,12 +320,14 @@ from dataclasses import dataclass

from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextDocument


@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")


document = ExampleDocument(
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."
Expand Down Expand Up @@ -390,14 +393,15 @@ from dataclasses import dataclass

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextDocument


@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")


document = ExampleDocument(
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."
Expand Down Expand Up @@ -550,7 +554,7 @@ print(dataset["train"][0])
# >>> CoNLL2003Document(text='EU rejects German call to boycott British lamb .', id='0', metadata={})

dataset["train"][0].entities
# >>> AnnotationList([LabeledSpan(start=0, end=2, label='ORG', score=1.0), LabeledSpan(start=11, end=17, label='MISC', score=1.0), LabeledSpan(start=34, end=41, label='MISC', score=1.0)])
# >>> AnnotationLayer([LabeledSpan(start=0, end=2, label='ORG', score=1.0), LabeledSpan(start=11, end=17, label='MISC', score=1.0), LabeledSpan(start=34, end=41, label='MISC', score=1.0)])

entity = dataset["train"][0].entities[1]

Expand All @@ -571,12 +575,12 @@ dataset from that, you have to implement:
```python
@dataclass
class CoNLL2003Document(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
```

Here we derive from `TextDocument` that has a simple `text` string as base annotation target. The `CoNLL2003Document`
adds one single annotation list called `entities` that consists of `LabeledSpan`s which reference the `text` field of
the document. You can add further annotation types by adding `AnnotationList` fields that may also reference (i.e.
the document. You can add further annotation types by adding `AnnotationLayer` fields that may also reference (i.e.
`target`) other annotations as you like. See ['pytorch_ie.annotations`](src/pytorch_ie/annotations.py) for predefined
annotation types.

Expand Down
4 changes: 2 additions & 2 deletions examples/predict/ner_span_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import TransformerSpanClassificationModel
from pytorch_ie.pipeline import Pipeline
Expand All @@ -10,7 +10,7 @@

@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")


def main():
Expand Down
6 changes: 3 additions & 3 deletions examples/predict/re_generative.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import TransformerSeq2SeqModel
from pytorch_ie.pipeline import Pipeline
Expand All @@ -10,8 +10,8 @@

@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")


def main():
Expand Down
6 changes: 3 additions & 3 deletions examples/predict/re_text_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import TransformerTextClassificationModel
from pytorch_ie.pipeline import Pipeline
Expand All @@ -10,8 +10,8 @@

@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")


def main():
Expand Down
8 changes: 6 additions & 2 deletions src/pytorch_ie/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .document import Annotation, AnnotationList, Document, annotation_field
from .document import Annotation, AnnotationLayer, Document, annotation_field
from .metric import DocumentMetric
from .model import PyTorchIEModel
from .module_mixins import RequiresDocumentTypeMixin
from .module_mixins import WithDocumentTypeMixin
from .statistic import DocumentStatistic
from .taskmodule import TaskEncoding, TaskModule

# backwards compatibility
AnnotationList = AnnotationLayer
RequiresDocumentTypeMixin = WithDocumentTypeMixin
44 changes: 22 additions & 22 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _get_reference_fields_and_container_types(


def _get_annotation_fields(fields: List[dataclasses.Field]) -> Set[dataclasses.Field]:
return {field for field in fields if typing.get_origin(field.type) is AnnotationList}
return {field for field in fields if typing.get_origin(field.type) is AnnotationLayer}


def annotation_field(
Expand Down Expand Up @@ -157,7 +157,7 @@ def annotation_field(


# for now, we only have annotation lists and texts
TARGET_TYPE = Union["AnnotationList", str]
TARGET_TYPE = Union["AnnotationLayer", str]


@dataclasses.dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -426,15 +426,15 @@ def target(self) -> Any:
return list(tgts.values())[0]

@property
def target_layers(self) -> dict[str, "AnnotationList"]:
def target_layers(self) -> dict[str, "AnnotationLayer"]:
return {
target_name: target
for target_name, target in self.targets.items()
if isinstance(target, AnnotationList)
if isinstance(target, AnnotationLayer)
}

@property
def target_layer(self) -> "AnnotationList":
def target_layer(self) -> "AnnotationLayer":
tgt_layers = self.target_layers
if len(tgt_layers) != 1:
raise ValueError(
Expand All @@ -443,7 +443,7 @@ def target_layer(self) -> "AnnotationList":
return list(tgt_layers.values())[0]


class AnnotationList(BaseAnnotationList[T]):
class AnnotationLayer(BaseAnnotationList[T]):
def __init__(self, document: "Document", targets: List["str"]):
super().__init__(document=document, targets=targets)
self._predictions: BaseAnnotationList[T] = BaseAnnotationList(document, targets=targets)
Expand All @@ -453,13 +453,13 @@ def predictions(self) -> BaseAnnotationList[T]:
return self._predictions

def __eq__(self, other: object) -> bool:
if not isinstance(other, AnnotationList):
if not isinstance(other, AnnotationLayer):
return NotImplemented

return super().__eq__(other) and self.predictions == other.predictions

def __repr__(self) -> str:
return f"AnnotationList({str(self._annotations)})"
return f"AnnotationLayer({str(self._annotations)})"


D = TypeVar("D", bound="Document")
Expand All @@ -485,7 +485,7 @@ def fields(cls):
def annotation_fields(cls):
return _get_annotation_fields(list(dataclasses.fields(cls)))

def __getitem__(self, key: str) -> AnnotationList:
def __getitem__(self, key: str) -> AnnotationLayer:
if key not in self._annotation_fields:
raise KeyError(f"Document has no attribute '{key}'.")
return getattr(self, key)
Expand All @@ -505,7 +505,7 @@ def __post_init__(self):

field_origin = typing.get_origin(field.type)

if field_origin is AnnotationList:
if field_origin is AnnotationLayer:
self._annotation_fields.add(field.name)

targets = field.metadata.get("targets")
Expand All @@ -519,7 +519,7 @@ def __post_init__(self):
f'annotation target "{target}" is not in field names of the document: {field_names}'
)

# check annotation target names and use them together with target names from the AnnotationList
# check annotation target names and use them together with target names from the AnnotationLayer
# to reorder targets, if available
target_names = field.metadata.get("target_names")
annotation_type = typing.get_args(field.type)[0]
Expand Down Expand Up @@ -547,8 +547,8 @@ def __post_init__(self):
# disallow multiple targets when target names are specified in the definition of the Annotation
if len(annotation_target_names) > 1:
raise TypeError(
f"A target name mapping is required for AnnotationLists containing Annotations with "
f'TARGET_NAMES, but AnnotationList "{field.name}" has no target_names. You should '
f"A target name mapping is required for AnnotationLayers containing Annotations with "
f'TARGET_NAMES, but AnnotationLayer "{field.name}" has no target_names. You should '
f"pass the named_targets dict containing the following keys (see Annotation "
f'"{annotation_type.__name__}") to annotation_field: {annotation_target_names}'
)
Expand All @@ -559,7 +559,7 @@ def __post_init__(self):
if "_artificial_root" in self._annotation_graph:
raise ValueError(
'Failed to add the "_artificial_root" node to the annotation graph because it already exists. Note '
"that AnnotationList entries with that name are not allowed."
"that AnnotationLayer entries with that name are not allowed."
)
self._annotation_graph["_artificial_root"] = list(self._annotation_fields - targeted)

Expand All @@ -568,7 +568,7 @@ def asdict(self):
for field in self.fields():
value = getattr(self, field.name)

if isinstance(value, AnnotationList):
if isinstance(value, AnnotationLayer):
dct[field.name] = {
"annotations": [v.asdict() for v in value],
"predictions": [v.asdict() for v in value.predictions],
Expand Down Expand Up @@ -621,7 +621,7 @@ def fromdict(cls, dct):
continue

# TODO: handle single annotations, e.g. a document-level label
if typing.get_origin(field.type) is AnnotationList:
if typing.get_origin(field.type) is AnnotationLayer:
annotation_class = typing.get_args(field.type)[0]
# build annotations
for annotation_data in value["annotations"]:
Expand Down Expand Up @@ -718,15 +718,15 @@ class Attribute(Annotation):
@dataclasses.dataclass
class TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
relation_attributes: AnnotationLayer[Attribute] = annotation_field(target="relations")
@dataclasses.dataclass
class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
relation_attributes: AnnotationLayer[Attribute] = annotation_field(target="relations")
doc_text = TextBasedDocumentWithEntitiesRelationsAndRelationAttributes(text="Hello World!")
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_ie/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Dict, Generic, Iterable, Optional, TypeVar, Union

from pytorch_ie.core.document import Document
from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin
from pytorch_ie.core.module_mixins import WithDocumentTypeMixin

T = TypeVar("T")


class DocumentMetric(ABC, RequiresDocumentTypeMixin, Generic[T]):
class DocumentMetric(ABC, WithDocumentTypeMixin, Generic[T]):
"""This defines the interface for a document metric."""

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_ie/core/module_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
logger = logging.getLogger(__name__)


class RequiresDocumentTypeMixin:
class WithDocumentTypeMixin:

DOCUMENT_TYPE: Optional[Type[Document]] = None

Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_ie/core/taskmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pytorch_ie.core.document import Annotation, Document
from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin
from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin
from pytorch_ie.core.module_mixins import WithDocumentTypeMixin
from pytorch_ie.core.registrable import Registrable

"""
Expand Down Expand Up @@ -133,7 +133,7 @@ class TaskModule(
PieTaskModuleHFHubMixin,
HyperparametersMixin,
Registrable,
RequiresDocumentTypeMixin,
WithDocumentTypeMixin,
Generic[
DocumentType,
InputEncoding,
Expand Down
Loading

0 comments on commit ef4c574

Please sign in to comment.