From 0bde3aeb219eaac7d1c64015827eff258a865b1a Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Thu, 7 Sep 2023 18:56:14 +0200 Subject: [PATCH] add model and taskmodule Interface classes (#328) * add the following model interfaces: RequiresModelNameOrPath and RequiresNumClasses; make tokenizer_vocab_size parameter optional for TransformerTextClassificationModel * add the ChangesTokenizerVocabSize taskmodule interfaces * re-add parameter freeze_model (removing it breaks tests) and really implement it --- src/pytorch_ie/models/interface.py | 6 ++++ src/pytorch_ie/models/transformer_seq2seq.py | 3 +- .../models/transformer_span_classification.py | 5 ++- .../models/transformer_text_classification.py | 33 +++++++------------ .../transformer_token_classification.py | 5 ++- src/pytorch_ie/taskmodules/interface.py | 2 ++ .../transformer_re_text_classification.py | 3 +- 7 files changed, 31 insertions(+), 26 deletions(-) create mode 100644 src/pytorch_ie/models/interface.py create mode 100644 src/pytorch_ie/taskmodules/interface.py diff --git a/src/pytorch_ie/models/interface.py b/src/pytorch_ie/models/interface.py new file mode 100644 index 00000000..c1de876b --- /dev/null +++ b/src/pytorch_ie/models/interface.py @@ -0,0 +1,6 @@ +class RequiresModelNameOrPath: + pass + + +class RequiresNumClasses: + pass diff --git a/src/pytorch_ie/models/transformer_seq2seq.py b/src/pytorch_ie/models/transformer_seq2seq.py index d3b5b324..578e2719 100644 --- a/src/pytorch_ie/models/transformer_seq2seq.py +++ b/src/pytorch_ie/models/transformer_seq2seq.py @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias from pytorch_ie.core import PyTorchIEModel +from pytorch_ie.models.interface import RequiresModelNameOrPath ModelInputType: TypeAlias = BatchEncoding ModelOutputType: TypeAlias = Seq2SeqLMOutput @@ -14,7 +15,7 @@ @PyTorchIEModel.register() -class TransformerSeq2SeqModel(PyTorchIEModel): +class TransformerSeq2SeqModel(PyTorchIEModel, RequiresModelNameOrPath): def __init__(self, model_name_or_path: str, learning_rate: float = 1e-5, **kwargs) -> None: super().__init__(**kwargs) diff --git a/src/pytorch_ie/models/transformer_span_classification.py b/src/pytorch_ie/models/transformer_span_classification.py index b809dc90..2b4a5af5 100644 --- a/src/pytorch_ie/models/transformer_span_classification.py +++ b/src/pytorch_ie/models/transformer_span_classification.py @@ -9,6 +9,7 @@ from typing_extensions import TypeAlias from pytorch_ie.core import PyTorchIEModel +from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses from pytorch_ie.models.modules.mlp import MLP ModelInputType: TypeAlias = BatchEncoding @@ -29,7 +30,9 @@ @PyTorchIEModel.register() -class TransformerSpanClassificationModel(PyTorchIEModel): +class TransformerSpanClassificationModel( + PyTorchIEModel, RequiresModelNameOrPath, RequiresNumClasses +): def __init__( self, model_name_or_path: str, diff --git a/src/pytorch_ie/models/transformer_text_classification.py b/src/pytorch_ie/models/transformer_text_classification.py index 78ff3c65..3d26d9d9 100644 --- a/src/pytorch_ie/models/transformer_text_classification.py +++ b/src/pytorch_ie/models/transformer_text_classification.py @@ -8,6 +8,7 @@ from typing_extensions import TypeAlias from pytorch_ie.core import PyTorchIEModel +from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses ModelInputType: TypeAlias = MutableMapping[str, Any] ModelOutputType: TypeAlias = Dict[str, Any] @@ -25,12 +26,14 @@ @PyTorchIEModel.register() -class TransformerTextClassificationModel(PyTorchIEModel): +class TransformerTextClassificationModel( + PyTorchIEModel, RequiresModelNameOrPath, RequiresNumClasses +): def __init__( self, model_name_or_path: str, num_classes: int, - tokenizer_vocab_size: int, + tokenizer_vocab_size: Optional[int] = None, ignore_index: Optional[int] = None, learning_rate: float = 1e-5, task_learning_rate: float = 1e-4, @@ -58,11 +61,13 @@ def __init__( self.model = AutoModel.from_config(config=config) else: self.model = AutoModel.from_pretrained(model_name_or_path, config=config) - self.model.resize_token_embeddings(tokenizer_vocab_size) - # if freeze_model: - # for param in self.model.parameters(): - # param.requires_grad = False + if freeze_model: + for param in self.model.parameters(): + param.requires_grad = False + + if tokenizer_vocab_size is not None: + self.model.resize_token_embeddings(tokenizer_vocab_size) classifier_dropout = ( config.classifier_dropout @@ -131,19 +136,3 @@ def configure_optimizers(self): return [optimizer], [{"scheduler": scheduler, "interval": "step"}] else: return optimizer - - # param_optimizer = list(self.named_parameters()) - # # TODO: this needs fixing (does not work models other than BERT) - # optimizer_grouped_parameters = [ - # {"params": [p for n, p in param_optimizer if "bert" in n]}, - # { - # "params": [p for n, p in param_optimizer if "bert" not in n], - # "lr": self.task_learning_rate, - # }, - # ] - # optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate) - # scheduler = get_linear_schedule_with_warmup( - # optimizer, int(self.t_total * self.warmup_proportion), self.t_total - # ) - # return [optimizer], [scheduler] - # return torch.optim.Adam(self.parameters(), lr=self.learning_rate) diff --git a/src/pytorch_ie/models/transformer_token_classification.py b/src/pytorch_ie/models/transformer_token_classification.py index b93de100..7db1a25a 100644 --- a/src/pytorch_ie/models/transformer_token_classification.py +++ b/src/pytorch_ie/models/transformer_token_classification.py @@ -7,6 +7,7 @@ from typing_extensions import TypeAlias from pytorch_ie.core import PyTorchIEModel +from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses ModelInputType: TypeAlias = BatchEncoding ModelOutputType: TypeAlias = Dict[str, Any] @@ -23,7 +24,9 @@ @PyTorchIEModel.register() -class TransformerTokenClassificationModel(PyTorchIEModel): +class TransformerTokenClassificationModel( + PyTorchIEModel, RequiresModelNameOrPath, RequiresNumClasses +): def __init__( self, model_name_or_path: str, diff --git a/src/pytorch_ie/taskmodules/interface.py b/src/pytorch_ie/taskmodules/interface.py new file mode 100644 index 00000000..9cc5d189 --- /dev/null +++ b/src/pytorch_ie/taskmodules/interface.py @@ -0,0 +1,2 @@ +class ChangesTokenizerVocabSize: + pass diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index b23406ec..06786f1a 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -27,6 +27,7 @@ from pytorch_ie.core import AnnotationList, Document, TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType +from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize from pytorch_ie.utils.span import get_token_slice, is_contained_in from pytorch_ie.utils.window import get_window_around_slice @@ -109,7 +110,7 @@ def shift_token_span(self, value: int): @TaskModule.register() -class TransformerRETextClassificationTaskModule(TaskModuleType): +class TransformerRETextClassificationTaskModule(TaskModuleType, ChangesTokenizerVocabSize): """Marker based relation extraction. This taskmodule prepares the input token ids in such a way that before and after the candidate head and tail entities special marker tokens are inserted. Then, the modified token ids can be simply passed into a transformer based text classifier