From 4518691596c4a6c9a9994facbb5b1eb2162388d2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 23 Oct 2024 12:27:45 +0200 Subject: [PATCH] flux transformer support, unet export fixes, updated callback test, updated negative prompt test, flux and sd3 tests --- optimum/exporters/onnx/model_configs.py | 178 ++++++++++++++----- optimum/exporters/tasks.py | 25 +-- optimum/exporters/utils.py | 107 +++++++---- optimum/onnxruntime/modeling_diffusion.py | 66 ++++--- optimum/utils/input_generators.py | 4 +- tests/exporters/exporters_utils.py | 1 + tests/onnxruntime/test_diffusion.py | 141 +++++++-------- tests/onnxruntime/utils_onnxruntime_tests.py | 3 +- 8 files changed, 331 insertions(+), 194 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 28f894fc22c..afe87c772da 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -53,6 +53,7 @@ NormalizedTextConfig, NormalizedTextConfigWithGQA, NormalizedVisionConfig, + check_if_diffusers_greater, check_if_transformers_greater, is_diffusers_available, logging, @@ -1031,7 +1032,7 @@ def patch_model_for_export( class UNetOnnxConfig(VisionOnnxConfig): - ATOL_FOR_VALIDATION = 1e-3 + ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 @@ -1054,17 +1055,19 @@ class UNetOnnxConfig(VisionOnnxConfig): def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = { "sample": {0: "batch_size", 2: "height", 3: "width"}, - "timestep": {0: "steps"}, + "timestep": {}, # a scalar with no dimension "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, } - # TODO : add text_image, image and image_embeds + # TODO : add addition_embed_type == text_image, image and image_embeds + # https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671 if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": common_inputs["text_embeds"] = {0: "batch_size"} common_inputs["time_ids"] = {0: "batch_size"} if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: common_inputs["timestep_cond"] = {0: "batch_size"} + return common_inputs @property @@ -1151,73 +1154,168 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } -class PooledProjectionsDummyInputGenerator(DummyInputGenerator): - SUPPORTED_INPUT_NAMES = "pooled_projections" - - def __init__( - self, - task: str, - normalized_config: NormalizedConfig, - batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], - **kwargs, - ): - self.task = task - self.batch_size = batch_size - self.pooled_projection_dim = normalized_config.config.pooled_projection_dim +class T5EncoderOnnxConfig(CLIPTextOnnxConfig): + @property + def inputs(self): + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - return self.random_float_tensor( - [self.batch_size, self.pooled_projection_dim], framework=framework, dtype=float_dtype - ) + @property + def outputs(self): + return { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator): + SUPPORTED_INPUT_NAMES = ("timestep",) + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): if input_name == "timestep": - shape = [self.batch_size] + shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states",) + + +class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "encoder_hidden_states", + "pooled_projection", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "encoder_hidden_states": + return super().generate(input_name, framework, int_dtype, float_dtype)[0] + + elif input_name == "pooled_projections": + return self.random_float_tensor( + [self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype + ) + return super().generate(input_name, framework, int_dtype, float_dtype) -class SD3TransformerOnnxConfig(UNetOnnxConfig): +class SD3TransformerOnnxConfig(VisionOnnxConfig): + ATOL_FOR_VALIDATION = 1e-4 + # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu + # operator support, available since opset 14 + DEFAULT_ONNX_OPSET = 14 + DUMMY_INPUT_GENERATOR_CLASSES = ( - (DummyTransformerTimestpsInputGenerator,) - + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES - + (PooledProjectionsDummyInputGenerator,) + DummyTransformerTimestpsInputGenerator, + DummyTransformerVisionInputGenerator, + DummyTransformerTextInputGenerator, ) + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( image_size="sample_size", num_channels="in_channels", - hidden_size="joint_attention_dim", vocab_size="attention_head_dim", + hidden_size="joint_attention_dim", + projection_size="pooled_projection_dim", allow_new=True, ) @property - def inputs(self): - common_inputs = super().inputs - common_inputs["pooled_projections"] = {0: "batch_size"} - return common_inputs + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "hidden_states": {0: "batch_size", 2: "height", 3: "width"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "step"}, + } - def rename_ambiguous_inputs(self, inputs): - # The input name in the model signature is `x, hence the export input name is updated. - hidden_states = inputs.pop("sample", None) - if hidden_states is not None: - inputs["hidden_states"] = hidden_states - return inputs + return common_inputs + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "out_hidden_states": {0: "batch_size", 2: "height", 3: "width"}, + } -class T5EncoderOnnxConfig(CLIPTextOnnxConfig): @property - def inputs(self): + def torch_to_onnx_output_map(self) -> Dict[str, str]: return { - "input_ids": {0: "batch_size", 1: "sequence_length"}, + "sample": "out_hidden_states", } + +class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "hidden_states", + "img_ids", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "hidden_states": + shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "img_ids": + shape = ( + [(self.height // 2) * (self.width // 2), 3] + if check_if_diffusers_greater("0.31.0") + else [self.batch_size, (self.height // 2) * (self.width // 2), 3] + ) + return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "encoder_hidden_states", + "pooled_projections", + "txt_ids", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "txt_ids": + shape = ( + [self.sequence_length, 3] + if check_if_diffusers_greater("0.31.0") + else [self.batch_size, self.sequence_length, 3] + ) + return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTransformerTimestpsInputGenerator, + DummyFluxTransformerVisionInputGenerator, + DummyFluxTransformerTextInputGenerator, + ) + + @property + def inputs(self): + common_inputs = super().inputs + + common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"} + common_inputs["txt_ids"] = ( + {0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"} + ) + common_inputs["img_ids"] = ( + {0: "packed_height_width"} + if check_if_diffusers_greater("0.31.0") + else {0: "batch_size", 1: "packed_height_width"} + ) + + if getattr(self._normalized_config, "guidance_embeds", False): + common_inputs["guidance"] = {0: "batch_size"} + + return common_inputs + @property def outputs(self): return { - "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + "out_hidden_states": {0: "batch_size", 1: "packed_height_width"}, } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 82417886f26..553dcc1a27a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -335,7 +335,7 @@ class TasksManager: } _DIFFUSERS_SUPPORTED_MODEL_TYPE = { - "t5-encoder": supported_tasks_mapping( + "t5-encoder-model": supported_tasks_mapping( "feature-extraction", onnx="T5EncoderOnnxConfig", ), @@ -343,18 +343,22 @@ class TasksManager: "feature-extraction", onnx="CLIPTextOnnxConfig", ), - "clip-text-with-projection": supported_tasks_mapping( + "clip-text-model-with-projection": supported_tasks_mapping( "feature-extraction", onnx="CLIPTextWithProjectionOnnxConfig", ), - "unet": supported_tasks_mapping( + "flux-transformer-2d-model": supported_tasks_mapping( "semantic-segmentation", - onnx="UNetOnnxConfig", + onnx="FluxTransformerOnnxConfig", ), - "sd3-transformer": supported_tasks_mapping( + "sd3-transformer-2d-model": supported_tasks_mapping( "semantic-segmentation", onnx="SD3TransformerOnnxConfig", ), + "unet-2d-condition": supported_tasks_mapping( + "semantic-segmentation", + onnx="UNetOnnxConfig", + ), "vae-encoder": supported_tasks_mapping( "semantic-segmentation", onnx="VaeEncoderOnnxConfig", @@ -1178,12 +1182,13 @@ class TasksManager: "transformers": _SUPPORTED_MODEL_TYPE, } _UNSUPPORTED_CLI_MODEL_TYPE = { + # diffusers submodels "clip-text-model", - "clip-text-with-projection", - "sd3-transformer", - "t5-encoder", - "trocr", # supported through the vision-encoder-decoder model type - "unet", + "clip-text-model-with-projection", + "flux-transformer-model", + "sd3-transformer-2d-model", + "t5-encoder-model", + "unet-2d-condition", "vae-encoder", "vae-decoder", } diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 9d37c7a996b..52ad986dd87 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -21,7 +21,9 @@ import torch from packaging import version +from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection from transformers.models.speecht5.modeling_speecht5 import SpeechT5HifiGan +from transformers.models.t5.modeling_t5 import T5EncoderModel from transformers.utils import is_tf_available, is_torch_available from ..utils import ( @@ -53,10 +55,20 @@ if check_if_diffusers_greater("0.30.0"): from diffusers import ( + FluxTransformer2DModel, + SD3Transformer2DModel, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline, + UNet2DConditionModel, ) + else: + FluxTransformer2DModel = None + SD3Transformer2DModel = None + StableDiffusion3Img2ImgPipeline = None + StableDiffusion3InpaintPipeline = None + StableDiffusion3Pipeline = None + UNet2DConditionModel = None from diffusers.models.attention_processor import ( Attention, @@ -88,6 +100,23 @@ DECODER_MERGED_NAME = "decoder_model_merged" +def _get_diffusers_model_type(model): + if isinstance(model, CLIPTextModel): + return "clip-text-model" + elif isinstance(model, CLIPTextModelWithProjection): + return "clip-text-model-with-projection" + elif isinstance(model, T5EncoderModel): + return "t5-encoder-model" + elif isinstance(model, FluxTransformer2DModel): + return "flux-transformer-2d-model" + elif isinstance(model, SD3Transformer2DModel): + return "sd3-transformer-2d-model" + elif isinstance(model, UNet2DConditionModel): + return "unet-2d-condition" + else: + raise ValueError(f"Unknown model class: {model.__class__}") + + def _get_submodels_for_export_diffusion( pipeline: "DiffusionPipeline", ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: @@ -97,11 +126,13 @@ def _get_submodels_for_export_diffusion( models_for_export = {} - is_stable_diffusion_xl = isinstance( - pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline) + is_sdxl = isinstance( + pipeline, + (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline), ) - is_stable_diffusion_3 = isinstance( - pipeline, (StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline) + is_sd3 = isinstance( + pipeline, + (StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline), ) is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") @@ -109,30 +140,27 @@ def _get_submodels_for_export_diffusion( # Text encoder text_encoder = getattr(pipeline, "text_encoder", None) if text_encoder is not None: - if is_stable_diffusion_xl or is_stable_diffusion_3: + if is_sdxl or is_sd3: text_encoder.config.output_hidden_states = True text_encoder.text_model.config.output_hidden_states = True - if is_stable_diffusion_3: - text_encoder.config.export_model_type = "clip-text-with-projection" - else: - text_encoder.config.export_model_type = "clip-text-model" - + text_encoder.config.export_model_type = _get_diffusers_model_type(text_encoder) models_for_export["text_encoder"] = text_encoder # Text encoder 2 text_encoder_2 = getattr(pipeline, "text_encoder_2", None) if text_encoder_2 is not None: - text_encoder_2.config.output_hidden_states = True - text_encoder_2.text_model.config.output_hidden_states = True - text_encoder_2.config.export_model_type = "clip-text-with-projection" + if is_sdxl or is_sd3: + text_encoder_2.config.output_hidden_states = True + text_encoder_2.text_model.config.output_hidden_states = True + text_encoder_2.config.export_model_type = _get_diffusers_model_type(text_encoder_2) models_for_export["text_encoder_2"] = text_encoder_2 # Text encoder 3 text_encoder_3 = getattr(pipeline, "text_encoder_3", None) if text_encoder_3 is not None: - text_encoder_3.config.export_model_type = "t5-encoder" + text_encoder_3.config.export_model_type = _get_diffusers_model_type(text_encoder_3) models_for_export["text_encoder_3"] = text_encoder_3 # U-NET @@ -147,7 +175,7 @@ def _get_submodels_for_export_diffusion( unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) unet.config.time_cond_proj_dim = getattr(pipeline.unet.config, "time_cond_proj_dim", None) unet.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim - unet.config.export_model_type = "unet" + unet.config.export_model_type = _get_diffusers_model_type(unet) models_for_export["unet"] = unet # Transformer @@ -160,7 +188,7 @@ def _get_submodels_for_export_diffusion( transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None) transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim - transformer.config.export_model_type = "sd3-transformer" + transformer.config.export_model_type = _get_diffusers_model_type(transformer) models_for_export["transformer"] = transformer # VAE Encoder @@ -346,55 +374,54 @@ def get_diffusion_models_for_export( # Text encoder if "text_encoder" in models_for_export: + text_encoder = models_for_export["text_encoder"] text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction" + model=text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction" ) text_encoder_export_config = text_encoder_config_constructor( - pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype + text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype ) models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config) # Text encoder 2 if "text_encoder_2" in models_for_export: + text_encoder_2 = models_for_export["text_encoder_2"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction" - ) - export_config = export_config_constructor( - pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype + model=text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction" ) + export_config = export_config_constructor(text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config) # Text encoder 3 if "text_encoder_3" in models_for_export: + text_encoder_3 = models_for_export["text_encoder_3"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction" - ) - export_config = export_config_constructor( - pipeline.text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype + model=text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction" ) + export_config = export_config_constructor(text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["text_encoder_3"] = (models_for_export["text_encoder_3"], export_config) # U-NET if "unet" in models_for_export: + unet = models_for_export["unet"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation" - ) - unet_export_config = export_config_constructor( - pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype + model=unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation" ) + unet_export_config = export_config_constructor(unet.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["unet"] = (models_for_export["unet"], unet_export_config) # Transformer if "transformer" in models_for_export: + transformer = models_for_export["transformer"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation" + model=transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation" ) transformer_export_config = export_config_constructor( - pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype + transformer.config, int_dtype=int_dtype, float_dtype=float_dtype ) models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config) - # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 + # VAE Encoder vae_encoder = models_for_export["vae_encoder"] vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_encoder, @@ -403,10 +430,12 @@ def get_diffusion_models_for_export( task="semantic-segmentation", model_type="vae-encoder", ) - vae_export_config = vae_config_constructor(vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype) - models_for_export["vae_encoder"] = (vae_encoder, vae_export_config) + vae_encoder_export_config = vae_config_constructor( + vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config) - # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 + # VAE Decoder vae_decoder = models_for_export["vae_decoder"] vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_decoder, @@ -415,8 +444,10 @@ def get_diffusion_models_for_export( task="semantic-segmentation", model_type="vae-decoder", ) - vae_export_config = vae_config_constructor(vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype) - models_for_export["vae_decoder"] = (vae_decoder, vae_export_config) + vae_decoder_export_config = vae_config_constructor( + vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config) return models_for_export diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 6df9816e056..c3d40f1ce1a 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -30,6 +30,7 @@ AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + FluxPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, StableDiffusion3Img2ImgPipeline, @@ -481,9 +482,13 @@ def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTDiffusionP self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()} self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()} + self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()} + self.output_shapes = {output_key.name: output_key.shape for output_key in self.session.get_outputs()} + config_file_path = Path(session._model_path).parent / self.config_name if not config_file_path.is_file(): # config is mandatory for the model part to be used for inference @@ -581,13 +586,18 @@ def __init__(self, *args, **kwargs): ) self.register_to_config(time_cond_proj_dim=None) + if len(self.input_shapes["timestep"]) > 0: + logger.warning( + "The exported unet onnx model expects a non scalar timestep input. " + "We will have to unsqueeze the timestep input at each iteration which might be inefficient. " + "Please re-export the pipeline with newer version of optimum and diffusers to avoid this warning." + ) + def forward( self, sample: Union[np.ndarray, torch.Tensor], timestep: Union[np.ndarray, torch.Tensor], encoder_hidden_states: Union[np.ndarray, torch.Tensor], - text_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None, - time_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, Any]] = None, @@ -595,15 +605,13 @@ def forward( ): use_torch = isinstance(sample, torch.Tensor) - if len(timestep.shape) == 0: + if len(self.input_shapes["timestep"]) > 0: timestep = timestep.unsqueeze(0) model_inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, - "text_embeds": text_embeds, - "time_ids": time_ids, "timestep_cond": timestep_cond, **(cross_attention_kwargs or {}), **(added_cond_kwargs or {}), @@ -623,22 +631,25 @@ class ORTModelTransformer(ORTPipelinePart): def forward( self, hidden_states: Union[np.ndarray, torch.Tensor], - timestep: Union[np.ndarray, torch.Tensor], encoder_hidden_states: Union[np.ndarray, torch.Tensor], pooled_projections: Union[np.ndarray, torch.Tensor], + timestep: Union[np.ndarray, torch.Tensor], + guidance: Optional[Union[np.ndarray, torch.Tensor]] = None, + txt_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, + img_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ): use_torch = isinstance(hidden_states, torch.Tensor) - if len(timestep.shape) == 0: - timestep = timestep.unsqueeze(0) - model_inputs = { "hidden_states": hidden_states, - "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, "pooled_projections": pooled_projections, + "timestep": timestep, + "guidance": guidance, + "txt_ids": txt_ids, + "img_ids": img_ids, **(joint_attention_kwargs or {}), } @@ -670,16 +681,12 @@ def forward( if output_hidden_states: model_outputs["hidden_states"] = [] - num_layers = ( - self.config.num_hidden_layers if hasattr(self.config, "num_hidden_layers") else self.num_decoder_layers - ) + num_layers = self.num_hidden_layers if hasattr(self, "num_hidden_layers") else self.num_decoder_layers for i in range(num_layers): model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}")) model_outputs["hidden_states"].append(model_outputs.get("last_hidden_state")) else: - num_layers = ( - self.config.num_hidden_layers if hasattr(self.config, "num_hidden_layers") else self.num_decoder_layers - ) + num_layers = self.num_hidden_layers if hasattr(self, "num_hidden_layers") else self.num_decoder_layers for i in range(num_layers): model_outputs.pop(f"hidden_states.{i}", None) @@ -697,7 +704,7 @@ def __init__(self, *args, **kwargs): if not hasattr(self.config, "scaling_factor"): logger.warning( "The `scaling_factor` attribute is missing from the VAE encoder configuration. " - "Please re-export the model with newer version of optimum and diffusers." + "Please re-export the model with newer version of optimum and diffusers to avoid this warning." ) self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1)) @@ -737,7 +744,7 @@ def __init__(self, *args, **kwargs): if not hasattr(self.config, "scaling_factor"): logger.warning( "The `scaling_factor` attribute is missing from the VAE decoder configuration. " - "Please re-export the model with newer version of optimum and diffusers." + "Please re-export the model with newer version of optimum and diffusers to avoid this warning." ) self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1)) @@ -981,6 +988,17 @@ class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3I auto_model_class = StableDiffusion3InpaintPipeline +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.FluxPipeline](https://huggingface.co/docs/diffusers/api/pipelines/flux/text2img#diffusers.FluxPipeline). + """ + + main_input_name = "prompt" + export_feature = "text-to-image" + auto_model_class = FluxPipeline + + SUPPORTED_ORT_PIPELINES = [ ORTStableDiffusionPipeline, ORTStableDiffusionImg2ImgPipeline, @@ -993,6 +1011,7 @@ class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3I ORTStableDiffusion3Pipeline, ORTStableDiffusion3Img2ImgPipeline, ORTStableDiffusion3InpaintPipeline, + ORTFluxPipeline, ] @@ -1010,27 +1029,28 @@ def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tr ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( [ - ("stable-diffusion", ORTStableDiffusionPipeline), - ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), + ("flux", ORTFluxPipeline), ("latent-consistency", ORTLatentConsistencyModelPipeline), + ("stable-diffusion", ORTStableDiffusionPipeline), ("stable-diffusion-3", ORTStableDiffusion3Pipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), ] ) ORT_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( [ - ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), - ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), ("latent-consistency", ORTLatentConsistencyModelImg2ImgPipeline), + ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), ("stable-diffusion-3", ORTStableDiffusion3Img2ImgPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), ] ) ORT_INPAINT_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", ORTStableDiffusionInpaintPipeline), - ("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline), ("stable-diffusion-3", ORTStableDiffusion3InpaintPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline), ] ) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index dac14a38114..dd29821766a 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -871,8 +871,8 @@ def __init__( def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): if input_name == "timestep": - shape = [self.batch_size] - return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype) + shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture) + return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) if input_name == "text_embeds": dim = self.text_encoder_projection_dim diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9617ab37a00..e7fe1035dc1 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -296,6 +296,7 @@ } PYTORCH_DIFFUSION_MODEL = { + "flux": "optimum-internal-testing/tiny-random-flux", "latent-consistency": "echarlaix/tiny-random-latent-consistency", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random", diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index 38e9b2b5391..5d70bb2de1d 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -71,8 +71,25 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type= class ORTPipelineForText2ImageTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency", "stable-diffusion-3"] - CALLBACK_SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + "stable-diffusion-xl", + "latent-consistency", + "stable-diffusion-3", + "flux", + ] + NEGATIVE_PROMPT_SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + "stable-diffusion-xl", + "latent-consistency", + "stable-diffusion-3", + ] + CALLBACK_SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + "stable-diffusion-xl", + "latent-consistency", + "flux", + ] ORTMODEL_CLASS = ORTPipelineForText2Image AUTOMODEL_CLASS = AutoPipelineForText2Image @@ -143,10 +160,10 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand(CALLBACK_SUPPORTED_ARCHITECTURES) @require_diffusers @@ -165,6 +182,7 @@ def __init__(self): def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 + return kwargs ort_callback = Callback() auto_callback = Callback() @@ -172,9 +190,8 @@ def __call__(self, *args, **kwargs) -> None: ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) - # callback_steps=1 to trigger callback every step - ort_pipe(**inputs, callback=ort_callback, callback_steps=1) - auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + ort_pipe(**inputs, callback_on_step_end=ort_callback) + auto_pipe(**inputs, callback_on_step_end=auto_callback) self.assertTrue(ort_callback.has_been_called) self.assertTrue(auto_callback.has_been_called) @@ -201,20 +218,20 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: - out_channels = ( - pipeline.unet.config.out_channels - if pipeline.unet is not None - else pipeline.transformer.config.out_channels - ) - self.assertEqual( - outputs.shape, - ( - batch_size, - out_channels, - height // pipeline.vae_scale_factor, - width // pipeline.vae_scale_factor, - ), - ) + expected_height = height // pipeline.vae_scale_factor + expected_width = width // pipeline.vae_scale_factor + + if model_arch == "flux": + channels = pipeline.transformer.config.in_channels + expected_shape = (batch_size, expected_height * expected_width, channels) + elif model_arch == "stable-diffusion-3": + out_channels = pipeline.transformer.config.out_channels + expected_shape = (batch_size, out_channels, expected_height, expected_width) + else: + out_channels = pipeline.unet.config.out_channels + expected_shape = (batch_size, out_channels, expected_height, expected_width) + + self.assertEqual(outputs.shape, expected_shape) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -235,60 +252,22 @@ def test_image_reproducibility(self, model_arch: str): self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) np.testing.assert_allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-4, rtol=1e-2) - @parameterized.expand(SUPPORTED_ARCHITECTURES) + @parameterized.expand(NEGATIVE_PROMPT_SUPPORTED_ARCHITECTURES) def test_negative_prompt(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) height, width, batch_size = 64, 64, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["negative_prompt"] = ["This is a negative prompt"] * batch_size - negative_prompt = ["This is a negative prompt"] - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - images_1 = pipeline(**inputs, negative_prompt=negative_prompt, generator=get_generator("pt", SEED)).images - prompt = inputs.pop("prompt") - - if model_arch == "stable-diffusion-xl": - ( - inputs["prompt_embeds"], - inputs["negative_prompt_embeds"], - inputs["pooled_prompt_embeds"], - inputs["negative_pooled_prompt_embeds"], - ) = pipeline.encode_prompt( - prompt=prompt, - num_images_per_prompt=1, - device=torch.device("cpu"), - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - elif model_arch == "stable-diffusion-3": - ( - inputs["prompt_embeds"], - inputs["negative_prompt_embeds"], - inputs["pooled_prompt_embeds"], - inputs["negative_pooled_prompt_embeds"], - ) = pipeline.encode_prompt( - prompt=prompt, - prompt_2=prompt, - prompt_3=prompt, - num_images_per_prompt=1, - device=torch.device("cpu"), - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - else: - inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt( - prompt=prompt, - num_images_per_prompt=1, - device=torch.device("cpu"), - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - - images_2 = pipeline(**inputs, generator=get_generator("pt", SEED)).images - - np.testing.assert_allclose(images_1, images_2, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand( grid_parameters( @@ -425,15 +404,16 @@ def __init__(self): def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 + return kwargs ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) ort_callback = Callback() auto_callback = Callback() - # callback_steps=1 to trigger callback every step - ort_pipe(**inputs, callback=ort_callback, callback_steps=1) - auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + ort_pipe(**inputs, callback_on_step_end=ort_callback) + auto_pipe(**inputs, callback_on_step_end=auto_callback) self.assertTrue(ort_callback.has_been_called) self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) @@ -491,10 +471,10 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -656,15 +636,16 @@ def __init__(self): def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 + return kwargs ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) ort_callback = Callback() auto_callback = Callback() - # callback_steps=1 to trigger callback every step - ort_pipe(**inputs, callback=ort_callback, callback_steps=1) - auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + ort_pipe(**inputs, callback_on_step_end=ort_callback) + auto_pipe(**inputs, callback_on_step_end=auto_callback) self.assertTrue(ort_callback.has_been_called) self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) @@ -722,10 +703,10 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index cb224993ad6..b6f5efaef76 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -98,6 +98,7 @@ }, "falcon": "fxmarty/really-tiny-falcon-testing", "flaubert": "hf-internal-testing/tiny-random-flaubert", + "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", @@ -141,7 +142,7 @@ "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", - "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random", + "stable-diffusion-3": "optimum-internal-testing/tiny-random-stable-diffusion-3", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224",