From b05f59915d83b27ab414b6b883865de2709c2b8b Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 10:51:15 +0200 Subject: [PATCH] fix model patcher for opt models --- optimum/exporters/onnx/model_patcher.py | 72 +++++++++++++++---------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 4f16624d292..77a0345a9c8 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -354,23 +354,22 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask") + + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + if self.patch: + self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask") def __enter__(self): super().__enter__() - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask) -class MPTModelPatcher(BloomModelPatcher): - pass - - class LlamaModelPatcher(ModelPatcher): def __init__( self, @@ -379,16 +378,19 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask") + + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + if self.patch: + self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.model, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) @@ -400,32 +402,49 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - if ( - self.real_config._behavior == "decoder" - and self.real_config.task == "text-generation" - and self.real_config.use_past - ): + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past and self.real_config._behavior == "decoder" + if self.patch: self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if ( - self.real_config._behavior == "decoder" - and self.real_config.task == "text-generation" - and self.real_config.use_past - ): + if self.patch: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if ( - self.real_config._behavior == "decoder" - and self.real_config.task == "text-generation" - and self.real_config.use_past - ): + if self.patch: + setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + + +class OPTModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + if self.patch: + self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + + def __enter__(self): + super().__enter__() + if self.patch: + setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if self.patch: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + +class MPTModelPatcher(BloomModelPatcher): + pass + + class BlenderbotSmallModelPatcher(BartModelPatcher): pass @@ -437,6 +456,3 @@ class BlenderbotModelPatcher(BartModelPatcher): class PegasusModelPatcher(BartModelPatcher): pass - -class OPTModelPatcher(BartModelPatcher): - pass