Skip to content

Commit

Permalink
use model part dtype instead
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 25, 2024
1 parent 81d0227 commit 82a2879
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 31 deletions.
15 changes: 15 additions & 0 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging

Expand Down Expand Up @@ -62,6 +63,20 @@ def __init__(
def device(self):
return self.parent_model.device

@property
def dtype(self):
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@abstractmethod
def forward(self, *args, **kwargs):
pass
Expand Down
33 changes: 14 additions & 19 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from .io_binding import TypeHelper
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import (
_ORT_TO_NP_TYPE,
ONNX_WEIGHTS_NAME,
get_provider_for_device,
parse_device,
Expand Down Expand Up @@ -441,24 +440,6 @@ def _from_transformers(
model_save_dir=save_dir,
)

@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the model.
"""

for dtype in self.unet.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.unet.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

def to(self, device: Union[torch.device, str, int]):
"""
Changes the ONNX Runtime provider according to the device.
Expand Down Expand Up @@ -522,6 +503,20 @@ def __init__(self, session: ort.InferenceSession, parent_model: ORTModel):
def device(self):
return self.parent_model.device

@property
def dtype(self):
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@abstractmethod
def forward(self, *args, **kwargs):
pass
Expand Down
13 changes: 1 addition & 12 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
ENCODER_ONNX_FILE_PATTERN,
)
from .io_binding import TypeHelper
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import (
ONNX_DECODER_NAME,
Expand Down Expand Up @@ -1111,17 +1110,7 @@ def dtype(self) -> torch.dtype:
`torch.dtype`: The dtype of the model.
"""

for dtype in self.encoder.input_dtypes:
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.encoder.output_dtypes:
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None
return self.encoder.dtype or self.decoder.dtype

def to(self, device: Union[torch.device, str, int]):
"""
Expand Down

0 comments on commit 82a2879

Please sign in to comment.