Skip to content

Commit

Permalink
Fix ov device (#530)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Jan 25, 2024
1 parent 5e9c1b7 commit d96ebfa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
8 changes: 6 additions & 2 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,12 @@ def to(self, device: str):
Use the specified `device` for inference. For example: "cpu" or "gpu". `device` can
be in upper or lower case. To speed up first inference, call `.compile()` after `.to()`.
"""
self._device = str(device).upper()
self.request = None
if isinstance(device, str):
self._device = device.upper()
self.request = None
else:
logger.warning(f"device must be of type {str} but got {type(device)} instead")

return self

def forward(self, *args, **kwargs):
Expand Down
10 changes: 7 additions & 3 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
**kwargs,
):
self._internal_dict = config
self._device = str(device).upper()
self._device = device.upper()
self.is_dynamic = dynamic_shapes
self.ov_config = ov_config if ov_config is not None else {}
self._model_save_dir = (
Expand Down Expand Up @@ -330,8 +330,12 @@ def _from_transformers(
)

def to(self, device: str):
self._device = device.upper()
self.clear_requests()
if isinstance(device, str):
self._device = device.upper()
self.clear_requests()
else:
logger.warning(f"device must be of type {str} but got {type(device)} instead")

return self

@property
Expand Down
16 changes: 10 additions & 6 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,16 @@ def __init__(
pass

def to(self, device: str):
self._device = str(device).upper()
self.encoder._device = self._device
self.decoder._device = self._device
if self.use_cache:
self.decoder_with_past._device = self._device
self.clear_requests()
if isinstance(device, str):
self._device = device.upper()
self.encoder._device = self._device
self.decoder._device = self._device
if self.use_cache:
self.decoder_with_past._device = self._device
self.clear_requests()
else:
logger.warning(f"device must be of type {str} but got {type(device)} instead")

return self

@add_start_docstrings_to_model_forward(
Expand Down

0 comments on commit d96ebfa

Please sign in to comment.