Skip to content

Commit

Permalink
SD3 and Flux support (#2073)
Browse files Browse the repository at this point in the history
* sd3 support

* unsupported cli model types

* flux transformer support, unet export fixes, updated callback test, updated negative prompt test, flux and sd3 tests

* fixes

* move input generators

* dummy diffusers

* style

* sd3 support

* unsupported cli model types

* flux transformer support, unet export fixes, updated callback test, updated negative prompt test, flux and sd3 tests

* fixes

* move input generators

* dummy diffusers

* style

* distribute ort tests

* fix

* fix

* fix

* test num images

* single process to reduce re-exports

* test

* revert unnecessary changes

* T5Encoder inherits from TextEncoder

* style

* fix typo in timestep

* style

* only test sd3 and flux on latest transformers

* conditional sd3 and flux modeling

* forgot sd3 inpaint
  • Loading branch information
IlyasMoutawwakil authored Nov 19, 2024
1 parent 400bb82 commit a7a807c
Show file tree
Hide file tree
Showing 18 changed files with 791 additions and 217 deletions.
13 changes: 4 additions & 9 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ jobs:
os: ubuntu-20.04

runs-on: ${{ matrix.os }}

steps:
- name: Free Disk Space (Ubuntu)
if: matrix.os == 'ubuntu-20.04'
uses: jlumbroso/free-disk-space@main
with:
tool-cache: false
swap-storage: false
large-packages: false

- name: Checkout code
uses: actions/checkout@v4
Expand All @@ -54,13 +51,11 @@ jobs:
run: pip install transformers==${{ matrix.transformers-version }}

- name: Test with pytest (in series)
working-directory: tests
run: |
pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s
pytest tests/onnxruntime -m "run_in_series" --durations=0 -vvvv -s
- name: Test with pytest (in parallel)
run: |
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
working-directory: tests
run: |
pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def fix_dynamic_axes(
input_shapes = {}
dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes)
dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names)
dummy_inputs = self.rename_ambiguous_inputs(dummy_inputs)

onnx_inputs = {}
for name, value in dummy_inputs.items():
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,10 @@ def onnx_export_from_model(
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

tokenizer_3 = getattr(model, "tokenizer_3", None)
if tokenizer_3 is not None:
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))

model.save_config(output)

if float_dtype == "bf16":
Expand Down
123 changes: 109 additions & 14 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specific ONNX configurations."""

import random
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
Expand All @@ -28,6 +29,8 @@
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyFluxTransformerTextInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyInputGenerator,
DummyIntGenerator,
DummyPastKeyValuesGenerator,
Expand All @@ -38,6 +41,9 @@
DummySpeechT5InputGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyTransformerTextInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
Expand All @@ -53,6 +59,7 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_diffusers_greater,
check_if_transformers_greater,
is_diffusers_available,
logging,
Expand Down Expand Up @@ -1039,22 +1046,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}

if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# TODO: fix should be by casting inputs during inference and not export
if framework == "pt":
import torch

dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
Expand All @@ -1064,7 +1062,7 @@ def patch_model_for_export(


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand All @@ -1087,17 +1085,19 @@ class UNetOnnxConfig(VisionOnnxConfig):
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"sample": {0: "batch_size", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"timestep": {}, # a scalar with no dimension
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}

# TODO : add text_image, image and image_embeds
# TODO : add addition_embed_type == text_image, image and image_embeds
# https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
common_inputs["timestep_cond"] = {0: "batch_size"}

return common_inputs

@property
Expand Down Expand Up @@ -1136,7 +1136,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:


class VaeEncoderOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 3e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand Down Expand Up @@ -1184,6 +1184,101 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class T5EncoderOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 12 # int64 was supported since opset 12

@property
def inputs(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self):
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


class SD3TransformerOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyTransformerTextInputGenerator,
)

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
vocab_size="attention_head_dim",
hidden_size="joint_attention_dim",
projection_size="pooled_projection_dim",
allow_new=True,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"hidden_states": {0: "batch_size", 2: "height", 3: "width"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"pooled_projections": {0: "batch_size"},
"timestep": {0: "step"},
}

return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"out_hidden_states": {0: "batch_size", 2: "height", 3: "width"},
}

@property
def torch_to_onnx_output_map(self) -> Dict[str, str]:
return {
"sample": "out_hidden_states",
}


class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyFluxTransformerTextInputGenerator,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = (
{0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"}
)
common_inputs["img_ids"] = (
{0: "packed_height_width"}
if check_if_diffusers_greater("0.31.0")
else {0: "batch_size", 1: "packed_height_width"}
)

if getattr(self._normalized_config, "guidance_embeds", False):
common_inputs["guidance"] = {0: "batch_size"}

return common_inputs

@property
def outputs(self):
return {
"out_hidden_states": {0: "batch_size", 1: "packed_height_width"},
}


class GroupViTOnnxConfig(CLIPOnnxConfig):
pass

Expand Down
29 changes: 23 additions & 6 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,27 @@ class TasksManager:
}

_DIFFUSERS_SUPPORTED_MODEL_TYPE = {
"clip-text-model": supported_tasks_mapping(
"t5-encoder": supported_tasks_mapping(
"feature-extraction",
onnx="T5EncoderOnnxConfig",
),
"clip-text": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextOnnxConfig",
),
"clip-text-with-projection": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextWithProjectionOnnxConfig",
),
"unet": supported_tasks_mapping(
"flux-transformer-2d": supported_tasks_mapping(
"semantic-segmentation",
onnx="FluxTransformerOnnxConfig",
),
"sd3-transformer-2d": supported_tasks_mapping(
"semantic-segmentation",
onnx="SD3TransformerOnnxConfig",
),
"unet-2d-condition": supported_tasks_mapping(
"semantic-segmentation",
onnx="UNetOnnxConfig",
),
Expand Down Expand Up @@ -1177,12 +1189,17 @@ class TasksManager:
"transformers": _SUPPORTED_MODEL_TYPE,
}
_UNSUPPORTED_CLI_MODEL_TYPE = {
"unet",
# diffusers model types
"clip-text",
"clip-text-with-projection",
"flux-transformer-2d",
"sd3-transformer-2d",
"t5-encoder",
"unet-2d-condition",
"vae-encoder",
"vae-decoder",
"clip-text-model",
"clip-text-with-projection",
"trocr", # supported through the vision-encoder-decoder model type
# redundant model types
"trocr", # same as vision-encoder-decoder
}
_SUPPORTED_CLI_MODEL_TYPE = (
set(_SUPPORTED_MODEL_TYPE.keys())
Expand Down
Loading

0 comments on commit a7a807c

Please sign in to comment.