diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 3bd6a1f7676..d927fee4cea 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -131,8 +131,10 @@ def _get_submodels_for_export_stable_diffusion( models_for_export["text_encoder"] = pipeline.text_encoder # U-NET - # PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention - pipeline.unet.set_attn_processor(AttnProcessor()) + # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 + is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") + if not is_torch_greater_or_equal_than_2_1: + pipeline.unet.set_attn_processor(AttnProcessor()) pipeline.unet.config.text_encoder_projection_dim = projection_dim # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 @@ -141,14 +143,14 @@ def _get_submodels_for_export_stable_diffusion( # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 vae_encoder = copy.deepcopy(pipeline.vae) - if not version.parse(torch.__version__) >= version.parse("2.1.0"): + if not is_torch_greater_or_equal_than_2_1: vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder) vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()} models_for_export["vae_encoder"] = vae_encoder # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 vae_decoder = copy.deepcopy(pipeline.vae) - if not version.parse(torch.__version__) >= version.parse("2.1.0"): + if not is_torch_greater_or_equal_than_2_1: vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder) vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) models_for_export["vae_decoder"] = vae_decoder