Skip to content

Commit

Permalink
trocr kv cache export
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 16, 2023
1 parent c5ad7f9 commit 0cce136
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 54 deletions.
82 changes: 41 additions & 41 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,54 +335,54 @@ def __init__(

self.is_decoder_with_past = False

if self._behavior is not ConfigBehavior.DECODER:
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type
)
self._encoder_onnx_config = encoder_onnx_config_constructor(
config.encoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
)
self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config
# Set up the encoder ONNX config.
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type
)
self._encoder_onnx_config = encoder_onnx_config_constructor(
config.encoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
)
self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config

if self._behavior is not ConfigBehavior.ENCODER:
decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type
)
kwargs = {}
if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast):
self.is_decoder_with_past = True
kwargs["use_past"] = use_past
else:
self.use_past = False
# Set up the decoder ONNX config.
decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type
)
kwargs = {}
if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast):
self.is_decoder_with_past = True
kwargs["use_past"] = use_past
else:
self.use_past = False

if use_past and not self.is_decoder_with_past:
raise ValueError(
f"The decoder part of the encoder-decoder model is {config.decoder.model_type} which does not need "
"past key values."
)
if use_past and not self.is_decoder_with_past:
raise ValueError(
f"The decoder part of the encoder-decoder model is {config.decoder.model_type} which does not need "
"past key values."
)

self._decoder_onnx_config = decoder_onnx_config_constructor(
config.decoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, **kwargs
self._decoder_onnx_config = decoder_onnx_config_constructor(
config.decoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, **kwargs
)
if issubclass(decoder_onnx_config_constructor.func, OnnxSeq2SeqConfigWithPast):
self._decoder_onnx_config = self._decoder_onnx_config.with_behavior(
self._behavior, use_past=kwargs["use_past"], use_past_in_inputs=use_past_in_inputs
)
if issubclass(decoder_onnx_config_constructor.func, OnnxSeq2SeqConfigWithPast):
self._decoder_onnx_config = self._decoder_onnx_config.with_behavior(
self._behavior, use_past=kwargs["use_past"], use_past_in_inputs=use_past_in_inputs
)

self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config

if isinstance(self._decoder_onnx_config, OnnxSeq2SeqConfigWithPast):
self._past_key_values_generator = (
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)
else:
self._past_key_values_generator = (
DummySeq2SeqDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
)
if isinstance(self._decoder_onnx_config, OnnxSeq2SeqConfigWithPast):
self._past_key_values_generator = (
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)
else:
self._past_key_values_generator = (
DummySeq2SeqDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
)

self.DUMMY_INPUT_GENERATOR_CLASSES += self._past_key_values_generator
self.DUMMY_INPUT_GENERATOR_CLASSES += self._past_key_values_generator

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand Down
24 changes: 17 additions & 7 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
TROCRDummyPastKeyValuseGenerator,
logging,
)
from ...utils.normalized_config import NormalizedConfigManager
Expand All @@ -56,7 +57,7 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import SAMModelPatcher, VisionEncoderDecoderPatcher, WavLMModelPatcher


if TYPE_CHECKING:
Expand Down Expand Up @@ -1224,15 +1225,19 @@ class TrOCROnnxConfig(TextSeq2SeqOnnxConfig):
decoder_num_layers="decoder_layers",
num_layers="decoder_layers",
decoder_num_attention_heads="decoder_attention_heads",
hidden_size="cross_attention_hidden_size",
hidden_size="hidden_size",
)


class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
ATOL_FOR_VALIDATION = 1e-3

DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
TROCRDummyPastKeyValuseGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)

def __init__(
self,
Expand All @@ -1256,10 +1261,10 @@ def __init__(
preprocessors=preprocessors,
)

if config.decoder.model_type == "trocr" and use_past:
raise ValueError(
"Exporting TrOCR to ONNX with past key values is not supported with TrOCR model. Please open an issue in Optimum repository."
)
if config.decoder.model_type == "trocr":
self.DUMMY_PKV_GENERATOR_CLASS = TROCRDummyPastKeyValuseGenerator
else:
self.DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -1281,6 +1286,11 @@ def inputs(self) -> Dict[str, Dict[int, str]]:

return common_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return VisionEncoderDecoderPatcher(self, model, model_kwargs=model_kwargs)


class SamOnnxConfig(OnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0")
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ def patched_forward(*args, **kwargs):
self.patched_forward = patched_forward


class VisionEncoderDecoderPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
use_cache = hasattr(self.real_config, "use_past")

if config._behavior == "decoder" and model.config.decoder.model_type == "trocr" and use_cache:
model.decoder.model.decoder.config.use_cache = True


class WavLMModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
FalconDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
TROCRDummyPastKeyValuseGenerator,
)
from .modeling_utils import recurse_getattr, recurse_setattr
from .normalized_config import (
Expand Down
60 changes: 54 additions & 6 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
import numpy as np
from transformers.utils import is_tf_available, is_torch_available

from .normalized_config import NormalizedConfig, NormalizedSeq2SeqConfig, NormalizedTextConfig, NormalizedVisionConfig
from .normalized_config import (
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
)


if is_torch_available():
Expand Down Expand Up @@ -408,7 +414,10 @@ def __init__(
random_num_choices_range=random_num_choices_range,
)

self.hidden_size = normalized_config.hidden_size
if isinstance(normalized_config, NormalizedEncoderDecoderConfig):
self.hidden_size = normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS.hidden_size
else:
self.hidden_size = normalized_config.hidden_size

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name in ["encoder_outputs", "encoder_hidden_states"]:
Expand Down Expand Up @@ -507,17 +516,28 @@ def __init__(
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if isinstance(self.normalized_config, NormalizedEncoderDecoderConfig):
decoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size
encoder_hidden_size = decoder_hidden_size
decoder_num_attention_heads = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_attention_heads
encoder_num_attention_heads = decoder_num_attention_heads # This is used for cross-attention KV cache.
else:
encoder_hidden_size = self.normalized_config.hidden_size
decoder_hidden_size = self.normalized_config.hidden_size
encoder_num_attention_heads = self.normalized_config.encoder_num_attention_heads
decoder_num_attention_heads = self.normalized_config.decoder_num_attention_heads

encoder_shape = (
self.batch_size,
self.normalized_config.encoder_num_attention_heads,
encoder_num_attention_heads,
self.encoder_sequence_length,
self.normalized_config.hidden_size // self.normalized_config.encoder_num_attention_heads,
encoder_hidden_size // encoder_num_attention_heads,
)
decoder_shape = (
self.batch_size,
self.normalized_config.decoder_num_attention_heads,
decoder_num_attention_heads,
self.sequence_length,
self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads,
decoder_hidden_size // decoder_num_attention_heads,
)
return [
(
Expand Down Expand Up @@ -945,3 +965,31 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
)
for _ in range(self.num_layers)
]


class TROCRDummyPastKeyValuseGenerator(DummySeq2SeqPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedSeq2SeqConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
encoder_sequence_length: Optional[int] = None,
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
encoder_sequence_length=encoder_sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)

image_size = normalized_config.encoder.image_size
patch_size = normalized_config.encoder.patch_size
self.encoder_sequence_length = (image_size // patch_size) ** 2 + 1

0 comments on commit 0cce136

Please sign in to comment.