Skip to content

Commit

Permalink
flux transformer support, unet export fixes, updated callback test, u…
Browse files Browse the repository at this point in the history
…pdated negative prompt test, flux and sd3 tests
  • Loading branch information
IlyasMoutawwakil committed Oct 23, 2024
1 parent 08190c7 commit 4518691
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 194 deletions.
178 changes: 138 additions & 40 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_diffusers_greater,
check_if_transformers_greater,
is_diffusers_available,
logging,
Expand Down Expand Up @@ -1031,7 +1032,7 @@ def patch_model_for_export(


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand All @@ -1054,17 +1055,19 @@ class UNetOnnxConfig(VisionOnnxConfig):
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"sample": {0: "batch_size", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"timestep": {}, # a scalar with no dimension
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}

# TODO : add text_image, image and image_embeds
# TODO : add addition_embed_type == text_image, image and image_embeds
# https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
common_inputs["timestep_cond"] = {0: "batch_size"}

return common_inputs

@property
Expand Down Expand Up @@ -1151,73 +1154,168 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = "pooled_projections"

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.pooled_projection_dim = normalized_config.config.pooled_projection_dim
class T5EncoderOnnxConfig(CLIPTextOnnxConfig):
@property
def inputs(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
return self.random_float_tensor(
[self.batch_size, self.pooled_projection_dim], framework=framework, dtype=float_dtype
)
@property
def outputs(self):
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
SUPPORTED_INPUT_NAMES = ("timestep",)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "timestep":
shape = [self.batch_size]
shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)


class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states",)


class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"encoder_hidden_states",
"pooled_projection",
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "encoder_hidden_states":
return super().generate(input_name, framework, int_dtype, float_dtype)[0]

elif input_name == "pooled_projections":
return self.random_float_tensor(
[self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype
)

return super().generate(input_name, framework, int_dtype, float_dtype)


class SD3TransformerOnnxConfig(UNetOnnxConfig):
class SD3TransformerOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (
(DummyTransformerTimestpsInputGenerator,)
+ UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
+ (PooledProjectionsDummyInputGenerator,)
DummyTransformerTimestpsInputGenerator,
DummyTransformerVisionInputGenerator,
DummyTransformerTextInputGenerator,
)

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="joint_attention_dim",
vocab_size="attention_head_dim",
hidden_size="joint_attention_dim",
projection_size="pooled_projection_dim",
allow_new=True,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs["pooled_projections"] = {0: "batch_size"}
return common_inputs
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"hidden_states": {0: "batch_size", 2: "height", 3: "width"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"pooled_projections": {0: "batch_size"},
"timestep": {0: "step"},
}

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
hidden_states = inputs.pop("sample", None)
if hidden_states is not None:
inputs["hidden_states"] = hidden_states
return inputs
return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"out_hidden_states": {0: "batch_size", 2: "height", 3: "width"},
}

class T5EncoderOnnxConfig(CLIPTextOnnxConfig):
@property
def inputs(self):
def torch_to_onnx_output_map(self) -> Dict[str, str]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"sample": "out_hidden_states",
}


class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator):
SUPPORTED_INPUT_NAMES = (
"hidden_states",
"img_ids",
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "hidden_states":
shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
elif input_name == "img_ids":
shape = (
[(self.height // 2) * (self.width // 2), 3]
if check_if_diffusers_greater("0.31.0")
else [self.batch_size, (self.height // 2) * (self.width // 2), 3]
)
return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)


class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"encoder_hidden_states",
"pooled_projections",
"txt_ids",
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "txt_ids":
shape = (
[self.sequence_length, 3]
if check_if_diffusers_greater("0.31.0")
else [self.batch_size, self.sequence_length, 3]
)
return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)


class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyFluxTransformerTextInputGenerator,
)

@property
def inputs(self):
common_inputs = super().inputs

common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = (
{0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"}
)
common_inputs["img_ids"] = (
{0: "packed_height_width"}
if check_if_diffusers_greater("0.31.0")
else {0: "batch_size", 1: "packed_height_width"}
)

if getattr(self._normalized_config, "guidance_embeds", False):
common_inputs["guidance"] = {0: "batch_size"}

return common_inputs

@property
def outputs(self):
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"out_hidden_states": {0: "batch_size", 1: "packed_height_width"},
}


Expand Down
25 changes: 15 additions & 10 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,26 +335,30 @@ class TasksManager:
}

_DIFFUSERS_SUPPORTED_MODEL_TYPE = {
"t5-encoder": supported_tasks_mapping(
"t5-encoder-model": supported_tasks_mapping(
"feature-extraction",
onnx="T5EncoderOnnxConfig",
),
"clip-text-model": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextOnnxConfig",
),
"clip-text-with-projection": supported_tasks_mapping(
"clip-text-model-with-projection": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextWithProjectionOnnxConfig",
),
"unet": supported_tasks_mapping(
"flux-transformer-2d-model": supported_tasks_mapping(
"semantic-segmentation",
onnx="UNetOnnxConfig",
onnx="FluxTransformerOnnxConfig",
),
"sd3-transformer": supported_tasks_mapping(
"sd3-transformer-2d-model": supported_tasks_mapping(
"semantic-segmentation",
onnx="SD3TransformerOnnxConfig",
),
"unet-2d-condition": supported_tasks_mapping(
"semantic-segmentation",
onnx="UNetOnnxConfig",
),
"vae-encoder": supported_tasks_mapping(
"semantic-segmentation",
onnx="VaeEncoderOnnxConfig",
Expand Down Expand Up @@ -1178,12 +1182,13 @@ class TasksManager:
"transformers": _SUPPORTED_MODEL_TYPE,
}
_UNSUPPORTED_CLI_MODEL_TYPE = {
# diffusers submodels
"clip-text-model",
"clip-text-with-projection",
"sd3-transformer",
"t5-encoder",
"trocr", # supported through the vision-encoder-decoder model type
"unet",
"clip-text-model-with-projection",
"flux-transformer-model",
"sd3-transformer-2d-model",
"t5-encoder-model",
"unet-2d-condition",
"vae-encoder",
"vae-decoder",
}
Expand Down
Loading

0 comments on commit 4518691

Please sign in to comment.