diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7e35691d54b..c0a42d7afd3 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -39,6 +39,7 @@ from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, + DTYPE_MAPPER, DummyInputGenerator, DummyLabelsGenerator, DummySeq2SeqPastKeyValuesGenerator, @@ -72,8 +73,6 @@ Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used. Args: - framework (`str`, defaults to `"pt"`): - The framework for which to create the dummy inputs. batch_size (`int`, defaults to {batch_size}): The batch size to use in the dummy inputs. sequence_length (`int`, defaults to {sequence_length}): @@ -467,8 +466,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict: input_was_inserted = False for dummy_input_gen in dummy_inputs_generators: if dummy_input_gen.supports_input(input_name): + int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype) + float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype) dummy_inputs[input_name] = dummy_input_gen.generate( - input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + input_name, int_dtype=int_dtype, float_dtype=float_dtype ) input_was_inserted = True break @@ -679,6 +680,8 @@ def overwrite_shape_and_generate_input( # TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs. # This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models. + int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype) + float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype) if ( self.use_past and self.use_past_in_inputs @@ -689,14 +692,10 @@ def overwrite_shape_and_generate_input( sequence_length = dummy_input_gen.sequence_length # Use a sequence length of 1 when the KV cache is already populated. dummy_input_gen.sequence_length = 1 - dummy_input = dummy_input_gen.generate( - input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype - ) + dummy_input = dummy_input_gen.generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype) dummy_input_gen.sequence_length = sequence_length else: - dummy_input = dummy_input_gen.generate( - input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype - ) + dummy_input = dummy_input_gen.generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype) return dummy_input @@ -740,8 +739,12 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> return flattened_output def generate_dummy_inputs_for_validation( - self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None + self, + reference_model_inputs: Dict[str, Any], + onnx_input_names: Optional[List[str]] = None, ) -> Dict[str, Any]: + int_dtype = DTYPE_MAPPER.pt(self.int_dtype) + float_dtype = DTYPE_MAPPER.pt(self.float_dtype) if self.is_merged is True and self.use_cache_branch is True: reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True) elif self.is_merged is True and self.use_cache_branch is False: @@ -754,7 +757,7 @@ def generate_dummy_inputs_for_validation( task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size ) reference_model_inputs["past_key_values"] = pkv_generator.generate( - "past_key_values", framework="pt", int_dtype=self.int_dtype, float_dtype=self.float_dtype + "past_key_values", int_dtype=int_dtype, float_dtype=float_dtype ) return reference_model_inputs @@ -1081,12 +1084,14 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): for cls_ in self.DUMMY_EXTRA_INPUT_GENERATOR_CLASSES ] + int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype) + float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype) for input_name in self._tasks_to_extra_inputs[self.task]: input_was_inserted = False for dummy_input_gen in dummy_inputs_generators: if dummy_input_gen.supports_input(input_name): dummy_inputs[input_name] = dummy_input_gen.generate( - input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + input_name, int_dtype=int_dtype, float_dtype=float_dtype ) input_was_inserted = True break diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 01309ff9f5e..8f9661b1d02 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -174,7 +174,7 @@ def random_int_tensor( "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", FutureWarning, ) - dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype + dtype = DummyInputGenerator._get_default_int_dtype() if dtype is None else dtype framework = DummyInputGenerator.infer_framework_from_dtype(dtype) if framework == "pt": return torch.randint(low=min_value, high=max_value, size=shape, dtype=dtype) @@ -213,7 +213,7 @@ def random_mask_tensor( ) shape = tuple(shape) mask_length = random.randint(1, shape[-1] - 1) - dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype + dtype = DummyInputGenerator._get_default_int_dtype() if dtype is None else dtype framework = DummyInputGenerator.infer_framework_from_dtype(dtype) if framework == "pt": mask_tensor = torch.cat( @@ -278,7 +278,7 @@ def random_float_tensor( "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", FutureWarning, ) - dtype = DummyInputGenerator._set_default_float_dtype() if dtype is None else dtype + dtype = DummyInputGenerator._get_default_float_dtype() if dtype is None else dtype framework = DummyInputGenerator.infer_framework_from_dtype(dtype) if framework == "pt": tensor = torch.empty(shape, dtype=dtype).uniform_(min_value, max_value) @@ -316,7 +316,7 @@ def constant_tensor( "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", FutureWarning, ) - dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype + dtype = DummyInputGenerator._get_default_int_dtype() if dtype is None else dtype framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework if framework == "pt": return torch.full(shape, value, dtype=dtype) @@ -338,8 +338,7 @@ def _infer_framework_from_input(input_) -> str: raise RuntimeError(f"Could not infer the framework from {input_}") return framework - @staticmethod - def _set_default_int_dtype(): + def _get_default_int_dtype(self): "Default to int64 of available framework." if is_torch_available(): return torch.int64 @@ -348,15 +347,14 @@ def _set_default_int_dtype(): else: return np.int64 - @staticmethod - def _set_default_float_dtype(): + def _get_default_float_dtype(self): "Default to float32 of available framework." if is_torch_available(): - return torch.int64 + return torch.float32 elif is_tf_available(): - return tf.int64 + return tf.float32 else: - return np.int64 + return np.float32 @staticmethod def infer_framework_from_dtype(dtype):