diff --git a/optimum/exporters/onnx/constants.py b/optimum/exporters/onnx/constants.py index 0a6f9f9b363..c994c01ff71 100644 --- a/optimum/exporters/onnx/constants.py +++ b/optimum/exporters/onnx/constants.py @@ -38,4 +38,5 @@ "bart", "musicgen", "whisper", + "gemma2", ] diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 1f873e4e716..ff8b336dcfd 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -43,7 +43,7 @@ AttentionMaskConverter = None if _transformers_version >= version.parse("4.42"): - from transformers.cache_utils import SlidingWindowCache, StaticCache + from transformers.cache_utils import SlidingWindowCache, StaticCache, DynamicCache if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -144,12 +144,40 @@ def __init__( allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + signature = inspect.signature(self.orig_forward) + parameters = list(signature.parameters.keys()) + index_of_pkv_parameter = parameters.index('past_key_values') + index_of_attention_mask_parameter = parameters.index('attention_mask') + index_of_input_ids_parameter = parameters.index('input_ids') + + # Globally override torch.tril so that the export works with ONNXRuntime. + # Otherwise we get `onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/decoder/Trilu'` + # See https://github.com/huggingface/optimum/pull/1391 for more information + # TODO: Remove when https://github.com/microsoft/onnxruntime/pull/20917 is available in a released version. + original_tril = torch.tril + def patched_tril(input, diagonal=0): + if input.dtype == torch.bool: + return original_tril(input.to(torch.float), diagonal=diagonal).to(torch.bool) + else: + return original_tril(input, diagonal=diagonal) + torch.tril = patched_tril + @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): - signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + if index_of_pkv_parameter != -1: + dynamic_cache = DynamicCache.from_legacy_cache(args[index_of_pkv_parameter]) + attention_mask = args[index_of_attention_mask_parameter] + input_ids = args[index_of_input_ids_parameter] + sequence_length = input_ids.shape[1] + + dynamic_cache.get_max_length = lambda: attention_mask.shape[-1] + sequence_length + 1 + args[index_of_pkv_parameter] = dynamic_cache + outputs = self.orig_forward(*args, **kwargs) + if index_of_pkv_parameter != -1: + outputs.past_key_values = outputs.past_key_values.to_legacy_cache() # This code block handles different cases of the filterd_outputs input to align it with the expected # format of outputs. It is common for the output type of a model to vary, such as tensor, list,