From afd2b5a36663bebd1f501486acee065c728947bc Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 13 Sep 2023 20:43:35 +0900 Subject: [PATCH] Do not output KV cache when not using `with-past` in the ONNX export (#1358) * never output KV cache when exporting without past * let's be honest this was dumb * wip * if the CI is green that is a miracle * cleaning * fix encoder-decoder * fix tests * classic merge mistake --- optimum/exporters/onnx/base.py | 94 +++++------------------- optimum/exporters/onnx/config.py | 10 +-- optimum/exporters/onnx/model_configs.py | 29 +++----- optimum/exporters/onnx/model_patcher.py | 21 +++--- optimum/exporters/onnx/utils.py | 13 +++- optimum/exporters/tasks.py | 4 +- tests/exporters/onnx/test_onnx_export.py | 51 +------------ 7 files changed, 56 insertions(+), 166 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index ee6f76e35eb..05f031322f7 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -556,8 +556,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC): """ PAD_ATTENTION_MASK_TO_PAST: bool = False - USE_PAST_IN_INPUTS: Optional[bool] = None - USE_PRESENT_IN_OUTPUTS: Optional[bool] = None + SUPPORTS_PAST: bool = True def __init__( self, @@ -566,81 +565,36 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, - use_past_in_inputs: Optional[bool] = None, - use_present_in_outputs: Optional[bool] = None, + use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, ): self.use_past = use_past - if use_past_in_inputs is None: - use_past_in_inputs = self.USE_PAST_IN_INPUTS - if use_present_in_outputs is None: - use_present_in_outputs = self.USE_PRESENT_IN_OUTPUTS - self.use_past_in_inputs = use_past if use_past_in_inputs is None else use_past_in_inputs - self.use_present_in_outputs = use_past if use_present_in_outputs is None else use_present_in_outputs - - if use_past != self.use_past_in_inputs: - logger.warning( - f"use_past = {use_past} is different than use_past_in_inputs = {use_past_in_inputs}, the value of " - "use_past_in_inputs will used for the inputs." - ) + self.use_past_in_inputs = use_past_in_inputs - if use_past != self.use_present_in_outputs: - logger.warning( - f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value " - "of use_present_in_outputs value will be used for the outputs." - ) self.is_merged = False self.use_cache_branch = None super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors ) - @classmethod - def with_past( - cls, - config: "PretrainedConfig", - task: str = "feature-extraction", - int_dtype: str = "int64", - float_dtype: str = "fp32", - preprocessors: Optional[List[Any]] = None, - ) -> "OnnxConfigWithPast": - """ - Instantiates a [`~optimum.exporters.onnx.OnnxConfig`] with `use_past` attribute set to `True`. - - Args: - config (`transformers.PretrainedConfig`): - The underlying model's config to use when exporting to ONNX. - task (`str`, defaults to `"feature-extraction"`): - The task the model should be exported for. - int_dtype (`str`, defaults to `"int64"`): - The data type of integer tensors, could be ["int64", "int32", "int8"], default to "int64". - float_dtype (`str`, defaults to `"fp32"`): - The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32". - - Returns: - [`~optimum.exporters.onnx.OnnxConfig`]: The onnx config with `.use_past = True` - """ - return cls( - config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, use_past=True, preprocessors=preprocessors - ) - @property def outputs(self) -> Dict[str, Dict[int, str]]: - if self.use_past is False: + if not self.use_past_in_inputs: common_outputs = super().outputs # In the other cases, the sequence_length axis is not dynamic, always of length 1 elif self.task == "feature-extraction": common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) else: common_outputs = OrderedDict({"logits": {0: "batch_size"}}) - if self.use_present_in_outputs: + if self.use_past: + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @property def values_override(self) -> Optional[Dict[str, Any]]: if hasattr(self._config, "use_cache"): - return {"use_cache": self.use_past_in_inputs or self.use_present_in_outputs} + return {"use_cache": self.use_past} @add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES) def generate_dummy_inputs(self, framework: str = "pt", **kwargs): @@ -648,7 +602,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = {} input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] - if self.use_past: + if self.use_past_in_inputs and self.use_cache_branch is not False: input_names.append("past_key_values") for input_name in input_names: @@ -707,16 +661,13 @@ def overwrite_shape_and_generate_input( # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name # while models from TextDecoderOnnxConfig use input_ids, hence the check for both if ( - self.use_past is True + self.use_past + and self.use_past_in_inputs and self.use_cache_branch is not False and input_name in ["decoder_input_ids", "input_ids"] ): sequence_length = dummy_input_gen.sequence_length - if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1: - logger.info( - f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 " - f"will be used with use_past == True for `{input_name}`." - ) + # Use a sequence length of 1 when the KV cache is already populated. dummy_input_gen.sequence_length = 1 dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype @@ -816,8 +767,7 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, - use_past_in_inputs: Optional[bool] = None, - use_present_in_outputs: Optional[bool] = None, + use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, ): @@ -828,26 +778,19 @@ def __init__( float_dtype=float_dtype, use_past=use_past, use_past_in_inputs=use_past_in_inputs, - use_present_in_outputs=use_present_in_outputs, preprocessors=preprocessors, ) self._behavior = behavior - self.override_attributes_for_behavior() - def override_attributes_for_behavior(self): - """Override this to specify custom attribute change for a given behavior.""" if self._behavior is ConfigBehavior.ENCODER: self.task = "feature-extraction" self.use_past_in_inputs = False - self.use_present_in_outputs = False - if self._behavior is ConfigBehavior.DECODER: - self.use_past_in_inputs = self.use_past - self.use_present_in_outputs = True def with_behavior( self, behavior: Union[str, ConfigBehavior], use_past: bool = False, + use_past_in_inputs: bool = False, ) -> "OnnxSeq2SeqConfigWithPast": """ Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. @@ -856,7 +799,9 @@ def with_behavior( behavior ([`ConfigBehavior`]): The behavior to use for the new instance. use_past (`bool`, defaults to `False`): - Whether or not the new instance should use past. + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. Returns: `OnnxSeq2SeqConfigWithPast` @@ -869,6 +814,7 @@ def with_behavior( int_dtype=self.int_dtype, float_dtype=self.float_dtype, use_past=use_past, + use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=self._preprocessors, ) @@ -895,7 +841,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]: new_axes_names[axis_idx] = axis_name common_outputs[name] = new_axes_names - if self.use_present_in_outputs: + if self.use_past: + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -917,7 +864,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire if ( self.is_merged is True - or (self._behavior is ConfigBehavior.DECODER and self.use_past is False) + or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) or direction == "inputs" ): # TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export() @@ -982,7 +929,6 @@ def post_process_exported_models( models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False # Past key values won't be generated by default, but added in the input - models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past = False models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index dd043aa3f55..537e8846fe8 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -124,7 +124,6 @@ def post_process_exported_models( models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False # Past key values won't be generated by default, but added in the input - models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past = False models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True @@ -277,8 +276,7 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, - use_past_in_inputs: Optional[bool] = None, - use_present_in_outputs: Optional[bool] = None, + use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, ): @@ -289,7 +287,6 @@ def __init__( float_dtype=float_dtype, use_past=use_past, use_past_in_inputs=use_past_in_inputs, - use_present_in_outputs=use_present_in_outputs, behavior=behavior, preprocessors=preprocessors, ) @@ -316,7 +313,7 @@ def __init__( self.is_decoder_with_past = True kwargs["use_past"] = use_past else: - self.use_present_in_outputs = False + self.use_past = False if use_past and not self.is_decoder_with_past: raise ValueError( @@ -329,7 +326,7 @@ def __init__( ) if issubclass(decoder_onnx_config_constructor.func, OnnxSeq2SeqConfigWithPast): self._decoder_onnx_config = self._decoder_onnx_config.with_behavior( - self._behavior, use_past=kwargs["use_past"] + self._behavior, use_past=kwargs["use_past"], use_past_in_inputs=use_past_in_inputs ) self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config @@ -429,7 +426,6 @@ def post_process_exported_models( models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_cache_branch = False # Past key values won't be generated by default, but added in the input - models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_past = False models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_past_in_inputs = True models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.use_cache_branch = True diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 2672ed45cfe..e9432dd58be 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -490,7 +490,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs else: common_outputs = super(OnnxConfigWithPast, self).outputs - if self.use_present_in_outputs: + if self.use_past: + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. for i in range(self._normalized_config.encoder_num_layers): common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"} common_outputs[f"present.{i}.value"] = { @@ -1223,8 +1224,7 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, - use_past_in_inputs: Optional[bool] = None, - use_present_in_outputs: Optional[bool] = None, + use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, ): @@ -1235,17 +1235,14 @@ def __init__( float_dtype=float_dtype, use_past=use_past, use_past_in_inputs=use_past_in_inputs, - use_present_in_outputs=use_present_in_outputs, behavior=behavior, preprocessors=preprocessors, ) - # TODO: Check modeling code to fix the issue with use_cache for trocr - if config.decoder.model_type == "trocr": - if self.use_past_in_inputs: - raise ValueError("Exporting past key values is not supported with TrOCR model!") - - self.use_present_in_outputs = False + if config.decoder.model_type == "trocr" and use_past: + raise ValueError( + "Exporting TrOCR to ONNX with past key values is not supported with TrOCR model. Please open an issue in Optimum repository." + ) @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1405,7 +1402,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]: new_axes_names[axis_idx] = axis_name common_outputs[name] = new_axes_names - if self.use_present_in_outputs: + if self.use_past: + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -1471,16 +1469,13 @@ def overwrite_shape_and_generate_input( # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name # while models from TextDecoderOnnxConfig use input_ids, hence the check for both if ( - self.use_past is True + self.use_past + and self.use_past_in_inputs and self.use_cache_branch is not False and input_name in ["decoder_input_ids", "input_ids"] ): sequence_length = dummy_input_gen.sequence_length - if "sequence_length" in input_shapes and input_shapes["sequence_length"] != 1: - logger.info( - f"Asked a sequence length of {input_shapes['sequence_length']}, but a sequence length of 1 " - f"will be used with use_past == True for `{input_name}`." - ) + # Use a sequence length of 1 when the KV cache is already populated. dummy_input_gen.sequence_length = 1 dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index d69d8d6f3a5..e6b50b6dc08 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -101,9 +101,8 @@ def __init__( self.real_config = config._onnx_config else: self.real_config = config - allow_past_in_outputs = ( - hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs - ) + + allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): @@ -180,9 +179,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - allow_past_in_outputs = ( - hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs - ) + allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past # use_cache is by default set to False with pix2struct, so we need to set it to # True to export with past key value @@ -196,7 +193,7 @@ def patched_forward(*args, **kwargs): outputs = self.orig_forward(*args, **kwargs) - # Filter out cross attention past key values + # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. filterd_outputs = {} for name, value in outputs.items(): onnx_output_name = config.torch_to_onnx_output_map.get(name, name) @@ -213,10 +210,12 @@ def patched_forward(*args, **kwargs): filterd_outputs[name] = value else: if self.real_config._behavior == "monolith" or ( - self.real_config._behavior == "decoder" and self.real_config.use_past is False + self.real_config._behavior == "decoder" + and (self.real_config.is_merged or not self.real_config.use_past_in_inputs) ): filterd_outputs[name] = value - elif self.real_config._behavior == "decoder" and self.real_config.use_past is True: + elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: + # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. filterd_outputs[name] = tuple([v[:2] for v in value]) return filterd_outputs @@ -233,9 +232,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - allow_past_in_outputs = ( - hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs - ) + allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 82c99341184..3170ebdcdd2 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -198,11 +198,11 @@ def get_encoder_decoder_models_for_export( encoder_onnx_config = config.with_behavior("encoder") models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config) - decoder_onnx_config = config.with_behavior("decoder", use_past=False) + decoder_onnx_config = config.with_behavior("decoder", use_past=config.use_past, use_past_in_inputs=False) models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config) if config.use_past: - decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True) + decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( models_for_export[ONNX_DECODER_WITH_PAST_NAME], decoder_onnx_config_with_past, @@ -238,8 +238,8 @@ def get_decoder_models_for_export( onnx_config = config.__class__( model.config, task=config.task, + use_past=config.use_past, use_past_in_inputs=False, - use_present_in_outputs=True, float_dtype=config.float_dtype, int_dtype=config.int_dtype, ) @@ -247,7 +247,12 @@ def get_decoder_models_for_export( if config.use_past: onnx_config_with_past = config.__class__( - model.config, task=config.task, use_past=True, float_dtype=config.float_dtype, int_dtype=config.int_dtype + model.config, + task=config.task, + use_past=True, + use_past_in_inputs=True, + float_dtype=config.float_dtype, + int_dtype=config.int_dtype, ) models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( models_for_export[ONNX_DECODER_WITH_PAST_NAME], diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index a31fa6272f6..5882972d758 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -65,9 +65,9 @@ def is_backend_available(backend): def make_backend_config_constructor_for_task(config_cls: Type, task: str) -> ExportConfigConstructor: if "-with-past" in task: - if not hasattr(config_cls, "with_past"): + if not getattr(config_cls, "SUPPORTS_PAST", False): raise ValueError(f"{config_cls} does not support tasks with past.") - constructor = partial(config_cls.with_past, task=task.replace("-with-past", "")) + constructor = partial(config_cls, use_past=True, task=task.replace("-with-past", "")) else: constructor = partial(config_cls, task=task) return constructor diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 910036c8554..10eaeddd13c 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -18,7 +18,6 @@ from tempfile import TemporaryDirectory from typing import Dict from unittest import TestCase -from unittest.mock import patch import onnx import pytest @@ -92,54 +91,6 @@ class OnnxConfigTestCase(TestCase): # TODO: insert relevant tests here. -class OnnxConfigWithPastTestCase(TestCase): - """ - Cover the tests for model which have use_cache task (i.e. "with_past" for ONNX) - """ - - SUPPORTED_WITH_PAST_CONFIGS = () - - @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) - def test_use_past(self): - """ - Ensures the use_past variable is correctly being set. - """ - for name, config in OnnxConfigWithPastTestCase.SUPPORTED_WITH_PAST_CONFIGS: - with self.subTest(name): - self.assertFalse( - OnnxConfigWithPast(config()).use_past, - "OnnxConfigWithPast should not use_past", - ) - - self.assertTrue( - OnnxConfigWithPast.with_past(config()).use_past, - "OnnxConfigWithPast should use_past", - ) - - @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) - def test_values_override(self): - """ - Ensures the use_past variable correctly set the `use_cache` value in model's configuration. - """ - for name, config in OnnxConfigWithPastTestCase.SUPPORTED_WITH_PAST_CONFIGS: - with self.subTest(name): - # Without past - onnx_config_default = OnnxConfigWithPast(config()) - self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") - self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present") - self.assertFalse( - onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past" - ) - - # With past - onnx_config_default = OnnxConfigWithPast.with_past(config()) - self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") - self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present") - self.assertTrue( - onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past" - ) - - def _get_models_to_test(export_models_dict: Dict): models_to_test = [] if is_torch_available() or is_tf_available(): @@ -625,8 +576,8 @@ def test_custom_export_trust_remote(self, fn_get_submodels): onnx_config = CustomMPTOnnxConfig( config=config, task="text-generation", + use_past=True, use_past_in_inputs=False, - use_present_in_outputs=True, ) onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True)