Skip to content

Commit

Permalink
fix model patcher for opt models
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 20, 2023
1 parent be836b5 commit b05f599
Showing 1 changed file with 44 additions and 28 deletions.
72 changes: 44 additions & 28 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,23 +354,22 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask")

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask")

def __enter__(self):
super().__enter__()
if self.real_config.task == "text-generation" and self.real_config.use_past:
if self.patch:
setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.real_config.task == "text-generation" and self.real_config.use_past:
if self.patch:
setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask)


class MPTModelPatcher(BloomModelPatcher):
pass


class LlamaModelPatcher(ModelPatcher):
def __init__(
self,
Expand All @@ -379,16 +378,19 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask")

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if self.real_config.task == "text-generation" and self.real_config.use_past:
if self.patch:
setattr(self._model.model, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.real_config.task == "text-generation" and self.real_config.use_past:
if self.patch:
setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


Expand All @@ -400,32 +402,49 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
if (
self.real_config._behavior == "decoder"
and self.real_config.task == "text-generation"
and self.real_config.use_past
):
self.patch = self.real_config.task == "text-generation" and self.real_config.use_past and self.real_config._behavior == "decoder"
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if (
self.real_config._behavior == "decoder"
and self.real_config.task == "text-generation"
and self.real_config.use_past
):
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if (
self.real_config._behavior == "decoder"
and self.real_config.task == "text-generation"
and self.real_config.use_past
):
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


class OPTModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)



class MPTModelPatcher(BloomModelPatcher):
pass


class BlenderbotSmallModelPatcher(BartModelPatcher):
pass

Expand All @@ -437,6 +456,3 @@ class BlenderbotModelPatcher(BartModelPatcher):
class PegasusModelPatcher(BartModelPatcher):
pass


class OPTModelPatcher(BartModelPatcher):
pass

0 comments on commit b05f599

Please sign in to comment.