From 07d9952cfbe3250a6d87c7817d2f7867155c4e74 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 21 Oct 2024 16:17:37 +0200 Subject: [PATCH 1/7] sd3 support --- optimum/exporters/onnx/base.py | 1 + optimum/exporters/onnx/convert.py | 4 + optimum/exporters/onnx/model_configs.py | 81 ++++++++-- optimum/exporters/tasks.py | 8 + optimum/exporters/utils.py | 161 +++++++++++++------ optimum/onnxruntime/modeling_diffusion.py | 138 ++++++++++++++-- optimum/utils/__init__.py | 2 + optimum/utils/constant.py | 4 +- tests/onnxruntime/test_diffusion.py | 66 ++++++-- tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 10 files changed, 385 insertions(+), 81 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8cd94194ffe..7e35691d54b 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -319,6 +319,7 @@ def fix_dynamic_axes( input_shapes = {} dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes) dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names) + dummy_inputs = self.rename_ambiguous_inputs(dummy_inputs) onnx_inputs = {} for name, value in dummy_inputs.items(): diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 2661d835979..c12a9ac222a 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -1183,6 +1183,10 @@ def onnx_export_from_model( if tokenizer_2 is not None: tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) + tokenizer_3 = getattr(model, "tokenizer_3", None) + if tokenizer_3 is not None: + tokenizer_3.save_pretrained(output.joinpath("tokenizer_3")) + model.save_config(output) if float_dtype == "bf16": diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e77f649f69b..28f894fc22c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1015,22 +1015,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, "pooler_output": {0: "batch_size"}, } + if self._normalized_config.output_hidden_states: for i in range(self._normalized_config.num_layers + 1): common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} return common_outputs - def generate_dummy_inputs(self, framework: str = "pt", **kwargs): - dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) - - # TODO: fix should be by casting inputs during inference and not export - if framework == "pt": - import torch - - dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32) - return dummy_inputs - def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], @@ -1160,6 +1151,76 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } +class PooledProjectionsDummyInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = "pooled_projections" + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + **kwargs, + ): + self.task = task + self.batch_size = batch_size + self.pooled_projection_dim = normalized_config.config.pooled_projection_dim + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + return self.random_float_tensor( + [self.batch_size, self.pooled_projection_dim], framework=framework, dtype=float_dtype + ) + + +class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "timestep": + shape = [self.batch_size] + return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class SD3TransformerOnnxConfig(UNetOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + (DummyTransformerTimestpsInputGenerator,) + + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + + (PooledProjectionsDummyInputGenerator,) + ) + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + image_size="sample_size", + num_channels="in_channels", + hidden_size="joint_attention_dim", + vocab_size="attention_head_dim", + allow_new=True, + ) + + @property + def inputs(self): + common_inputs = super().inputs + common_inputs["pooled_projections"] = {0: "batch_size"} + return common_inputs + + def rename_ambiguous_inputs(self, inputs): + # The input name in the model signature is `x, hence the export input name is updated. + hidden_states = inputs.pop("sample", None) + if hidden_states is not None: + inputs["hidden_states"] = hidden_states + return inputs + + +class T5EncoderOnnxConfig(CLIPTextOnnxConfig): + @property + def inputs(self): + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self): + return { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + + class GroupViTOnnxConfig(CLIPOnnxConfig): pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index a489f34fb06..87ab62b2f29 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -335,6 +335,10 @@ class TasksManager: } _DIFFUSERS_SUPPORTED_MODEL_TYPE = { + "t5-encoder": supported_tasks_mapping( + "feature-extraction", + onnx="T5EncoderOnnxConfig", + ), "clip-text-model": supported_tasks_mapping( "feature-extraction", onnx="CLIPTextOnnxConfig", @@ -347,6 +351,10 @@ class TasksManager: "semantic-segmentation", onnx="UNetOnnxConfig", ), + "sd3-transformer": supported_tasks_mapping( + "semantic-segmentation", + onnx="SD3TransformerOnnxConfig", + ), "vae-encoder": supported_tasks_mapping( "semantic-segmentation", onnx="VaeEncoderOnnxConfig", diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 949b54f4685..9d37c7a996b 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -50,6 +50,14 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) + + if check_if_diffusers_greater("0.30.0"): + from diffusers import ( + StableDiffusion3Img2ImgPipeline, + StableDiffusion3InpaintPipeline, + StableDiffusion3Pipeline, + ) + from diffusers.models.attention_processor import ( Attention, AttnAddedKVProcessor, @@ -87,56 +95,95 @@ def _get_submodels_for_export_diffusion( Returns the components of a Stable Diffusion model. """ + models_for_export = {} + is_stable_diffusion_xl = isinstance( pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline) ) - if is_stable_diffusion_xl: - projection_dim = pipeline.text_encoder_2.config.projection_dim - else: - projection_dim = pipeline.text_encoder.config.projection_dim + is_stable_diffusion_3 = isinstance( + pipeline, (StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline) + ) - models_for_export = {} + is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") # Text encoder text_encoder = getattr(pipeline, "text_encoder", None) if text_encoder is not None: - if is_stable_diffusion_xl: + if is_stable_diffusion_xl or is_stable_diffusion_3: text_encoder.config.output_hidden_states = True + text_encoder.text_model.config.output_hidden_states = True + + if is_stable_diffusion_3: + text_encoder.config.export_model_type = "clip-text-with-projection" + else: + text_encoder.config.export_model_type = "clip-text-model" + models_for_export["text_encoder"] = text_encoder - # U-NET - # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 - is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") - if not is_torch_greater_or_equal_than_2_1: - pipeline.unet.set_attn_processor(AttnProcessor()) + # Text encoder 2 + text_encoder_2 = getattr(pipeline, "text_encoder_2", None) + if text_encoder_2 is not None: + text_encoder_2.config.output_hidden_states = True + text_encoder_2.text_model.config.output_hidden_states = True + text_encoder_2.config.export_model_type = "clip-text-with-projection" - pipeline.unet.config.text_encoder_projection_dim = projection_dim - # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` - # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 - pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) - models_for_export["unet"] = pipeline.unet + models_for_export["text_encoder_2"] = text_encoder_2 - # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 + # Text encoder 3 + text_encoder_3 = getattr(pipeline, "text_encoder_3", None) + if text_encoder_3 is not None: + text_encoder_3.config.export_model_type = "t5-encoder" + models_for_export["text_encoder_3"] = text_encoder_3 + + # U-NET + unet = getattr(pipeline, "unet", None) + if unet is not None: + # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 + if not is_torch_greater_or_equal_than_2_1: + unet.set_attn_processor(AttnProcessor()) + + # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` + # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 + unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) + unet.config.time_cond_proj_dim = getattr(pipeline.unet.config, "time_cond_proj_dim", None) + unet.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim + unet.config.export_model_type = "unet" + models_for_export["unet"] = unet + + # Transformer + transformer = getattr(pipeline, "transformer", None) + if transformer is not None: + # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 + if not is_torch_greater_or_equal_than_2_1: + transformer.set_attn_processor(AttnProcessor()) + + transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) + transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None) + transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim + transformer.config.export_model_type = "sd3-transformer" + models_for_export["transformer"] = transformer + + # VAE Encoder vae_encoder = copy.deepcopy(pipeline.vae) + + # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 if not is_torch_greater_or_equal_than_2_1: vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder) + # we return the distribution parameters to be able to recreate it in the decoder vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters} models_for_export["vae_encoder"] = vae_encoder - # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 + # VAE Decoder vae_decoder = copy.deepcopy(pipeline.vae) + + # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 if not is_torch_greater_or_equal_than_2_1: vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder) + vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) models_for_export["vae_decoder"] = vae_decoder - text_encoder_2 = getattr(pipeline, "text_encoder_2", None) - if text_encoder_2 is not None: - text_encoder_2.config.output_hidden_states = True - text_encoder_2.text_model.config.output_hidden_states = True - models_for_export["text_encoder_2"] = text_encoder_2 - return models_for_export @@ -294,31 +341,58 @@ def get_diffusion_models_for_export( `Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `ExportConfig`]: A Dict containing the model and export configs for the different components of the model. """ + models_for_export = _get_submodels_for_export_diffusion(pipeline) # Text encoder if "text_encoder" in models_for_export: text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder, - exporter=exporter, - library_name="diffusers", - task="feature-extraction", + model=pipeline.text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction" ) text_encoder_export_config = text_encoder_config_constructor( pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype ) models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config) + # Text encoder 2 + if "text_encoder_2" in models_for_export: + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction" + ) + export_config = export_config_constructor( + pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config) + + # Text encoder 3 + if "text_encoder_3" in models_for_export: + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction" + ) + export_config = export_config_constructor( + pipeline.text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["text_encoder_3"] = (models_for_export["text_encoder_3"], export_config) + # U-NET - export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.unet, - exporter=exporter, - library_name="diffusers", - task="semantic-segmentation", - model_type="unet", - ) - unet_export_config = export_config_constructor(pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype) - models_for_export["unet"] = (models_for_export["unet"], unet_export_config) + if "unet" in models_for_export: + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation" + ) + unet_export_config = export_config_constructor( + pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["unet"] = (models_for_export["unet"], unet_export_config) + + # Transformer + if "transformer" in models_for_export: + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation" + ) + transformer_export_config = export_config_constructor( + pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config) # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 vae_encoder = models_for_export["vae_encoder"] @@ -344,19 +418,6 @@ def get_diffusion_models_for_export( vae_export_config = vae_config_constructor(vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["vae_decoder"] = (vae_decoder, vae_export_config) - if "text_encoder_2" in models_for_export: - export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder_2, - exporter=exporter, - library_name="diffusers", - task="feature-extraction", - model_type="clip-text-with-projection", - ) - export_config = export_config_constructor( - pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype - ) - models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config) - return models_for_export diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 3899a7b36b6..6df9816e056 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -32,6 +32,9 @@ AutoPipelineForText2Image, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, + StableDiffusion3Img2ImgPipeline, + StableDiffusion3InpaintPipeline, + StableDiffusion3Pipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, @@ -57,7 +60,9 @@ from ..onnx.utils import _get_model_external_data_paths from ..utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, + DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, @@ -76,7 +81,7 @@ if check_if_diffusers_greater("0.25.0"): from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution else: - from diffusers.models.vae import DiagonalGaussianDistribution + from diffusers.models.vae import DiagonalGaussianDistribution # type: ignore logger = logging.getLogger(__name__) @@ -92,15 +97,18 @@ class ORTDiffusionPipeline(ORTModel, DiffusionPipeline): def __init__( self, scheduler: "SchedulerMixin", - unet_session: ort.InferenceSession, vae_decoder_session: ort.InferenceSession, # optional pipeline models + unet_session: Optional[ort.InferenceSession] = None, + transformer_session: Optional[ort.InferenceSession] = None, vae_encoder_session: Optional[ort.InferenceSession] = None, text_encoder_session: Optional[ort.InferenceSession] = None, text_encoder_2_session: Optional[ort.InferenceSession] = None, + text_encoder_3_session: Optional[ort.InferenceSession] = None, # optional pipeline submodels tokenizer: Optional["CLIPTokenizer"] = None, tokenizer_2: Optional["CLIPTokenizer"] = None, + tokenizer_3: Optional["CLIPTokenizer"] = None, feature_extractor: Optional["CLIPFeatureExtractor"] = None, # stable diffusion xl specific arguments force_zeros_for_empty_prompt: bool = True, @@ -111,16 +119,20 @@ def __init__( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ): - self.unet = ORTModelUnet(unet_session, self) - self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) - self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) if vae_encoder_session is not None else None + self.unet = ORTModelUnet(unet_session, self) if unet_session is not None else None + self.transformer = ORTModelTransformer(transformer_session, self) if transformer_session is not None else None self.text_encoder = ( ORTModelTextEncoder(text_encoder_session, self) if text_encoder_session is not None else None ) self.text_encoder_2 = ( ORTModelTextEncoder(text_encoder_2_session, self) if text_encoder_2_session is not None else None ) + self.text_encoder_3 = ( + ORTModelTextEncoder(text_encoder_3_session, self) if text_encoder_3_session is not None else None + ) # We wrap the VAE Decoder & Encoder in a single object to simulate diffusers API + self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) if vae_encoder_session is not None else None + self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) if vae_decoder_session is not None else None self.vae = ORTWrapperVae(self.vae_encoder, self.vae_decoder) # we allow passing these as torch models for now @@ -130,18 +142,22 @@ def __init__( self.scheduler = scheduler self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 + self.tokenizer_3 = tokenizer_3 self.feature_extractor = feature_extractor all_pipeline_init_args = { "vae": self.vae, "unet": self.unet, + "transformer": self.transformer, "text_encoder": self.text_encoder, "text_encoder_2": self.text_encoder_2, + "text_encoder_3": self.text_encoder_3, "safety_checker": self.safety_checker, "image_encoder": self.image_encoder, "scheduler": self.scheduler, "tokenizer": self.tokenizer, "tokenizer_2": self.tokenizer_2, + "tokenizer_3": self.tokenizer_3, "feature_extractor": self.feature_extractor, "requires_aesthetics_score": requires_aesthetics_score, "force_zeros_for_empty_prompt": force_zeros_for_empty_prompt, @@ -157,7 +173,10 @@ def __init__( # inits ort specific attributes self.shared_attributes_init( - model=unet_session, use_io_binding=use_io_binding, model_save_dir=model_save_dir, **kwargs + model=unet_session if unet_session is not None else transformer_session, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + **kwargs, ) def _save_pretrained(self, save_directory: Union[str, Path]): @@ -165,10 +184,12 @@ def _save_pretrained(self, save_directory: Union[str, Path]): models_to_save_paths = { (self.unet, save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER), + (self.transformer, save_directory / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER), (self.vae_decoder, save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER), (self.vae_encoder, save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER), (self.text_encoder, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER), (self.text_encoder_2, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER), + (self.text_encoder_3, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER), } for model, save_path in models_to_save_paths: if model is not None: @@ -192,6 +213,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]): self.tokenizer.save_pretrained(save_directory / "tokenizer") if self.tokenizer_2 is not None: self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") + if self.tokenizer_3 is not None: + self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3") if self.feature_extractor is not None: self.feature_extractor.save_pretrained(save_directory / "feature_extractor") @@ -208,10 +231,12 @@ def _from_pretrained( cache_dir: str = HUGGINGFACE_HUB_CACHE, token: Optional[Union[bool, str]] = None, unet_file_name: str = ONNX_WEIGHTS_NAME, + transformer_file_name: str = ONNX_WEIGHTS_NAME, vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, vae_encoder_file_name: str = ONNX_WEIGHTS_NAME, text_encoder_file_name: str = ONNX_WEIGHTS_NAME, text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME, + text_encoder_3_file_name: str = ONNX_WEIGHTS_NAME, use_io_binding: Optional[bool] = None, provider: str = "CPUExecutionProvider", provider_options: Optional[Dict[str, Any]] = None, @@ -230,10 +255,12 @@ def _from_pretrained( allow_patterns.update( { unet_file_name, + transformer_file_name, vae_decoder_file_name, vae_encoder_file_name, text_encoder_file_name, text_encoder_2_file_name, + text_encoder_3_file_name, SCHEDULER_CONFIG_NAME, cls.config_name, CONFIG_NAME, @@ -259,10 +286,12 @@ def _from_pretrained( model_paths = { "unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, + "transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / transformer_file_name, "vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, "vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, "text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, "text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name, + "text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / text_encoder_3_file_name, } sessions = {} @@ -276,7 +305,7 @@ def _from_pretrained( ) submodels = {} - for submodel in {"scheduler", "tokenizer", "tokenizer_2", "feature_extractor"}: + for submodel in {"scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"}: if kwargs.get(submodel, None) is not None: submodels[submodel] = kwargs.pop(submodel) elif config.get(submodel, (None, None))[0] is not None: @@ -385,17 +414,24 @@ def to(self, device: Union[torch.device, str, int]): if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": return self - self.unet.session.set_providers([provider], provider_options=[provider_options]) self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) + if self.unet is not None: + self.unet.session.set_providers([provider], provider_options=[provider_options]) + if self.transformer is not None: + self.transformer.session.set_providers([provider], provider_options=[provider_options]) if self.vae_encoder is not None: self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) if self.text_encoder is not None: self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) if self.text_encoder_2 is not None: self.text_encoder_2.session.set_providers([provider], provider_options=[provider_options]) + if self.text_encoder_3 is not None: + self.text_encoder_3.session.set_providers([provider], provider_options=[provider_options]) - self.providers = self.unet.session.get_providers() + self.providers = ( + self.unet.session.get_providers() if self.unet is not None else self.transformer.session.get_providers() + ) self._device = device return self @@ -412,8 +448,10 @@ def components(self) -> Dict[str, Any]: components = { "vae": self.vae, "unet": self.unet, + "transformer": self.transformer, "text_encoder": self.text_encoder, "text_encoder_2": self.text_encoder_2, + "text_encoder_3": self.text_encoder_3, "safety_checker": self.safety_checker, "image_encoder": self.image_encoder, } @@ -581,6 +619,39 @@ def forward( return ModelOutput(**model_outputs) +class ORTModelTransformer(ORTPipelinePart): + def forward( + self, + hidden_states: Union[np.ndarray, torch.Tensor], + timestep: Union[np.ndarray, torch.Tensor], + encoder_hidden_states: Union[np.ndarray, torch.Tensor], + pooled_projections: Union[np.ndarray, torch.Tensor], + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ): + use_torch = isinstance(hidden_states, torch.Tensor) + + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + **(joint_attention_kwargs or {}), + } + + onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) + + if return_dict: + return model_outputs + + return ModelOutput(**model_outputs) + + class ORTModelTextEncoder(ORTPipelinePart): def forward( self, @@ -599,11 +670,17 @@ def forward( if output_hidden_states: model_outputs["hidden_states"] = [] - for i in range(self.config.num_hidden_layers): + num_layers = ( + self.config.num_hidden_layers if hasattr(self.config, "num_hidden_layers") else self.num_decoder_layers + ) + for i in range(num_layers): model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}")) model_outputs["hidden_states"].append(model_outputs.get("last_hidden_state")) else: - for i in range(self.config.num_hidden_layers): + num_layers = ( + self.config.num_hidden_layers if hasattr(self.config, "num_hidden_layers") else self.num_decoder_layers + ) + for i in range(num_layers): model_outputs.pop(f"hidden_states.{i}", None) if return_dict: @@ -871,6 +948,39 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi auto_model_class = LatentConsistencyModelImg2ImgPipeline +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTStableDiffusion3Pipeline(ORTDiffusionPipeline, StableDiffusion3Pipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Pipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusion3Pipeline). + """ + + main_input_name = "prompt" + export_feature = "text-to-image" + auto_model_class = StableDiffusion3Pipeline + + +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTStableDiffusion3Img2ImgPipeline(ORTDiffusionPipeline, StableDiffusion3Img2ImgPipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Img2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusion3Img2ImgPipeline). + """ + + main_input_name = "image" + export_feature = "image-to-image" + auto_model_class = StableDiffusion3Img2ImgPipeline + + +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3InpaintPipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3InpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusion3InpaintPipeline). + """ + + main_input_name = "prompt" + export_feature = "inpainting" + auto_model_class = StableDiffusion3InpaintPipeline + + SUPPORTED_ORT_PIPELINES = [ ORTStableDiffusionPipeline, ORTStableDiffusionImg2ImgPipeline, @@ -880,6 +990,9 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi ORTStableDiffusionXLInpaintPipeline, ORTLatentConsistencyModelPipeline, ORTLatentConsistencyModelImg2ImgPipeline, + ORTStableDiffusion3Pipeline, + ORTStableDiffusion3Img2ImgPipeline, + ORTStableDiffusion3InpaintPipeline, ] @@ -900,6 +1013,7 @@ def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tr ("stable-diffusion", ORTStableDiffusionPipeline), ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), ("latent-consistency", ORTLatentConsistencyModelPipeline), + ("stable-diffusion-3", ORTStableDiffusion3Pipeline), ] ) @@ -908,6 +1022,7 @@ def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tr ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), ("latent-consistency", ORTLatentConsistencyModelImg2ImgPipeline), + ("stable-diffusion-3", ORTStableDiffusion3Img2ImgPipeline), ] ) @@ -915,6 +1030,7 @@ def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tr [ ("stable-diffusion", ORTStableDiffusionInpaintPipeline), ("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline), + ("stable-diffusion-3", ORTStableDiffusion3InpaintPipeline), ] ) diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 5d5044e63e1..cdf7c45b0e1 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -16,7 +16,9 @@ from .constant import ( CONFIG_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, + DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, diff --git a/optimum/utils/constant.py b/optimum/utils/constant.py index 4497b5246d4..eb7a67e9ece 100644 --- a/optimum/utils/constant.py +++ b/optimum/utils/constant.py @@ -15,8 +15,10 @@ CONFIG_NAME = "config.json" DIFFUSION_MODEL_UNET_SUBFOLDER = "unet" -DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER = "text_encoder" +DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER = "transformer" DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER = "vae_decoder" DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER = "vae_encoder" +DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER = "text_encoder" DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER = "text_encoder_2" +DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER = "text_encoder_3" ONNX_WEIGHTS_NAME = "model.onnx" diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index 956566f0e1f..38e9b2b5391 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -71,7 +71,8 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type= class ORTPipelineForText2ImageTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency", "stable-diffusion-3"] + CALLBACK_SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] ORTMODEL_CLASS = ORTPipelineForText2Image AUTOMODEL_CLASS = AutoPipelineForText2Image @@ -147,7 +148,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) - @parameterized.expand(SUPPORTED_ARCHITECTURES) + @parameterized.expand(CALLBACK_SUPPORTED_ARCHITECTURES) @require_diffusers def test_callback(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} @@ -200,9 +201,19 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: + out_channels = ( + pipeline.unet.config.out_channels + if pipeline.unet is not None + else pipeline.transformer.config.out_channels + ) self.assertEqual( outputs.shape, - (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ( + batch_size, + out_channels, + height // pipeline.vae_scale_factor, + width // pipeline.vae_scale_factor, + ), ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -251,6 +262,21 @@ def test_negative_prompt(self, model_arch: str): do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) + elif model_arch == "stable-diffusion-3": + ( + inputs["prompt_embeds"], + inputs["negative_prompt_embeds"], + inputs["pooled_prompt_embeds"], + inputs["negative_pooled_prompt_embeds"], + ) = pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt, + prompt_3=prompt, + num_images_per_prompt=1, + device=torch.device("cpu"), + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) else: inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt( prompt=prompt, @@ -326,7 +352,8 @@ def test_safety_checker(self, model_arch: str): class ORTPipelineForImage2ImageTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency", "stable-diffusion-3"] + CALLBACK_SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] AUTOMODEL_CLASS = AutoPipelineForImage2Image ORTMODEL_CLASS = ORTPipelineForImage2Image @@ -380,7 +407,7 @@ def test_num_images_per_prompt(self, model_arch: str): outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) - @parameterized.expand(SUPPORTED_ARCHITECTURES) + @parameterized.expand(CALLBACK_SUPPORTED_ARCHITECTURES) @require_diffusers def test_callback(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} @@ -434,9 +461,19 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: + out_channels = ( + pipeline.unet.config.out_channels + if pipeline.unet is not None + else pipeline.transformer.config.out_channels + ) self.assertEqual( outputs.shape, - (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ( + batch_size, + out_channels, + height // pipeline.vae_scale_factor, + width // pipeline.vae_scale_factor, + ), ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -541,7 +578,8 @@ def test_safety_checker(self, model_arch: str): class ORTPipelineForInpaintingTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"] + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "stable-diffusion-3"] + CALLBACK_SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"] AUTOMODEL_CLASS = AutoPipelineForInpainting ORTMODEL_CLASS = ORTPipelineForInpainting @@ -600,7 +638,7 @@ def test_num_images_per_prompt(self, model_arch: str): outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) - @parameterized.expand(SUPPORTED_ARCHITECTURES) + @parameterized.expand(CALLBACK_SUPPORTED_ARCHITECTURES) @require_diffusers def test_callback(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} @@ -654,9 +692,19 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: + out_channels = ( + pipeline.unet.config.out_channels + if pipeline.unet is not None + else pipeline.transformer.config.out_channels + ) self.assertEqual( outputs.shape, - (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ( + batch_size, + out_channels, + height // pipeline.vae_scale_factor, + width // pipeline.vae_scale_factor, + ), ) @parameterized.expand(SUPPORTED_ARCHITECTURES) diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 5071d0081af..50eec6c95ee 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -141,6 +141,7 @@ "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224", From 08190c70fbba5d89954a6ca81f49e30902300a65 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 21 Oct 2024 17:09:30 +0200 Subject: [PATCH 2/7] unsupported cli model types --- optimum/exporters/tasks.py | 8 +++++--- tests/exporters/exporters_utils.py | 3 ++- tests/onnxruntime/utils_onnxruntime_tests.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 87ab62b2f29..82417886f26 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1178,12 +1178,14 @@ class TasksManager: "transformers": _SUPPORTED_MODEL_TYPE, } _UNSUPPORTED_CLI_MODEL_TYPE = { - "unet", - "vae-encoder", - "vae-decoder", "clip-text-model", "clip-text-with-projection", + "sd3-transformer", + "t5-encoder", "trocr", # supported through the vision-encoder-decoder model type + "unet", + "vae-encoder", + "vae-decoder", } _SUPPORTED_CLI_MODEL_TYPE = ( set(_SUPPORTED_MODEL_TYPE.keys()) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c8a33b0be35..9617ab37a00 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -296,9 +296,10 @@ } PYTORCH_DIFFUSION_MODEL = { + "latent-consistency": "echarlaix/tiny-random-latent-consistency", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", - "latent-consistency": "echarlaix/tiny-random-latent-consistency", } PYTORCH_TIMM_MODEL = { diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 50eec6c95ee..cb224993ad6 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -107,10 +107,10 @@ "groupvit": "hf-internal-testing/tiny-random-groupvit", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-IBertModel", - "levit": "hf-internal-testing/tiny-random-LevitModel", "latent-consistency": "echarlaix/tiny-random-latent-consistency", "layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel", "layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model", + "levit": "hf-internal-testing/tiny-random-LevitModel", "longt5": "hf-internal-testing/tiny-random-LongT5Model", "llama": "optimum-internal-testing/tiny-random-llama", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", From 4518691596c4a6c9a9994facbb5b1eb2162388d2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 23 Oct 2024 12:27:45 +0200 Subject: [PATCH 3/7] flux transformer support, unet export fixes, updated callback test, updated negative prompt test, flux and sd3 tests --- optimum/exporters/onnx/model_configs.py | 178 ++++++++++++++----- optimum/exporters/tasks.py | 25 +-- optimum/exporters/utils.py | 107 +++++++---- optimum/onnxruntime/modeling_diffusion.py | 66 ++++--- optimum/utils/input_generators.py | 4 +- tests/exporters/exporters_utils.py | 1 + tests/onnxruntime/test_diffusion.py | 141 +++++++-------- tests/onnxruntime/utils_onnxruntime_tests.py | 3 +- 8 files changed, 331 insertions(+), 194 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 28f894fc22c..afe87c772da 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -53,6 +53,7 @@ NormalizedTextConfig, NormalizedTextConfigWithGQA, NormalizedVisionConfig, + check_if_diffusers_greater, check_if_transformers_greater, is_diffusers_available, logging, @@ -1031,7 +1032,7 @@ def patch_model_for_export( class UNetOnnxConfig(VisionOnnxConfig): - ATOL_FOR_VALIDATION = 1e-3 + ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 @@ -1054,17 +1055,19 @@ class UNetOnnxConfig(VisionOnnxConfig): def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = { "sample": {0: "batch_size", 2: "height", 3: "width"}, - "timestep": {0: "steps"}, + "timestep": {}, # a scalar with no dimension "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, } - # TODO : add text_image, image and image_embeds + # TODO : add addition_embed_type == text_image, image and image_embeds + # https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671 if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": common_inputs["text_embeds"] = {0: "batch_size"} common_inputs["time_ids"] = {0: "batch_size"} if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: common_inputs["timestep_cond"] = {0: "batch_size"} + return common_inputs @property @@ -1151,73 +1154,168 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } -class PooledProjectionsDummyInputGenerator(DummyInputGenerator): - SUPPORTED_INPUT_NAMES = "pooled_projections" - - def __init__( - self, - task: str, - normalized_config: NormalizedConfig, - batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], - **kwargs, - ): - self.task = task - self.batch_size = batch_size - self.pooled_projection_dim = normalized_config.config.pooled_projection_dim +class T5EncoderOnnxConfig(CLIPTextOnnxConfig): + @property + def inputs(self): + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - return self.random_float_tensor( - [self.batch_size, self.pooled_projection_dim], framework=framework, dtype=float_dtype - ) + @property + def outputs(self): + return { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator): + SUPPORTED_INPUT_NAMES = ("timestep",) + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): if input_name == "timestep": - shape = [self.batch_size] + shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states",) + + +class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "encoder_hidden_states", + "pooled_projection", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "encoder_hidden_states": + return super().generate(input_name, framework, int_dtype, float_dtype)[0] + + elif input_name == "pooled_projections": + return self.random_float_tensor( + [self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype + ) + return super().generate(input_name, framework, int_dtype, float_dtype) -class SD3TransformerOnnxConfig(UNetOnnxConfig): +class SD3TransformerOnnxConfig(VisionOnnxConfig): + ATOL_FOR_VALIDATION = 1e-4 + # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu + # operator support, available since opset 14 + DEFAULT_ONNX_OPSET = 14 + DUMMY_INPUT_GENERATOR_CLASSES = ( - (DummyTransformerTimestpsInputGenerator,) - + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES - + (PooledProjectionsDummyInputGenerator,) + DummyTransformerTimestpsInputGenerator, + DummyTransformerVisionInputGenerator, + DummyTransformerTextInputGenerator, ) + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( image_size="sample_size", num_channels="in_channels", - hidden_size="joint_attention_dim", vocab_size="attention_head_dim", + hidden_size="joint_attention_dim", + projection_size="pooled_projection_dim", allow_new=True, ) @property - def inputs(self): - common_inputs = super().inputs - common_inputs["pooled_projections"] = {0: "batch_size"} - return common_inputs + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "hidden_states": {0: "batch_size", 2: "height", 3: "width"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "step"}, + } - def rename_ambiguous_inputs(self, inputs): - # The input name in the model signature is `x, hence the export input name is updated. - hidden_states = inputs.pop("sample", None) - if hidden_states is not None: - inputs["hidden_states"] = hidden_states - return inputs + return common_inputs + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "out_hidden_states": {0: "batch_size", 2: "height", 3: "width"}, + } -class T5EncoderOnnxConfig(CLIPTextOnnxConfig): @property - def inputs(self): + def torch_to_onnx_output_map(self) -> Dict[str, str]: return { - "input_ids": {0: "batch_size", 1: "sequence_length"}, + "sample": "out_hidden_states", } + +class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "hidden_states", + "img_ids", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "hidden_states": + shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "img_ids": + shape = ( + [(self.height // 2) * (self.width // 2), 3] + if check_if_diffusers_greater("0.31.0") + else [self.batch_size, (self.height // 2) * (self.width // 2), 3] + ) + return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "encoder_hidden_states", + "pooled_projections", + "txt_ids", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "txt_ids": + shape = ( + [self.sequence_length, 3] + if check_if_diffusers_greater("0.31.0") + else [self.batch_size, self.sequence_length, 3] + ) + return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTransformerTimestpsInputGenerator, + DummyFluxTransformerVisionInputGenerator, + DummyFluxTransformerTextInputGenerator, + ) + + @property + def inputs(self): + common_inputs = super().inputs + + common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"} + common_inputs["txt_ids"] = ( + {0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"} + ) + common_inputs["img_ids"] = ( + {0: "packed_height_width"} + if check_if_diffusers_greater("0.31.0") + else {0: "batch_size", 1: "packed_height_width"} + ) + + if getattr(self._normalized_config, "guidance_embeds", False): + common_inputs["guidance"] = {0: "batch_size"} + + return common_inputs + @property def outputs(self): return { - "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + "out_hidden_states": {0: "batch_size", 1: "packed_height_width"}, } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 82417886f26..553dcc1a27a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -335,7 +335,7 @@ class TasksManager: } _DIFFUSERS_SUPPORTED_MODEL_TYPE = { - "t5-encoder": supported_tasks_mapping( + "t5-encoder-model": supported_tasks_mapping( "feature-extraction", onnx="T5EncoderOnnxConfig", ), @@ -343,18 +343,22 @@ class TasksManager: "feature-extraction", onnx="CLIPTextOnnxConfig", ), - "clip-text-with-projection": supported_tasks_mapping( + "clip-text-model-with-projection": supported_tasks_mapping( "feature-extraction", onnx="CLIPTextWithProjectionOnnxConfig", ), - "unet": supported_tasks_mapping( + "flux-transformer-2d-model": supported_tasks_mapping( "semantic-segmentation", - onnx="UNetOnnxConfig", + onnx="FluxTransformerOnnxConfig", ), - "sd3-transformer": supported_tasks_mapping( + "sd3-transformer-2d-model": supported_tasks_mapping( "semantic-segmentation", onnx="SD3TransformerOnnxConfig", ), + "unet-2d-condition": supported_tasks_mapping( + "semantic-segmentation", + onnx="UNetOnnxConfig", + ), "vae-encoder": supported_tasks_mapping( "semantic-segmentation", onnx="VaeEncoderOnnxConfig", @@ -1178,12 +1182,13 @@ class TasksManager: "transformers": _SUPPORTED_MODEL_TYPE, } _UNSUPPORTED_CLI_MODEL_TYPE = { + # diffusers submodels "clip-text-model", - "clip-text-with-projection", - "sd3-transformer", - "t5-encoder", - "trocr", # supported through the vision-encoder-decoder model type - "unet", + "clip-text-model-with-projection", + "flux-transformer-model", + "sd3-transformer-2d-model", + "t5-encoder-model", + "unet-2d-condition", "vae-encoder", "vae-decoder", } diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 9d37c7a996b..52ad986dd87 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -21,7 +21,9 @@ import torch from packaging import version +from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection from transformers.models.speecht5.modeling_speecht5 import SpeechT5HifiGan +from transformers.models.t5.modeling_t5 import T5EncoderModel from transformers.utils import is_tf_available, is_torch_available from ..utils import ( @@ -53,10 +55,20 @@ if check_if_diffusers_greater("0.30.0"): from diffusers import ( + FluxTransformer2DModel, + SD3Transformer2DModel, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline, + UNet2DConditionModel, ) + else: + FluxTransformer2DModel = None + SD3Transformer2DModel = None + StableDiffusion3Img2ImgPipeline = None + StableDiffusion3InpaintPipeline = None + StableDiffusion3Pipeline = None + UNet2DConditionModel = None from diffusers.models.attention_processor import ( Attention, @@ -88,6 +100,23 @@ DECODER_MERGED_NAME = "decoder_model_merged" +def _get_diffusers_model_type(model): + if isinstance(model, CLIPTextModel): + return "clip-text-model" + elif isinstance(model, CLIPTextModelWithProjection): + return "clip-text-model-with-projection" + elif isinstance(model, T5EncoderModel): + return "t5-encoder-model" + elif isinstance(model, FluxTransformer2DModel): + return "flux-transformer-2d-model" + elif isinstance(model, SD3Transformer2DModel): + return "sd3-transformer-2d-model" + elif isinstance(model, UNet2DConditionModel): + return "unet-2d-condition" + else: + raise ValueError(f"Unknown model class: {model.__class__}") + + def _get_submodels_for_export_diffusion( pipeline: "DiffusionPipeline", ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: @@ -97,11 +126,13 @@ def _get_submodels_for_export_diffusion( models_for_export = {} - is_stable_diffusion_xl = isinstance( - pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline) + is_sdxl = isinstance( + pipeline, + (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline), ) - is_stable_diffusion_3 = isinstance( - pipeline, (StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline) + is_sd3 = isinstance( + pipeline, + (StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline), ) is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") @@ -109,30 +140,27 @@ def _get_submodels_for_export_diffusion( # Text encoder text_encoder = getattr(pipeline, "text_encoder", None) if text_encoder is not None: - if is_stable_diffusion_xl or is_stable_diffusion_3: + if is_sdxl or is_sd3: text_encoder.config.output_hidden_states = True text_encoder.text_model.config.output_hidden_states = True - if is_stable_diffusion_3: - text_encoder.config.export_model_type = "clip-text-with-projection" - else: - text_encoder.config.export_model_type = "clip-text-model" - + text_encoder.config.export_model_type = _get_diffusers_model_type(text_encoder) models_for_export["text_encoder"] = text_encoder # Text encoder 2 text_encoder_2 = getattr(pipeline, "text_encoder_2", None) if text_encoder_2 is not None: - text_encoder_2.config.output_hidden_states = True - text_encoder_2.text_model.config.output_hidden_states = True - text_encoder_2.config.export_model_type = "clip-text-with-projection" + if is_sdxl or is_sd3: + text_encoder_2.config.output_hidden_states = True + text_encoder_2.text_model.config.output_hidden_states = True + text_encoder_2.config.export_model_type = _get_diffusers_model_type(text_encoder_2) models_for_export["text_encoder_2"] = text_encoder_2 # Text encoder 3 text_encoder_3 = getattr(pipeline, "text_encoder_3", None) if text_encoder_3 is not None: - text_encoder_3.config.export_model_type = "t5-encoder" + text_encoder_3.config.export_model_type = _get_diffusers_model_type(text_encoder_3) models_for_export["text_encoder_3"] = text_encoder_3 # U-NET @@ -147,7 +175,7 @@ def _get_submodels_for_export_diffusion( unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) unet.config.time_cond_proj_dim = getattr(pipeline.unet.config, "time_cond_proj_dim", None) unet.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim - unet.config.export_model_type = "unet" + unet.config.export_model_type = _get_diffusers_model_type(unet) models_for_export["unet"] = unet # Transformer @@ -160,7 +188,7 @@ def _get_submodels_for_export_diffusion( transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None) transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim - transformer.config.export_model_type = "sd3-transformer" + transformer.config.export_model_type = _get_diffusers_model_type(transformer) models_for_export["transformer"] = transformer # VAE Encoder @@ -346,55 +374,54 @@ def get_diffusion_models_for_export( # Text encoder if "text_encoder" in models_for_export: + text_encoder = models_for_export["text_encoder"] text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction" + model=text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction" ) text_encoder_export_config = text_encoder_config_constructor( - pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype + text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype ) models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config) # Text encoder 2 if "text_encoder_2" in models_for_export: + text_encoder_2 = models_for_export["text_encoder_2"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction" - ) - export_config = export_config_constructor( - pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype + model=text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction" ) + export_config = export_config_constructor(text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config) # Text encoder 3 if "text_encoder_3" in models_for_export: + text_encoder_3 = models_for_export["text_encoder_3"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction" - ) - export_config = export_config_constructor( - pipeline.text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype + model=text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction" ) + export_config = export_config_constructor(text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["text_encoder_3"] = (models_for_export["text_encoder_3"], export_config) # U-NET if "unet" in models_for_export: + unet = models_for_export["unet"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation" - ) - unet_export_config = export_config_constructor( - pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype + model=unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation" ) + unet_export_config = export_config_constructor(unet.config, int_dtype=int_dtype, float_dtype=float_dtype) models_for_export["unet"] = (models_for_export["unet"], unet_export_config) # Transformer if "transformer" in models_for_export: + transformer = models_for_export["transformer"] export_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation" + model=transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation" ) transformer_export_config = export_config_constructor( - pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype + transformer.config, int_dtype=int_dtype, float_dtype=float_dtype ) models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config) - # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 + # VAE Encoder vae_encoder = models_for_export["vae_encoder"] vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_encoder, @@ -403,10 +430,12 @@ def get_diffusion_models_for_export( task="semantic-segmentation", model_type="vae-encoder", ) - vae_export_config = vae_config_constructor(vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype) - models_for_export["vae_encoder"] = (vae_encoder, vae_export_config) + vae_encoder_export_config = vae_config_constructor( + vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config) - # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 + # VAE Decoder vae_decoder = models_for_export["vae_decoder"] vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_decoder, @@ -415,8 +444,10 @@ def get_diffusion_models_for_export( task="semantic-segmentation", model_type="vae-decoder", ) - vae_export_config = vae_config_constructor(vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype) - models_for_export["vae_decoder"] = (vae_decoder, vae_export_config) + vae_decoder_export_config = vae_config_constructor( + vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config) return models_for_export diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 6df9816e056..c3d40f1ce1a 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -30,6 +30,7 @@ AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + FluxPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, StableDiffusion3Img2ImgPipeline, @@ -481,9 +482,13 @@ def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTDiffusionP self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()} self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()} + self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()} + self.output_shapes = {output_key.name: output_key.shape for output_key in self.session.get_outputs()} + config_file_path = Path(session._model_path).parent / self.config_name if not config_file_path.is_file(): # config is mandatory for the model part to be used for inference @@ -581,13 +586,18 @@ def __init__(self, *args, **kwargs): ) self.register_to_config(time_cond_proj_dim=None) + if len(self.input_shapes["timestep"]) > 0: + logger.warning( + "The exported unet onnx model expects a non scalar timestep input. " + "We will have to unsqueeze the timestep input at each iteration which might be inefficient. " + "Please re-export the pipeline with newer version of optimum and diffusers to avoid this warning." + ) + def forward( self, sample: Union[np.ndarray, torch.Tensor], timestep: Union[np.ndarray, torch.Tensor], encoder_hidden_states: Union[np.ndarray, torch.Tensor], - text_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None, - time_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, Any]] = None, @@ -595,15 +605,13 @@ def forward( ): use_torch = isinstance(sample, torch.Tensor) - if len(timestep.shape) == 0: + if len(self.input_shapes["timestep"]) > 0: timestep = timestep.unsqueeze(0) model_inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, - "text_embeds": text_embeds, - "time_ids": time_ids, "timestep_cond": timestep_cond, **(cross_attention_kwargs or {}), **(added_cond_kwargs or {}), @@ -623,22 +631,25 @@ class ORTModelTransformer(ORTPipelinePart): def forward( self, hidden_states: Union[np.ndarray, torch.Tensor], - timestep: Union[np.ndarray, torch.Tensor], encoder_hidden_states: Union[np.ndarray, torch.Tensor], pooled_projections: Union[np.ndarray, torch.Tensor], + timestep: Union[np.ndarray, torch.Tensor], + guidance: Optional[Union[np.ndarray, torch.Tensor]] = None, + txt_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, + img_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ): use_torch = isinstance(hidden_states, torch.Tensor) - if len(timestep.shape) == 0: - timestep = timestep.unsqueeze(0) - model_inputs = { "hidden_states": hidden_states, - "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, "pooled_projections": pooled_projections, + "timestep": timestep, + "guidance": guidance, + "txt_ids": txt_ids, + "img_ids": img_ids, **(joint_attention_kwargs or {}), } @@ -670,16 +681,12 @@ def forward( if output_hidden_states: model_outputs["hidden_states"] = [] - num_layers = ( - self.config.num_hidden_layers if hasattr(self.config, "num_hidden_layers") else self.num_decoder_layers - ) + num_layers = self.num_hidden_layers if hasattr(self, "num_hidden_layers") else self.num_decoder_layers for i in range(num_layers): model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}")) model_outputs["hidden_states"].append(model_outputs.get("last_hidden_state")) else: - num_layers = ( - self.config.num_hidden_layers if hasattr(self.config, "num_hidden_layers") else self.num_decoder_layers - ) + num_layers = self.num_hidden_layers if hasattr(self, "num_hidden_layers") else self.num_decoder_layers for i in range(num_layers): model_outputs.pop(f"hidden_states.{i}", None) @@ -697,7 +704,7 @@ def __init__(self, *args, **kwargs): if not hasattr(self.config, "scaling_factor"): logger.warning( "The `scaling_factor` attribute is missing from the VAE encoder configuration. " - "Please re-export the model with newer version of optimum and diffusers." + "Please re-export the model with newer version of optimum and diffusers to avoid this warning." ) self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1)) @@ -737,7 +744,7 @@ def __init__(self, *args, **kwargs): if not hasattr(self.config, "scaling_factor"): logger.warning( "The `scaling_factor` attribute is missing from the VAE decoder configuration. " - "Please re-export the model with newer version of optimum and diffusers." + "Please re-export the model with newer version of optimum and diffusers to avoid this warning." ) self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1)) @@ -981,6 +988,17 @@ class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3I auto_model_class = StableDiffusion3InpaintPipeline +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.FluxPipeline](https://huggingface.co/docs/diffusers/api/pipelines/flux/text2img#diffusers.FluxPipeline). + """ + + main_input_name = "prompt" + export_feature = "text-to-image" + auto_model_class = FluxPipeline + + SUPPORTED_ORT_PIPELINES = [ ORTStableDiffusionPipeline, ORTStableDiffusionImg2ImgPipeline, @@ -993,6 +1011,7 @@ class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3I ORTStableDiffusion3Pipeline, ORTStableDiffusion3Img2ImgPipeline, ORTStableDiffusion3InpaintPipeline, + ORTFluxPipeline, ] @@ -1010,27 +1029,28 @@ def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tr ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( [ - ("stable-diffusion", ORTStableDiffusionPipeline), - ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), + ("flux", ORTFluxPipeline), ("latent-consistency", ORTLatentConsistencyModelPipeline), + ("stable-diffusion", ORTStableDiffusionPipeline), ("stable-diffusion-3", ORTStableDiffusion3Pipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), ] ) ORT_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( [ - ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), - ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), ("latent-consistency", ORTLatentConsistencyModelImg2ImgPipeline), + ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), ("stable-diffusion-3", ORTStableDiffusion3Img2ImgPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), ] ) ORT_INPAINT_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", ORTStableDiffusionInpaintPipeline), - ("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline), ("stable-diffusion-3", ORTStableDiffusion3InpaintPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline), ] ) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index dac14a38114..dd29821766a 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -871,8 +871,8 @@ def __init__( def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): if input_name == "timestep": - shape = [self.batch_size] - return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype) + shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture) + return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) if input_name == "text_embeds": dim = self.text_encoder_projection_dim diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9617ab37a00..e7fe1035dc1 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -296,6 +296,7 @@ } PYTORCH_DIFFUSION_MODEL = { + "flux": "optimum-internal-testing/tiny-random-flux", "latent-consistency": "echarlaix/tiny-random-latent-consistency", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random", diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index 38e9b2b5391..5d70bb2de1d 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -71,8 +71,25 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type= class ORTPipelineForText2ImageTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency", "stable-diffusion-3"] - CALLBACK_SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + "stable-diffusion-xl", + "latent-consistency", + "stable-diffusion-3", + "flux", + ] + NEGATIVE_PROMPT_SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + "stable-diffusion-xl", + "latent-consistency", + "stable-diffusion-3", + ] + CALLBACK_SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + "stable-diffusion-xl", + "latent-consistency", + "flux", + ] ORTMODEL_CLASS = ORTPipelineForText2Image AUTOMODEL_CLASS = AutoPipelineForText2Image @@ -143,10 +160,10 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand(CALLBACK_SUPPORTED_ARCHITECTURES) @require_diffusers @@ -165,6 +182,7 @@ def __init__(self): def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 + return kwargs ort_callback = Callback() auto_callback = Callback() @@ -172,9 +190,8 @@ def __call__(self, *args, **kwargs) -> None: ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) - # callback_steps=1 to trigger callback every step - ort_pipe(**inputs, callback=ort_callback, callback_steps=1) - auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + ort_pipe(**inputs, callback_on_step_end=ort_callback) + auto_pipe(**inputs, callback_on_step_end=auto_callback) self.assertTrue(ort_callback.has_been_called) self.assertTrue(auto_callback.has_been_called) @@ -201,20 +218,20 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: - out_channels = ( - pipeline.unet.config.out_channels - if pipeline.unet is not None - else pipeline.transformer.config.out_channels - ) - self.assertEqual( - outputs.shape, - ( - batch_size, - out_channels, - height // pipeline.vae_scale_factor, - width // pipeline.vae_scale_factor, - ), - ) + expected_height = height // pipeline.vae_scale_factor + expected_width = width // pipeline.vae_scale_factor + + if model_arch == "flux": + channels = pipeline.transformer.config.in_channels + expected_shape = (batch_size, expected_height * expected_width, channels) + elif model_arch == "stable-diffusion-3": + out_channels = pipeline.transformer.config.out_channels + expected_shape = (batch_size, out_channels, expected_height, expected_width) + else: + out_channels = pipeline.unet.config.out_channels + expected_shape = (batch_size, out_channels, expected_height, expected_width) + + self.assertEqual(outputs.shape, expected_shape) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -235,60 +252,22 @@ def test_image_reproducibility(self, model_arch: str): self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) np.testing.assert_allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-4, rtol=1e-2) - @parameterized.expand(SUPPORTED_ARCHITECTURES) + @parameterized.expand(NEGATIVE_PROMPT_SUPPORTED_ARCHITECTURES) def test_negative_prompt(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) height, width, batch_size = 64, 64, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["negative_prompt"] = ["This is a negative prompt"] * batch_size - negative_prompt = ["This is a negative prompt"] - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - images_1 = pipeline(**inputs, negative_prompt=negative_prompt, generator=get_generator("pt", SEED)).images - prompt = inputs.pop("prompt") - - if model_arch == "stable-diffusion-xl": - ( - inputs["prompt_embeds"], - inputs["negative_prompt_embeds"], - inputs["pooled_prompt_embeds"], - inputs["negative_pooled_prompt_embeds"], - ) = pipeline.encode_prompt( - prompt=prompt, - num_images_per_prompt=1, - device=torch.device("cpu"), - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - elif model_arch == "stable-diffusion-3": - ( - inputs["prompt_embeds"], - inputs["negative_prompt_embeds"], - inputs["pooled_prompt_embeds"], - inputs["negative_pooled_prompt_embeds"], - ) = pipeline.encode_prompt( - prompt=prompt, - prompt_2=prompt, - prompt_3=prompt, - num_images_per_prompt=1, - device=torch.device("cpu"), - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - else: - inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt( - prompt=prompt, - num_images_per_prompt=1, - device=torch.device("cpu"), - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - - images_2 = pipeline(**inputs, generator=get_generator("pt", SEED)).images - - np.testing.assert_allclose(images_1, images_2, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand( grid_parameters( @@ -425,15 +404,16 @@ def __init__(self): def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 + return kwargs ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) ort_callback = Callback() auto_callback = Callback() - # callback_steps=1 to trigger callback every step - ort_pipe(**inputs, callback=ort_callback, callback_steps=1) - auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + ort_pipe(**inputs, callback_on_step_end=ort_callback) + auto_pipe(**inputs, callback_on_step_end=auto_callback) self.assertTrue(ort_callback.has_been_called) self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) @@ -491,10 +471,10 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -656,15 +636,16 @@ def __init__(self): def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 + return kwargs ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) ort_callback = Callback() auto_callback = Callback() - # callback_steps=1 to trigger callback every step - ort_pipe(**inputs, callback=ort_callback, callback_steps=1) - auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + ort_pipe(**inputs, callback_on_step_end=ort_callback) + auto_pipe(**inputs, callback_on_step_end=auto_callback) self.assertTrue(ort_callback.has_been_called) self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) @@ -722,10 +703,10 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index cb224993ad6..b6f5efaef76 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -98,6 +98,7 @@ }, "falcon": "fxmarty/really-tiny-falcon-testing", "flaubert": "hf-internal-testing/tiny-random-flaubert", + "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", @@ -141,7 +142,7 @@ "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", - "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random", + "stable-diffusion-3": "optimum-internal-testing/tiny-random-stable-diffusion-3", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224", From aa74f63d87560a934374c6fdf80b290ce341095f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 23 Oct 2024 12:45:35 +0200 Subject: [PATCH 4/7] fixes --- optimum/exporters/onnx/model_configs.py | 2 +- optimum/exporters/tasks.py | 6 ++++-- tests/exporters/onnx/test_onnx_export.py | 2 -- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index afe87c772da..6bd32c19cef 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1106,7 +1106,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: class VaeEncoderOnnxConfig(VisionOnnxConfig): - ATOL_FOR_VALIDATION = 1e-4 + ATOL_FOR_VALIDATION = 3e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 553dcc1a27a..8bb0a9f1ba9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1182,15 +1182,17 @@ class TasksManager: "transformers": _SUPPORTED_MODEL_TYPE, } _UNSUPPORTED_CLI_MODEL_TYPE = { - # diffusers submodels + # diffusers model types "clip-text-model", "clip-text-model-with-projection", - "flux-transformer-model", + "flux-transformer-2d-model", "sd3-transformer-2d-model", "t5-encoder-model", "unet-2d-condition", "vae-encoder", "vae-decoder", + # redundant model types + "trocr", # same as vision-encoder-decoder } _SUPPORTED_CLI_MODEL_TYPE = ( set(_SUPPORTED_MODEL_TYPE.keys()) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 7671d6cd2e6..88288547c95 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -299,7 +299,6 @@ def _onnx_export_diffusion_models(self, model_type: str, model_name: str, device with TemporaryDirectory() as tmpdirname: _, onnx_outputs = export_models( models_and_onnx_configs=models_and_onnx_configs, - opset=14, output_dir=Path(tmpdirname), device=device, ) @@ -307,7 +306,6 @@ def _onnx_export_diffusion_models(self, model_type: str, model_name: str, device models_and_onnx_configs=models_and_onnx_configs, onnx_named_outputs=onnx_outputs, output_dir=Path(tmpdirname), - atol=1e-4, use_subprocess=False, ) From b566392f8eac9444e9a29d56ca85c7c2bc110734 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 24 Oct 2024 15:39:34 +0200 Subject: [PATCH 5/7] move input generators --- optimum/exporters/onnx/model_configs.py | 80 ++----------------------- optimum/utils/__init__.py | 5 ++ optimum/utils/import_utils.py | 1 + optimum/utils/input_generators.py | 78 +++++++++++++++++++++++- 4 files changed, 88 insertions(+), 76 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 6bd32c19cef..7588e300ea3 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from packaging import version + from transformers.utils import is_tf_available from ...onnx import merge_decoders @@ -28,6 +29,8 @@ DummyCodegenDecoderTextInputGenerator, DummyDecoderTextInputGenerator, DummyEncodecInputGenerator, + DummyFluxTransformerTextInputGenerator, + DummyFluxTransformerVisionInputGenerator, DummyInputGenerator, DummyIntGenerator, DummyPastKeyValuesGenerator, @@ -38,6 +41,9 @@ DummySpeechT5InputGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, + DummyTransformerTextInputGenerator, + DummyTransformerTimestpsInputGenerator, + DummyTransformerVisionInputGenerator, DummyVisionEmbeddingsGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, @@ -1168,39 +1174,6 @@ def outputs(self): } -class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator): - SUPPORTED_INPUT_NAMES = ("timestep",) - - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - if input_name == "timestep": - shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor - return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) - - return super().generate(input_name, framework, int_dtype, float_dtype) - - -class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator): - SUPPORTED_INPUT_NAMES = ("hidden_states",) - - -class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): - SUPPORTED_INPUT_NAMES = ( - "encoder_hidden_states", - "pooled_projection", - ) - - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - if input_name == "encoder_hidden_states": - return super().generate(input_name, framework, int_dtype, float_dtype)[0] - - elif input_name == "pooled_projections": - return self.random_float_tensor( - [self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype - ) - - return super().generate(input_name, framework, int_dtype, float_dtype) - - class SD3TransformerOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu @@ -1246,46 +1219,6 @@ def torch_to_onnx_output_map(self) -> Dict[str, str]: } -class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator): - SUPPORTED_INPUT_NAMES = ( - "hidden_states", - "img_ids", - ) - - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - if input_name == "hidden_states": - shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels] - return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) - elif input_name == "img_ids": - shape = ( - [(self.height // 2) * (self.width // 2), 3] - if check_if_diffusers_greater("0.31.0") - else [self.batch_size, (self.height // 2) * (self.width // 2), 3] - ) - return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) - - return super().generate(input_name, framework, int_dtype, float_dtype) - - -class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator): - SUPPORTED_INPUT_NAMES = ( - "encoder_hidden_states", - "pooled_projections", - "txt_ids", - ) - - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - if input_name == "txt_ids": - shape = ( - [self.sequence_length, 3] - if check_if_diffusers_greater("0.31.0") - else [self.batch_size, self.sequence_length, 3] - ) - return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) - - return super().generate(input_name, framework, int_dtype, float_dtype) - - class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTransformerTimestpsInputGenerator, @@ -1296,7 +1229,6 @@ class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig): @property def inputs(self): common_inputs = super().inputs - common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"} common_inputs["txt_ids"] = ( {0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"} diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index cdf7c45b0e1..3755517ab1b 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -53,6 +53,8 @@ DummyCodegenDecoderTextInputGenerator, DummyDecoderTextInputGenerator, DummyEncodecInputGenerator, + DummyFluxTransformerTextInputGenerator, + DummyFluxTransformerVisionInputGenerator, DummyInputGenerator, DummyIntGenerator, DummyLabelsGenerator, @@ -64,6 +66,9 @@ DummySpeechT5InputGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, + DummyTransformerTextInputGenerator, + DummyTransformerTimestpsInputGenerator, + DummyTransformerVisionInputGenerator, DummyVisionEmbeddingsGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 4a57fda79ce..49688d60837 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -22,6 +22,7 @@ import numpy as np from packaging import version + from transformers.utils import is_torch_available diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index dd29821766a..27cc2d075bc 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -20,9 +20,10 @@ from typing import Any, List, Optional, Tuple, Union import numpy as np + from transformers.utils import is_tf_available, is_torch_available -from ..utils import check_if_transformers_greater +from ..utils import check_if_diffusers_greater, check_if_transformers_greater from .normalized_config import ( NormalizedConfig, NormalizedEncoderDecoderConfig, @@ -36,7 +37,7 @@ import torch if is_tf_available(): - import tensorflow as tf + import tensorflow as tf # type: ignore def check_framework_is_available(func): @@ -1411,3 +1412,76 @@ def generate( float_dtype: str = "fp32", ): return self.random_int_tensor(shape=(1,), min_value=20, max_value=22, framework=framework, dtype=int_dtype) + + +class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator): + SUPPORTED_INPUT_NAMES = ("timestep",) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "timestep": + shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor + return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states",) + + +class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "encoder_hidden_states", + "pooled_projection", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "encoder_hidden_states": + return super().generate(input_name, framework, int_dtype, float_dtype)[0] + + elif input_name == "pooled_projections": + return self.random_float_tensor( + [self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype + ) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "hidden_states", + "img_ids", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "hidden_states": + shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "img_ids": + shape = ( + [(self.height // 2) * (self.width // 2), 3] + if check_if_diffusers_greater("0.31.0") + else [self.batch_size, (self.height // 2) * (self.width // 2), 3] + ) + return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "encoder_hidden_states", + "pooled_projections", + "txt_ids", + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "txt_ids": + shape = ( + [self.sequence_length, 3] + if check_if_diffusers_greater("0.31.0") + else [self.batch_size, self.sequence_length, 3] + ) + return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework, int_dtype, float_dtype) From 88901adfec780802188fd907b29adbd5f9db7084 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 24 Oct 2024 16:07:13 +0200 Subject: [PATCH 6/7] dummy diffusers --- optimum/onnxruntime/__init__.py | 72 ++++++++++++++++++----- optimum/utils/dummy_diffusers_objects.py | 74 ++++++++++++++++++++++-- 2 files changed, 126 insertions(+), 20 deletions(-) diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 4e25a436909..f3f1535fd45 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -74,33 +74,51 @@ raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: _import_structure[".utils.dummy_diffusers_objects"] = [ - "ORTStableDiffusionPipeline", + "ORTDiffusionPipeline", + "ORTPipelineForText2Image", + "ORTPipelineForImage2Image", + "ORTPipelineForInpainting", + # flux + "ORTFluxPipeline", + # lcm + "ORTLatentConsistencyModelImg2ImgPipeline", + "ORTLatentConsistencyModelPipeline", + # sd3 + "ORTStableDiffusion3Img2ImgPipeline", + "ORTStableDiffusion3InpaintPipeline", + "ORTStableDiffusion3Pipeline", + # sd "ORTStableDiffusionImg2ImgPipeline", "ORTStableDiffusionInpaintPipeline", - "ORTStableDiffusionXLPipeline", + "ORTStableDiffusionPipeline", + # xl "ORTStableDiffusionXLImg2ImgPipeline", "ORTStableDiffusionXLInpaintPipeline", - "ORTLatentConsistencyModelPipeline", - "ORTLatentConsistencyModelImg2ImgPipeline", - "ORTPipelineForImage2Image", - "ORTPipelineForInpainting", - "ORTPipelineForText2Image", - "ORTDiffusionPipeline", + "ORTStableDiffusionXLPipeline", ] else: _import_structure["modeling_diffusion"] = [ - "ORTStableDiffusionPipeline", + "ORTDiffusionPipeline", + "ORTPipelineForText2Image", + "ORTPipelineForImage2Image", + "ORTPipelineForInpainting", + # flux + "ORTFluxPipeline", + # lcm + "ORTLatentConsistencyModelImg2ImgPipeline", + "ORTLatentConsistencyModelPipeline", + # sd3 + "ORTStableDiffusion3Img2ImgPipeline", + "ORTStableDiffusion3InpaintPipeline", + "ORTStableDiffusion3Pipeline", + # sd "ORTStableDiffusionImg2ImgPipeline", "ORTStableDiffusionInpaintPipeline", - "ORTStableDiffusionXLPipeline", + "ORTStableDiffusionPipeline", + # xl "ORTStableDiffusionXLImg2ImgPipeline", "ORTStableDiffusionXLInpaintPipeline", - "ORTLatentConsistencyModelImg2ImgPipeline", - "ORTLatentConsistencyModelPipeline", - "ORTPipelineForImage2Image", - "ORTPipelineForInpainting", - "ORTPipelineForText2Image", - "ORTDiffusionPipeline", + "ORTStableDiffusionXLPipeline", ] @@ -151,30 +169,52 @@ raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_diffusers_objects import ( + # generic entrypoint ORTDiffusionPipeline, + # flux + ORTFluxPipeline, + # lcm ORTLatentConsistencyModelImg2ImgPipeline, ORTLatentConsistencyModelPipeline, + # task-specific entrypoints ORTPipelineForImage2Image, ORTPipelineForInpainting, ORTPipelineForText2Image, + # sd3 + ORTStableDiffusion3Img2ImgPipeline, + ORTStableDiffusion3InpaintPipeline, + ORTStableDiffusion3Pipeline, + # sd ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, + # xl ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLPipeline, ) else: from .modeling_diffusion import ( + # generic entrypoint ORTDiffusionPipeline, + # flux + ORTFluxPipeline, + # lcm ORTLatentConsistencyModelImg2ImgPipeline, ORTLatentConsistencyModelPipeline, + # task-specific entrypoints ORTPipelineForImage2Image, ORTPipelineForInpainting, ORTPipelineForText2Image, + # sd3 + ORTStableDiffusion3Img2ImgPipeline, + ORTStableDiffusion3InpaintPipeline, + ORTStableDiffusion3Pipeline, + # sd ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, + # xl ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLPipeline, diff --git a/optimum/utils/dummy_diffusers_objects.py b/optimum/utils/dummy_diffusers_objects.py index 35d1ffe9fc7..ff8b587e19f 100644 --- a/optimum/utils/dummy_diffusers_objects.py +++ b/optimum/utils/dummy_diffusers_objects.py @@ -15,6 +15,50 @@ from .import_utils import DummyObject, requires_backends +class ORTDiffusionPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForText2Image(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForImage2Image(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForInpainting(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + class ORTStableDiffusionPipeline(metaclass=DummyObject): _backends = ["diffusers"] @@ -70,6 +114,17 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) +class ORTStableDiffusionXLInpaintPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + class ORTLatentConsistencyModelPipeline(metaclass=DummyObject): _backends = ["diffusers"] @@ -81,7 +136,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) -class ORTDiffusionPipeline(metaclass=DummyObject): +class ORTLatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject): _backends = ["diffusers"] def __init__(self, *args, **kwargs): @@ -92,7 +147,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) -class ORTPipelineForText2Image(metaclass=DummyObject): +class ORTStableDiffusion3Pipeline(metaclass=DummyObject): _backends = ["diffusers"] def __init__(self, *args, **kwargs): @@ -103,7 +158,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) -class ORTPipelineForImage2Image(metaclass=DummyObject): +class ORTStableDiffusion3Img2ImgPipeline(metaclass=DummyObject): _backends = ["diffusers"] def __init__(self, *args, **kwargs): @@ -114,7 +169,18 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) -class ORTPipelineForInpainting(metaclass=DummyObject): +class ORTStableDiffusion3InpaintPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTFluxPipeline(metaclass=DummyObject): _backends = ["diffusers"] def __init__(self, *args, **kwargs): From 11467aebe4e5ecc8a89a1aea66679e7261759002 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 28 Oct 2024 23:33:49 +0100 Subject: [PATCH 7/7] style --- optimum/exporters/onnx/model_configs.py | 2 +- optimum/utils/import_utils.py | 1 - optimum/utils/input_generators.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 7588e300ea3..0c39168f4ec 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Model specific ONNX configurations.""" + import random from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from packaging import version - from transformers.utils import is_tf_available from ...onnx import merge_decoders diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 49688d60837..4a57fda79ce 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -22,7 +22,6 @@ import numpy as np from packaging import version - from transformers.utils import is_torch_available diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 27cc2d075bc..4418c523ef0 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -20,7 +20,6 @@ from typing import Any, List, Optional, Tuple, Union import numpy as np - from transformers.utils import is_tf_available, is_torch_available from ..utils import check_if_diffusers_greater, check_if_transformers_greater