diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index f2ea5edd..e25120d7 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -215,10 +215,22 @@ def __init__( @property def document_type(self) -> Optional[Type[TextDocument]]: + dt: Type[TextDocument] if self.partition_annotation is not None: - return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions + dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions else: - return TextDocumentWithLabeledSpansAndBinaryRelations + dt = TextDocumentWithLabeledSpansAndBinaryRelations + + if self.relation_annotation == "binary_relations": + return dt + else: + logger.warning( + f"relation_annotation={self.relation_annotation} is " + f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request " + f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default " + f"value as layer name instead of the provided one." + ) + return None def get_relation_layer(self, document: Document) -> AnnotationList[BinaryRelation]: return document[self.relation_annotation] diff --git a/src/pytorch_ie/taskmodules/transformer_seq2seq.py b/src/pytorch_ie/taskmodules/transformer_seq2seq.py index 1bffe6c0..c886bde0 100644 --- a/src/pytorch_ie/taskmodules/transformer_seq2seq.py +++ b/src/pytorch_ie/taskmodules/transformer_seq2seq.py @@ -9,7 +9,7 @@ import logging import re -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union from transformers import AutoTokenizer from transformers.file_utils import PaddingStrategy @@ -69,6 +69,23 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + @property + def document_type(self) -> Optional[Type[TextDocument]]: + dt: Type[TextDocument] = self.DOCUMENT_TYPE + if ( + self.entity_annotation == "labeled_spans" + and self.relation_annotation == "binary_relations" + ): + return dt + else: + logger.warning( + f"entity_annotation={self.entity_annotation} and relation_annotation={self.relation_annotation} are " + f"not the default values ('labeled_spans' and 'binary_relations'), so the taskmodule " + f"{type(self).__name__} can not request the usual document type ({dt.__name__}) for auto-conversion " + f"because this has the bespoken default values as layer names instead of the provided ones." + ) + return None + def encode_text(self, text: str) -> InputEncodingType: return self.tokenizer( text, diff --git a/src/pytorch_ie/taskmodules/transformer_span_classification.py b/src/pytorch_ie/taskmodules/transformer_span_classification.py index 23f341c6..8a323ca4 100644 --- a/src/pytorch_ie/taskmodules/transformer_span_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_span_classification.py @@ -8,7 +8,7 @@ """ import logging -from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union import numpy as np import torch @@ -90,11 +90,23 @@ def __init__( self.multi_label = multi_label @property - def document_type(self) -> TypeAlias: + def document_type(self) -> Optional[Type[TextDocument]]: + dt: Type[TextDocument] if self.single_sentence: - return TextDocumentWithLabeledSpansAndSentences + dt = TextDocumentWithLabeledSpansAndSentences else: - return TextDocumentWithLabeledSpans + dt = TextDocumentWithLabeledSpans + + if self.entity_annotation == "labeled_spans": + return dt + else: + logger.warning( + f"entity_annotation={self.entity_annotation} is " + f"not the default value ('labeled_spans'), so the taskmodule {type(self).__name__} can not request " + f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default " + f"value as layer name instead of the provided one." + ) + return None def _config(self) -> Dict[str, Any]: config = super()._config() diff --git a/src/pytorch_ie/taskmodules/transformer_text_classification.py b/src/pytorch_ie/taskmodules/transformer_text_classification.py index d5b9fb54..81fe680a 100644 --- a/src/pytorch_ie/taskmodules/transformer_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_text_classification.py @@ -6,7 +6,7 @@ -> task_output -> document """ - +import logging from typing import ( Any, Dict, @@ -33,6 +33,8 @@ from pytorch_ie.documents import TextDocument, TextDocumentWithLabel, TextDocumentWithMultiLabel from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType +logger = logging.getLogger(__name__) + InputEncodingType: TypeAlias = MutableMapping[str, Any] TargetEncodingType: TypeAlias = Sequence[int] @@ -74,7 +76,7 @@ def __init__( self, tokenizer_name_or_path: str, label_to_verbalizer: Dict[str, str], - annotation: str = "labels", + annotation: str = "label", padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = True, max_length: Optional[int] = None, @@ -104,11 +106,22 @@ def __init__( self.multi_label = multi_label @property - def document_type(self) -> Type[TextDocument]: + def document_type(self) -> Optional[Type[TextDocument]]: + dt: Type[TextDocument] if self.multi_label: - return TextDocumentWithMultiLabel + dt = TextDocumentWithMultiLabel else: - return TextDocumentWithLabel + dt = TextDocumentWithLabel + if self.annotation == "label": + return dt + else: + logger.warning( + f"annotation={self.annotation} is " + f"not the default value ('label'), so the taskmodule {type(self).__name__} can not request " + f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken " + f"default value as layer name instead of the provided one." + ) + return None def _config(self) -> Dict[str, Any]: config = super()._config() diff --git a/src/pytorch_ie/taskmodules/transformer_token_classification.py b/src/pytorch_ie/taskmodules/transformer_token_classification.py index 9763e55a..8acd89c0 100644 --- a/src/pytorch_ie/taskmodules/transformer_token_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_token_classification.py @@ -95,11 +95,22 @@ def __init__( self.include_ill_formed_predictions = include_ill_formed_predictions @property - def document_type(self) -> Type[TextDocument]: + def document_type(self) -> Optional[Type[TextDocument]]: + dt: Type[TextDocument] if self.partition_annotation is not None: - return TextDocumentWithLabeledSpansAndLabeledPartitions + dt = TextDocumentWithLabeledSpansAndLabeledPartitions else: - return TextDocumentWithLabeledSpans + dt = TextDocumentWithLabeledSpans + if self.entity_annotation == "labeled_spans": + return dt + else: + logger.warning( + f"entity_annotation={self.entity_annotation} is " + f"not the default value ('labeled_spans'), so the taskmodule {type(self).__name__} can not request " + f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default " + f"value as layer name instead of the provided one." + ) + return None def _config(self) -> Dict[str, Any]: config = super()._config()