diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index f4e54752115..4bbfb2eda2a 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]): Returns: `ORTModel`: the model placed on the requested device. """ + device, provider_options = parse_device(device) provider = get_provider_for_device(device) validate_provider_availability(provider) # raise error if the provider is not available - self.device = device + + if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": + return self + self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) self.unet.session.set_providers([provider], provider_options=[provider_options]) @@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]): self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) self.providers = self.vae_decoder.session.get_providers() + self._device = device + return self @classmethod diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 3b1af05d0f5..09489eb7527 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1124,12 +1124,13 @@ def to(self, device: Union[torch.device, str, int]): provider = get_provider_for_device(device) validate_provider_availability(provider) # raise error if the provider is not available - self.device = device self.encoder.session.set_providers([provider], provider_options=[provider_options]) self.decoder.session.set_providers([provider], provider_options=[provider_options]) if self.decoder_with_past is not None: self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) + self.providers = self.encoder.session.get_providers() + self._device = device return self