Skip to content

Commit

Permalink
(testing) ensure correct cache format
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jul 25, 2024
1 parent 97f539b commit ba65bb7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@
"bart",
"musicgen",
"whisper",
"gemma2",
]
32 changes: 30 additions & 2 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
AttentionMaskConverter = None

if _transformers_version >= version.parse("4.42"):
from transformers.cache_utils import SlidingWindowCache, StaticCache
from transformers.cache_utils import SlidingWindowCache, StaticCache, DynamicCache

if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
Expand Down Expand Up @@ -144,12 +144,40 @@ def __init__(

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

signature = inspect.signature(self.orig_forward)
parameters = list(signature.parameters.keys())
index_of_pkv_parameter = parameters.index('past_key_values')
index_of_attention_mask_parameter = parameters.index('attention_mask')
index_of_input_ids_parameter = parameters.index('input_ids')

# Globally override torch.tril so that the export works with ONNXRuntime.
# Otherwise we get `onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/decoder/Trilu'`
# See https://github.com/huggingface/optimum/pull/1391 for more information
# TODO: Remove when https://github.com/microsoft/onnxruntime/pull/20917 is available in a released version.
original_tril = torch.tril
def patched_tril(input, diagonal=0):
if input.dtype == torch.bool:
return original_tril(input.to(torch.float), diagonal=diagonal).to(torch.bool)
else:
return original_tril(input, diagonal=diagonal)
torch.tril = patched_tril

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

if index_of_pkv_parameter != -1:
dynamic_cache = DynamicCache.from_legacy_cache(args[index_of_pkv_parameter])
attention_mask = args[index_of_attention_mask_parameter]
input_ids = args[index_of_input_ids_parameter]
sequence_length = input_ids.shape[1]

dynamic_cache.get_max_length = lambda: attention_mask.shape[-1] + sequence_length + 1
args[index_of_pkv_parameter] = dynamic_cache

outputs = self.orig_forward(*args, **kwargs)
if index_of_pkv_parameter != -1:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

# This code block handles different cases of the filterd_outputs input to align it with the expected
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
Expand Down

0 comments on commit ba65bb7

Please sign in to comment.