Skip to content

Commit

Permalink
add support task for unet-controlnet
Browse files Browse the repository at this point in the history
  • Loading branch information
TianmengChen committed Jul 10, 2024
1 parent 4237e1d commit b163814
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 5 deletions.
5 changes: 5 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class TasksManager:
_DIFFUSERS_TASKS_TO_MODEL_LOADERS = {
"stable-diffusion": "StableDiffusionPipeline",
"stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline",
"stable-diffusion-controlnet": "StableDiffusionPipeline",
}

_TIMM_TASKS_TO_MODEL_LOADERS = {
Expand Down Expand Up @@ -304,6 +305,10 @@ class TasksManager:
"feature-extraction",
onnx="CLIPTextWithProjectionOnnxConfig",
),
"unet-controlnet": supported_tasks_mapping(
"semantic-segmentation",
onnx="UNetOnnxConfig",
),
"unet": supported_tasks_mapping(
"semantic-segmentation",
onnx="UNetOnnxConfig",
Expand Down
215 changes: 210 additions & 5 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from transformers.modeling_tf_utils import TFPreTrainedModel

if is_diffusers_available():
from diffusers import ModelMixin, StableDiffusionPipeline
from diffusers import ModelMixin, StableDiffusionPipeline, StableDiffusionControlNetPipeline


ENCODER_NAME = "encoder_model"
Expand Down Expand Up @@ -124,6 +124,121 @@ def _get_submodels_for_export_stable_diffusion(

return models_for_export

def _get_submodels_for_export_stable_diffusion_controlnet(
pipeline: "StableDiffusionPipeline",
) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]:
"""
Returns the components of a Stable Diffusion model.
"""
from diffusers import StableDiffusionXLImg2ImgPipeline

models_for_export = {}
if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline):
projection_dim = pipeline.text_encoder_2.config.projection_dim
else:
projection_dim = pipeline.text_encoder.config.projection_dim

# Text encoder
if pipeline.text_encoder is not None:
if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline):
pipeline.text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = pipeline.text_encoder

# U-NET
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
class UnetWrapper(torch.nn.Module):
def __init__(
self,
unet,
sample_dtype=torch.float32,
timestep_dtype=torch.int64,
encoder_hidden_states=torch.float32,
down_block_additional_residuals=torch.float32,
mid_block_additional_residual=torch.float32,
):
super().__init__()
self.unet = unet
self.config = unet.config
self.sample_dtype = sample_dtype
self.timestep_dtype = timestep_dtype
self.encoder_hidden_states_dtype = encoder_hidden_states
self.down_block_additional_residuals_dtype = down_block_additional_residuals
self.mid_block_additional_residual_dtype = mid_block_additional_residual

def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
down_block_additional_residual_1: torch.Tensor,
down_block_additional_residual_3: torch.Tensor,
down_block_additional_residual_5: torch.Tensor,
down_block_additional_residual_7: torch.Tensor,
down_block_additional_residual_9: torch.Tensor,
down_block_additional_residual_11: torch.Tensor,
down_block_additional_residual_13: torch.Tensor,
down_block_additional_residual_15: torch.Tensor,
down_block_additional_residual_17: torch.Tensor,
down_block_additional_residual_19: torch.Tensor,
down_block_additional_residual_21: torch.Tensor,
down_block_additional_residual: torch.Tensor,
mid_block_additional_residual: torch.Tensor,
):
sample.to(self.sample_dtype)
timestep.to(self.timestep_dtype)
encoder_hidden_states.to(self.encoder_hidden_states_dtype)
down_block_additional_residuals = [down_block_additional_residual_1.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_3.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_5.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_7.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_9.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_11.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_13.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_15.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_17.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_19.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual_21.to(self.down_block_additional_residuals_dtype),
down_block_additional_residual.to(self.down_block_additional_residuals_dtype)]
mid_block_additional_residual.to(self.mid_block_additional_residual_dtype)
return self.unet(
sample,
timestep,
encoder_hidden_states,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
)


is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")
if not is_torch_greater_or_equal_than_2_1:
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)

models_for_export["unet"] = UnetWrapper(pipeline.unet)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = copy.deepcopy(pipeline.vae)
if not is_torch_greater_or_equal_than_2_1:
vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder)
vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()}
models_for_export["vae_encoder"] = vae_encoder

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
vae_decoder = copy.deepcopy(pipeline.vae)
if not is_torch_greater_or_equal_than_2_1:
vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export

def _get_submodels_for_export_decoder(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
Expand Down Expand Up @@ -294,7 +409,7 @@ def get_stable_diffusion_models_for_export(
)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config)

# U-NET
# U-NET
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet,
exporter=exporter,
Expand Down Expand Up @@ -344,6 +459,91 @@ def get_stable_diffusion_models_for_export(

return models_for_export

def get_stable_diffusion_controlnet_models_for_export(
pipeline: "StableDiffusionPipeline",
int_dtype: str = "int64",
float_dtype: str = "fp32",
exporter: str = "onnx",
) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "ExportConfig"]]:
"""
Returns the components of a Stable Diffusion model and their subsequent export configs.
Args:
pipeline ([`StableDiffusionPipeline`]):
The model to export.
int_dtype (`str`, defaults to `"int64"`):
The data type of integer tensors, could be ["int64", "int32", "int8"], default to "int64".
float_dtype (`str`, defaults to `"fp32"`):
The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32".
Returns:
`Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `ExportConfig`]: A Dict containing the model and
export configs for the different components of the model.
"""
models_for_export = _get_submodels_for_export_stable_diffusion_controlnet(pipeline)
# Text encoder
if "text_encoder" in models_for_export:
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config)

# U-NET
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="unet-controlnet",
)
unet_export_config = export_config_constructor(pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype)

models_for_export["unet"] = (models_for_export["unet"], unet_export_config)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = models_for_export["vae_encoder"]
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-encoder",
)
vae_export_config = vae_config_constructor(vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype)
models_for_export["vae_encoder"] = (vae_encoder, vae_export_config)

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
vae_decoder = models_for_export["vae_decoder"]
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-decoder",
)
vae_export_config = vae_config_constructor(vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype)
models_for_export["vae_decoder"] = (vae_decoder, vae_export_config)

if "text_encoder_2" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="clip-text-with-projection",
)
export_config = export_config_constructor(
pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config)

return models_for_export

def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"):
models_for_export = {
Expand Down Expand Up @@ -523,9 +723,14 @@ def _get_submodels_and_export_configs(
if not custom_architecture:
if library_name == "diffusers":
export_config = None
models_and_export_configs = get_stable_diffusion_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype, exporter=exporter
)
if task == "stable-diffusion-controlnet":
models_and_export_configs = get_stable_diffusion_controlnet_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype, exporter=exporter
)
else:
models_and_export_configs = get_stable_diffusion_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype, exporter=exporter
)
else:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter=exporter, task=task, library_name=library_name
Expand Down

0 comments on commit b163814

Please sign in to comment.