diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 16461dce957..121273dac93 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -24,6 +24,7 @@ from ..utils import NormalizedConfigManager from ..utils.logging import warn_once +from .io_binding import TypeHelper from .modeling_ort import ORTModel from .utils import get_ordered_input_names, logging @@ -62,6 +63,20 @@ def __init__( def device(self): return self.parent_model.device + @property + def dtype(self): + for dtype in self.input_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + for dtype in self.output_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + return None + @abstractmethod def forward(self, *args, **kwargs): pass diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index f22ec235f5b..2a8764117e8 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -59,7 +59,6 @@ from .io_binding import TypeHelper from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( - _ORT_TO_NP_TYPE, ONNX_WEIGHTS_NAME, get_provider_for_device, parse_device, @@ -441,24 +440,6 @@ def _from_transformers( model_save_dir=save_dir, ) - @property - def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the model. - """ - - for dtype in self.unet.input_dtypes.values(): - torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) - if torch_dtype.is_floating_point: - return torch_dtype - - for dtype in self.unet.output_dtypes.values(): - torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) - if torch_dtype.is_floating_point: - return torch_dtype - - return None - def to(self, device: Union[torch.device, str, int]): """ Changes the ONNX Runtime provider according to the device. @@ -522,6 +503,20 @@ def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): def device(self): return self.parent_model.device + @property + def dtype(self): + for dtype in self.input_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + for dtype in self.output_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + return None + @abstractmethod def forward(self, *args, **kwargs): pass diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 61b658ecbaf..6bf0e1a20a9 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -57,7 +57,6 @@ DECODER_WITH_PAST_ONNX_FILE_PATTERN, ENCODER_ONNX_FILE_PATTERN, ) -from .io_binding import TypeHelper from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( ONNX_DECODER_NAME, @@ -1111,17 +1110,7 @@ def dtype(self) -> torch.dtype: `torch.dtype`: The dtype of the model. """ - for dtype in self.encoder.input_dtypes: - torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) - if torch_dtype.is_floating_point: - return torch_dtype - - for dtype in self.encoder.output_dtypes: - torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) - if torch_dtype.is_floating_point: - return torch_dtype - - return None + return self.encoder.dtype or self.decoder.dtype def to(self, device: Union[torch.device, str, int]): """