Skip to content

Commit

Permalink
Fix document_type property of taskmodules (#361)
Browse files Browse the repository at this point in the history
* set default value for annotation to "label" in TransformerTextClassificationTaskModule to match the document type

* check if the annotation layer names required in the taskmodules match the ones in the usual document types that they request and request no document type in the case of a mismatch (bot log a warning)

* make mypy happy
  • Loading branch information
ArneBinder authored Oct 23, 2023
1 parent d779a75 commit 2795226
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 15 deletions.
16 changes: 14 additions & 2 deletions src/pytorch_ie/taskmodules/transformer_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,22 @@ def __init__(

@property
def document_type(self) -> Optional[Type[TextDocument]]:
dt: Type[TextDocument]
if self.partition_annotation is not None:
return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
else:
return TextDocumentWithLabeledSpansAndBinaryRelations
dt = TextDocumentWithLabeledSpansAndBinaryRelations

if self.relation_annotation == "binary_relations":
return dt
else:
logger.warning(
f"relation_annotation={self.relation_annotation} is "
f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default "
f"value as layer name instead of the provided one."
)
return None

def get_relation_layer(self, document: Document) -> AnnotationList[BinaryRelation]:
return document[self.relation_annotation]
Expand Down
19 changes: 18 additions & 1 deletion src/pytorch_ie/taskmodules/transformer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging
import re
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union

from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
Expand Down Expand Up @@ -69,6 +69,23 @@ def __init__(

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

@property
def document_type(self) -> Optional[Type[TextDocument]]:
dt: Type[TextDocument] = self.DOCUMENT_TYPE
if (
self.entity_annotation == "labeled_spans"
and self.relation_annotation == "binary_relations"
):
return dt
else:
logger.warning(
f"entity_annotation={self.entity_annotation} and relation_annotation={self.relation_annotation} are "
f"not the default values ('labeled_spans' and 'binary_relations'), so the taskmodule "
f"{type(self).__name__} can not request the usual document type ({dt.__name__}) for auto-conversion "
f"because this has the bespoken default values as layer names instead of the provided ones."
)
return None

def encode_text(self, text: str) -> InputEncodingType:
return self.tokenizer(
text,
Expand Down
20 changes: 16 additions & 4 deletions src/pytorch_ie/taskmodules/transformer_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import logging
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -90,11 +90,23 @@ def __init__(
self.multi_label = multi_label

@property
def document_type(self) -> TypeAlias:
def document_type(self) -> Optional[Type[TextDocument]]:
dt: Type[TextDocument]
if self.single_sentence:
return TextDocumentWithLabeledSpansAndSentences
dt = TextDocumentWithLabeledSpansAndSentences
else:
return TextDocumentWithLabeledSpans
dt = TextDocumentWithLabeledSpans

if self.entity_annotation == "labeled_spans":
return dt
else:
logger.warning(
f"entity_annotation={self.entity_annotation} is "
f"not the default value ('labeled_spans'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default "
f"value as layer name instead of the provided one."
)
return None

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down
23 changes: 18 additions & 5 deletions src/pytorch_ie/taskmodules/transformer_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
-> task_output
-> document
"""

import logging
from typing import (
Any,
Dict,
Expand All @@ -33,6 +33,8 @@
from pytorch_ie.documents import TextDocument, TextDocumentWithLabel, TextDocumentWithMultiLabel
from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType

logger = logging.getLogger(__name__)

InputEncodingType: TypeAlias = MutableMapping[str, Any]
TargetEncodingType: TypeAlias = Sequence[int]

Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
self,
tokenizer_name_or_path: str,
label_to_verbalizer: Dict[str, str],
annotation: str = "labels",
annotation: str = "label",
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = True,
max_length: Optional[int] = None,
Expand Down Expand Up @@ -104,11 +106,22 @@ def __init__(
self.multi_label = multi_label

@property
def document_type(self) -> Type[TextDocument]:
def document_type(self) -> Optional[Type[TextDocument]]:
dt: Type[TextDocument]
if self.multi_label:
return TextDocumentWithMultiLabel
dt = TextDocumentWithMultiLabel
else:
return TextDocumentWithLabel
dt = TextDocumentWithLabel
if self.annotation == "label":
return dt
else:
logger.warning(
f"annotation={self.annotation} is "
f"not the default value ('label'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken "
f"default value as layer name instead of the provided one."
)
return None

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down
17 changes: 14 additions & 3 deletions src/pytorch_ie/taskmodules/transformer_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,22 @@ def __init__(
self.include_ill_formed_predictions = include_ill_formed_predictions

@property
def document_type(self) -> Type[TextDocument]:
def document_type(self) -> Optional[Type[TextDocument]]:
dt: Type[TextDocument]
if self.partition_annotation is not None:
return TextDocumentWithLabeledSpansAndLabeledPartitions
dt = TextDocumentWithLabeledSpansAndLabeledPartitions
else:
return TextDocumentWithLabeledSpans
dt = TextDocumentWithLabeledSpans
if self.entity_annotation == "labeled_spans":
return dt
else:
logger.warning(
f"entity_annotation={self.entity_annotation} is "
f"not the default value ('labeled_spans'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default "
f"value as layer name instead of the provided one."
)
return None

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down

0 comments on commit 2795226

Please sign in to comment.