Skip to content

Commit

Permalink
apply suggestions + try fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Dec 19, 2024
1 parent bb185a0 commit 35d52d2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
29 changes: 17 additions & 12 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
DTYPE_MAPPER,
DummyInputGenerator,
DummyLabelsGenerator,
DummySeq2SeqPastKeyValuesGenerator,
Expand Down Expand Up @@ -72,8 +73,6 @@
Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.
Args:
framework (`str`, defaults to `"pt"`):
The framework for which to create the dummy inputs.
batch_size (`int`, defaults to {batch_size}):
The batch size to use in the dummy inputs.
sequence_length (`int`, defaults to {sequence_length}):
Expand Down Expand Up @@ -467,8 +466,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype)
float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype)
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
input_name, int_dtype=int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
Expand Down Expand Up @@ -679,6 +680,8 @@ def overwrite_shape_and_generate_input(

# TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs.
# This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models.
int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype)
float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype)
if (
self.use_past
and self.use_past_in_inputs
Expand All @@ -689,14 +692,10 @@ def overwrite_shape_and_generate_input(
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
dummy_input_gen.sequence_length = 1
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
dummy_input = dummy_input_gen.generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype)
dummy_input_gen.sequence_length = sequence_length
else:
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
dummy_input = dummy_input_gen.generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype)

return dummy_input

Expand Down Expand Up @@ -740,8 +739,12 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->
return flattened_output

def generate_dummy_inputs_for_validation(
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
self,
reference_model_inputs: Dict[str, Any],
onnx_input_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
int_dtype = DTYPE_MAPPER.pt(self.int_dtype)
float_dtype = DTYPE_MAPPER.pt(self.float_dtype)
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
Expand All @@ -754,7 +757,7 @@ def generate_dummy_inputs_for_validation(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate(
"past_key_values", framework="pt", int_dtype=self.int_dtype, float_dtype=self.float_dtype
"past_key_values", int_dtype=int_dtype, float_dtype=float_dtype
)

return reference_model_inputs
Expand Down Expand Up @@ -1081,12 +1084,14 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
for cls_ in self.DUMMY_EXTRA_INPUT_GENERATOR_CLASSES
]

int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype)
float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype)
for input_name in self._tasks_to_extra_inputs[self.task]:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
input_name, int_dtype=int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
Expand Down
20 changes: 9 additions & 11 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def random_int_tensor(
"The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.",
FutureWarning,
)
dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype
dtype = DummyInputGenerator._get_default_int_dtype() if dtype is None else dtype
framework = DummyInputGenerator.infer_framework_from_dtype(dtype)
if framework == "pt":
return torch.randint(low=min_value, high=max_value, size=shape, dtype=dtype)
Expand Down Expand Up @@ -213,7 +213,7 @@ def random_mask_tensor(
)
shape = tuple(shape)
mask_length = random.randint(1, shape[-1] - 1)
dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype
dtype = DummyInputGenerator._get_default_int_dtype() if dtype is None else dtype
framework = DummyInputGenerator.infer_framework_from_dtype(dtype)
if framework == "pt":
mask_tensor = torch.cat(
Expand Down Expand Up @@ -278,7 +278,7 @@ def random_float_tensor(
"The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.",
FutureWarning,
)
dtype = DummyInputGenerator._set_default_float_dtype() if dtype is None else dtype
dtype = DummyInputGenerator._get_default_float_dtype() if dtype is None else dtype
framework = DummyInputGenerator.infer_framework_from_dtype(dtype)
if framework == "pt":
tensor = torch.empty(shape, dtype=dtype).uniform_(min_value, max_value)
Expand Down Expand Up @@ -316,7 +316,7 @@ def constant_tensor(
"The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.",
FutureWarning,
)
dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype
dtype = DummyInputGenerator._get_default_int_dtype() if dtype is None else dtype
framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework
if framework == "pt":
return torch.full(shape, value, dtype=dtype)
Expand All @@ -338,8 +338,7 @@ def _infer_framework_from_input(input_) -> str:
raise RuntimeError(f"Could not infer the framework from {input_}")
return framework

@staticmethod
def _set_default_int_dtype():
def _get_default_int_dtype(self):
"Default to int64 of available framework."
if is_torch_available():
return torch.int64
Expand All @@ -348,15 +347,14 @@ def _set_default_int_dtype():
else:
return np.int64

@staticmethod
def _set_default_float_dtype():
def _get_default_float_dtype(self):
"Default to float32 of available framework."
if is_torch_available():
return torch.int64
return torch.float32
elif is_tf_available():
return tf.int64
return tf.float32
else:
return np.int64
return np.float32

@staticmethod
def infer_framework_from_dtype(dtype):
Expand Down

0 comments on commit 35d52d2

Please sign in to comment.