Skip to content

Commit

Permalink
Do not output KV cache when not using with-past in the ONNX export (#…
Browse files Browse the repository at this point in the history
…1358)

* never output KV cache when exporting without past

* let's be honest this was dumb

* wip

* if the CI is green that is a miracle

* cleaning

* fix encoder-decoder

* fix tests

* classic merge mistake
  • Loading branch information
fxmarty authored Sep 13, 2023
1 parent fe94480 commit afd2b5a
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 166 deletions.
94 changes: 20 additions & 74 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
"""

PAD_ATTENTION_MASK_TO_PAST: bool = False
USE_PAST_IN_INPUTS: Optional[bool] = None
USE_PRESENT_IN_OUTPUTS: Optional[bool] = None
SUPPORTS_PAST: bool = True

def __init__(
self,
Expand All @@ -566,89 +565,44 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
):
self.use_past = use_past
if use_past_in_inputs is None:
use_past_in_inputs = self.USE_PAST_IN_INPUTS
if use_present_in_outputs is None:
use_present_in_outputs = self.USE_PRESENT_IN_OUTPUTS
self.use_past_in_inputs = use_past if use_past_in_inputs is None else use_past_in_inputs
self.use_present_in_outputs = use_past if use_present_in_outputs is None else use_present_in_outputs

if use_past != self.use_past_in_inputs:
logger.warning(
f"use_past = {use_past} is different than use_past_in_inputs = {use_past_in_inputs}, the value of "
"use_past_in_inputs will used for the inputs."
)
self.use_past_in_inputs = use_past_in_inputs

if use_past != self.use_present_in_outputs:
logger.warning(
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value "
"of use_present_in_outputs value will be used for the outputs."
)
self.is_merged = False
self.use_cache_branch = None
super().__init__(
config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
)

@classmethod
def with_past(
cls,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
preprocessors: Optional[List[Any]] = None,
) -> "OnnxConfigWithPast":
"""
Instantiates a [`~optimum.exporters.onnx.OnnxConfig`] with `use_past` attribute set to `True`.
Args:
config (`transformers.PretrainedConfig`):
The underlying model's config to use when exporting to ONNX.
task (`str`, defaults to `"feature-extraction"`):
The task the model should be exported for.
int_dtype (`str`, defaults to `"int64"`):
The data type of integer tensors, could be ["int64", "int32", "int8"], default to "int64".
float_dtype (`str`, defaults to `"fp32"`):
The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32".
Returns:
[`~optimum.exporters.onnx.OnnxConfig`]: The onnx config with `.use_past = True`
"""
return cls(
config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, use_past=True, preprocessors=preprocessors
)

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past is False:
if not self.use_past_in_inputs:
common_outputs = super().outputs
# In the other cases, the sequence_length axis is not dynamic, always of length 1
elif self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
if self.use_present_in_outputs:
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

@property
def values_override(self) -> Optional[Dict[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": self.use_past_in_inputs or self.use_present_in_outputs}
return {"use_cache": self.use_past}

@add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES)
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
if self.use_past:
if self.use_past_in_inputs and self.use_cache_branch is not False:
input_names.append("past_key_values")

for input_name in input_names:
Expand Down Expand Up @@ -707,16 +661,13 @@ def overwrite_shape_and_generate_input(
# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both
if (
self.use_past is True
self.use_past
and self.use_past_in_inputs
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids"]
):
sequence_length = dummy_input_gen.sequence_length
if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1:
logger.info(
f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 "
f"will be used with use_past == True for `{input_name}`."
)
# Use a sequence length of 1 when the KV cache is already populated.
dummy_input_gen.sequence_length = 1
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
Expand Down Expand Up @@ -816,8 +767,7 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
):
Expand All @@ -828,26 +778,19 @@ def __init__(
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
use_present_in_outputs=use_present_in_outputs,
preprocessors=preprocessors,
)
self._behavior = behavior
self.override_attributes_for_behavior()

def override_attributes_for_behavior(self):
"""Override this to specify custom attribute change for a given behavior."""
if self._behavior is ConfigBehavior.ENCODER:
self.task = "feature-extraction"
self.use_past_in_inputs = False
self.use_present_in_outputs = False
if self._behavior is ConfigBehavior.DECODER:
self.use_past_in_inputs = self.use_past
self.use_present_in_outputs = True

def with_behavior(
self,
behavior: Union[str, ConfigBehavior],
use_past: bool = False,
use_past_in_inputs: bool = False,
) -> "OnnxSeq2SeqConfigWithPast":
"""
Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value.
Expand All @@ -856,7 +799,9 @@ def with_behavior(
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
use_past (`bool`, defaults to `False`):
Whether or not the new instance should use past.
Whether or not the ONNX config to instantiate is for a model using KV cache.
use_past_in_inputs (`bool`, defaults to `False`):
Whether the KV cache is to be passed as an input to the ONNX.
Returns:
`OnnxSeq2SeqConfigWithPast`
Expand All @@ -869,6 +814,7 @@ def with_behavior(
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
)
Expand All @@ -895,7 +841,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
new_axes_names[axis_idx] = axis_name
common_outputs[name] = new_axes_names

if self.use_present_in_outputs:
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")

return common_outputs
Expand All @@ -917,7 +864,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire

if (
self.is_merged is True
or (self._behavior is ConfigBehavior.DECODER and self.use_past is False)
or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs)
or direction == "inputs"
):
# TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export()
Expand Down Expand Up @@ -982,7 +929,6 @@ def post_process_exported_models(
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past = False
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
Expand Down
10 changes: 3 additions & 7 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def post_process_exported_models(
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past = False
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
Expand Down Expand Up @@ -277,8 +276,7 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
):
Expand All @@ -289,7 +287,6 @@ def __init__(
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
use_present_in_outputs=use_present_in_outputs,
behavior=behavior,
preprocessors=preprocessors,
)
Expand All @@ -316,7 +313,7 @@ def __init__(
self.is_decoder_with_past = True
kwargs["use_past"] = use_past
else:
self.use_present_in_outputs = False
self.use_past = False

if use_past and not self.is_decoder_with_past:
raise ValueError(
Expand All @@ -329,7 +326,7 @@ def __init__(
)
if issubclass(decoder_onnx_config_constructor.func, OnnxSeq2SeqConfigWithPast):
self._decoder_onnx_config = self._decoder_onnx_config.with_behavior(
self._behavior, use_past=kwargs["use_past"]
self._behavior, use_past=kwargs["use_past"], use_past_in_inputs=use_past_in_inputs
)

self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config
Expand Down Expand Up @@ -429,7 +426,6 @@ def post_process_exported_models(
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_past = False
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.use_cache_branch = True
Expand Down
29 changes: 12 additions & 17 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_present_in_outputs:
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
for i in range(self._normalized_config.encoder_num_layers):
common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"}
common_outputs[f"present.{i}.value"] = {
Expand Down Expand Up @@ -1223,8 +1224,7 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
):
Expand All @@ -1235,17 +1235,14 @@ def __init__(
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
use_present_in_outputs=use_present_in_outputs,
behavior=behavior,
preprocessors=preprocessors,
)

# TODO: Check modeling code to fix the issue with use_cache for trocr
if config.decoder.model_type == "trocr":
if self.use_past_in_inputs:
raise ValueError("Exporting past key values is not supported with TrOCR model!")

self.use_present_in_outputs = False
if config.decoder.model_type == "trocr" and use_past:
raise ValueError(
"Exporting TrOCR to ONNX with past key values is not supported with TrOCR model. Please open an issue in Optimum repository."
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand Down Expand Up @@ -1405,7 +1402,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
new_axes_names[axis_idx] = axis_name
common_outputs[name] = new_axes_names

if self.use_present_in_outputs:
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")

return common_outputs
Expand Down Expand Up @@ -1471,16 +1469,13 @@ def overwrite_shape_and_generate_input(
# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both
if (
self.use_past is True
self.use_past
and self.use_past_in_inputs
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids"]
):
sequence_length = dummy_input_gen.sequence_length
if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1:
logger.info(
f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 "
f"will be used with use_past == True for `{input_name}`."
)
# Use a sequence length of 1 when the KV cache is already populated.
dummy_input_gen.sequence_length = 1
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
Expand Down
21 changes: 9 additions & 12 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ def __init__(
self.real_config = config._onnx_config
else:
self.real_config = config
allow_past_in_outputs = (
hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs
)

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
Expand Down Expand Up @@ -180,9 +179,7 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

allow_past_in_outputs = (
hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs
)
allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

# use_cache is by default set to False with pix2struct, so we need to set it to
# True to export with past key value
Expand All @@ -196,7 +193,7 @@ def patched_forward(*args, **kwargs):

outputs = self.orig_forward(*args, **kwargs)

# Filter out cross attention past key values
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
Expand All @@ -213,10 +210,12 @@ def patched_forward(*args, **kwargs):
filterd_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder" and self.real_config.use_past is False
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filterd_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past is True:
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])

return filterd_outputs
Expand All @@ -233,9 +232,7 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

allow_past_in_outputs = (
hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs
)
allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
Expand Down
Loading

0 comments on commit afd2b5a

Please sign in to comment.