diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index e51943bd4ac..e3754049a4e 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -151,7 +151,7 @@ def __init__( self.use_fp16 = False for inp in model.get_inputs(): - if inp.name == "past_key_values" and inp.type == "tensor(float16)": + if (inp.name == "past_key_values" or inp.name in self.key_value_input_names) and inp.type == "tensor(float16)": self.use_fp16 = True break