Skip to content

Commit

Permalink
check if the annotation layer names required in the taskmodules match…
Browse files Browse the repository at this point in the history
… the ones in the usual document types that they request and request no document type in the case of a mismatch (bot log a warning)
  • Loading branch information
ArneBinder committed Oct 23, 2023
1 parent dc05530 commit 7a5df3b
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 14 deletions.
15 changes: 13 additions & 2 deletions src/pytorch_ie/taskmodules/transformer_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,20 @@ def __init__(
@property
def document_type(self) -> Optional[Type[TextDocument]]:
if self.partition_annotation is not None:
return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
else:
return TextDocumentWithLabeledSpansAndBinaryRelations
dt = TextDocumentWithLabeledSpansAndBinaryRelations

if self.relation_annotation == "binary_relations":
return dt
else:
logger.warning(
f"relation_annotation={self.relation_annotation} is "
f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default "
f"value as layer name instead of the provided one."
)
return None

def get_relation_layer(self, document: Document) -> AnnotationList[BinaryRelation]:
return document[self.relation_annotation]
Expand Down
19 changes: 18 additions & 1 deletion src/pytorch_ie/taskmodules/transformer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging
import re
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union

from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
Expand Down Expand Up @@ -69,6 +69,23 @@ def __init__(

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

@property
def document_type(self) -> Optional[Type[TextDocument]]:
dt = self.DOCUMENT_TYPE
if (
self.entity_annotation == "labeled_spans"
and self.relation_annotation == "binary_relations"
):
return dt
else:
logger.warning(
f"entity_annotation={self.entity_annotation} and relation_annotation={self.relation_annotation} are "
f"not the default values ('labeled_spans' and 'binary_relations'), so the taskmodule "
f"{type(self).__name__} can not request the usual document type ({dt.__name__}) for auto-conversion "
f"because this has the bespoken default values as layer names instead of the provided ones."
)
return None

def encode_text(self, text: str) -> InputEncodingType:
return self.tokenizer(
text,
Expand Down
19 changes: 15 additions & 4 deletions src/pytorch_ie/taskmodules/transformer_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import logging
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -90,11 +90,22 @@ def __init__(
self.multi_label = multi_label

@property
def document_type(self) -> TypeAlias:
def document_type(self) -> Optional[Type[TextDocument]]:
if self.single_sentence:
return TextDocumentWithLabeledSpansAndSentences
dt = TextDocumentWithLabeledSpansAndSentences
else:
return TextDocumentWithLabeledSpans
dt = TextDocumentWithLabeledSpans

if self.entity_annotation == "labeled_spans":
return dt
else:
logger.warning(
f"entity_annotation={self.entity_annotation} is "
f"not the default value ('labeled_spans'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default "
f"value as layer name instead of the provided one."
)
return None

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down
20 changes: 16 additions & 4 deletions src/pytorch_ie/taskmodules/transformer_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
-> task_output
-> document
"""

import logging
from typing import (
Any,
Dict,
Expand All @@ -33,6 +33,8 @@
from pytorch_ie.documents import TextDocument, TextDocumentWithLabel, TextDocumentWithMultiLabel
from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType

logger = logging.getLogger(__name__)

InputEncodingType: TypeAlias = MutableMapping[str, Any]
TargetEncodingType: TypeAlias = Sequence[int]

Expand Down Expand Up @@ -104,11 +106,21 @@ def __init__(
self.multi_label = multi_label

@property
def document_type(self) -> Type[TextDocument]:
def document_type(self) -> Optional[Type[TextDocument]]:
if self.multi_label:
return TextDocumentWithMultiLabel
dt = TextDocumentWithMultiLabel
else:
return TextDocumentWithLabel
dt = TextDocumentWithLabel
if self.annotation == "label":
return dt
else:
logger.warning(
f"annotation={self.annotation} is "
f"not the default value ('label'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken "
f"default value as layer name instead of the provided one."
)
return None

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down
16 changes: 13 additions & 3 deletions src/pytorch_ie/taskmodules/transformer_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,21 @@ def __init__(
self.include_ill_formed_predictions = include_ill_formed_predictions

@property
def document_type(self) -> Type[TextDocument]:
def document_type(self) -> Optional[Type[TextDocument]]:
if self.partition_annotation is not None:
return TextDocumentWithLabeledSpansAndLabeledPartitions
dt = TextDocumentWithLabeledSpansAndLabeledPartitions
else:
return TextDocumentWithLabeledSpans
dt = TextDocumentWithLabeledSpans
if self.entity_annotation == "labeled_spans":
return dt
else:
logger.warning(
f"entity_annotation={self.entity_annotation} is "
f"not the default value ('labeled_spans'), so the taskmodule {type(self).__name__} can not request "
f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default "
f"value as layer name instead of the provided one."
)
return None

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down

0 comments on commit 7a5df3b

Please sign in to comment.