diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 4d859be5ea..9408483166 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -129,8 +129,12 @@ def to(self, device: str): Use the specified `device` for inference. For example: "cpu" or "gpu". `device` can be in upper or lower case. To speed up first inference, call `.compile()` after `.to()`. """ - self._device = str(device).upper() - self.request = None + if isinstance(device, str): + self._device = device.upper() + self.request = None + else: + logger.warning(f"device must be of type {str} but got {type(device)} instead") + return self def forward(self, *args, **kwargs): diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 154580812b..41f06936f8 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -88,7 +88,7 @@ def __init__( **kwargs, ): self._internal_dict = config - self._device = str(device).upper() + self._device = device.upper() self.is_dynamic = dynamic_shapes self.ov_config = ov_config if ov_config is not None else {} self._model_save_dir = ( @@ -330,8 +330,12 @@ def _from_transformers( ) def to(self, device: str): - self._device = device.upper() - self.clear_requests() + if isinstance(device, str): + self._device = device.upper() + self.clear_requests() + else: + logger.warning(f"device must be of type {str} but got {type(device)} instead") + return self @property diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 089399a252..7e9f582799 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -282,12 +282,16 @@ def __init__( pass def to(self, device: str): - self._device = str(device).upper() - self.encoder._device = self._device - self.decoder._device = self._device - if self.use_cache: - self.decoder_with_past._device = self._device - self.clear_requests() + if isinstance(device, str): + self._device = device.upper() + self.encoder._device = self._device + self.decoder._device = self._device + if self.use_cache: + self.decoder_with_past._device = self._device + self.clear_requests() + else: + logger.warning(f"device must be of type {str} but got {type(device)} instead") + return self @add_start_docstrings_to_model_forward(