From 1c1dcb5c9a18f589e3c07af0ef899cb619e46ac7 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Tue, 29 Oct 2024 10:26:50 -0700 Subject: [PATCH] In SD decoder/diffusion configuration, add a device type field to check if hlfb is needed. Also update the converter code to inspect device type flag. This hopefully will limit the number of places where `enable_hlfb` should be flipped. PiperOrigin-RevId: 691071718 --- .../stable_diffusion/convert_to_tflite.py | 15 +++- .../examples/stable_diffusion/decoder.py | 28 ++++++-- .../examples/stable_diffusion/diffusion.py | 70 ++++++++++++++----- .../generative/layers/unet/model_config.py | 3 + 4 files changed, 88 insertions(+), 28 deletions(-) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py index 6b305647..6c6aed7e 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py @@ -61,6 +61,13 @@ default=True, ) +_DEVICE_TYPE = flags.DEFINE_string( + 'device_type', + 'cpu', + help='The device type of the model. Currently supported: cpu, gpu.', + default='cpu', +) + @torch.inference_mode def convert_stable_diffusion_to_tflite( @@ -80,13 +87,17 @@ def convert_stable_diffusion_to_tflite( ) loader.load(clip_model, strict=False) - diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2)) + diffusion_model = diffusion.Diffusion( + diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value) + ) diffusion_loader = stable_diffusion_loader.DiffusionModelLoader( diffusion_ckpt_path, diffusion.TENSOR_NAMES ) diffusion_loader.load(diffusion_model, strict=False) - decoder_model = decoder.Decoder(decoder.get_model_config()) + decoder_model = decoder.Decoder( + decoder.get_model_config(device_type=_DEVICE_TYPE.value) + ) decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader( decoder_ckpt_path, decoder.TENSOR_NAMES ) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py index e65d2626..8a28faa5 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py @@ -270,8 +270,8 @@ def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor: return x -def get_model_config() -> unet_cfg.AutoEncoderConfig: - """Get configs for the Decoder of Stable Diffusion v1.5""" +def get_model_config(device_type: str = "cpu") -> unet_cfg.AutoEncoderConfig: + """Get configs for the Decoder of Stable Diffusion v1.5.""" in_channels = 3 latent_channels = 4 out_channels = 3 @@ -279,8 +279,14 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig: scaling_factor = 0.18215 layers_per_block = 3 + # For now, only turns on StableHLO composite ops on GPU backend for better + # performance. CPU should also switch to it once the support is done. + enable_hlfb = True if device_type == "gpu" else False + norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, group_num=32 + layers_cfg.NormalizationType.GROUP_NORM, + group_num=32, + enable_hlfb=enable_hlfb, ) att_config = unet_cfg.AttentionBlock2DConfig( @@ -298,7 +304,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig: rotary_base=0, rotary_percentage=0.0, ), - enable_hlfb=False, + enable_hlfb=enable_hlfb, ) mid_block_config = unet_cfg.MidBlock2DConfig( @@ -327,7 +333,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig: return config -def get_fake_model_config() -> unet_cfg.AutoEncoderConfig: +def get_fake_model_config( + device_type: str = "cpu", +) -> unet_cfg.AutoEncoderConfig: """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing.""" in_channels = 3 latent_channels = 4 @@ -336,8 +344,14 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig: scaling_factor = 0.18215 layers_per_block = 2 + # For now, only turns on StableHLO composite ops on GPU backend for better + # performance. CPU should also switch to it once the support is done. + enable_hlfb = True if device_type == "gpu" else False + norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, group_num=2 + layers_cfg.NormalizationType.GROUP_NORM, + group_num=2, + enable_hlfb=enable_hlfb, ) att_config = unet_cfg.AttentionBlock2DConfig( @@ -355,7 +369,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig: rotary_base=0, rotary_percentage=0.0, ), - enable_hlfb=False, + enable_hlfb=enable_hlfb, ) mid_block_config = unet_cfg.MidBlock2DConfig( diff --git a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py index 266ab8dd..661c7023 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py @@ -333,7 +333,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): dim=output_channel, num_query_groups=config.transformer_num_attention_heads, ), - enable_hlfb=False, + enable_hlfb=config.enable_hlfb, ), cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( query_dim=output_channel, @@ -347,7 +347,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): dim=output_channel, num_query_groups=config.transformer_num_attention_heads, ), - enable_hlfb=False, + enable_hlfb=config.enable_hlfb, ), pre_conv_normalization_config=config.transformer_pre_conv_norm_config, feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig( @@ -405,7 +405,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): dim=mid_block_channels, num_query_groups=config.transformer_num_attention_heads, ), - enable_hlfb=False, + enable_hlfb=config.enable_hlfb, ), cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( query_dim=mid_block_channels, @@ -419,7 +419,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): dim=mid_block_channels, num_query_groups=config.transformer_num_attention_heads, ), - enable_hlfb=False, + enable_hlfb=config.enable_hlfb, ), pre_conv_normalization_config=config.transformer_pre_conv_norm_config, feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig( @@ -478,7 +478,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): dim=output_channel, num_query_groups=config.transformer_num_attention_heads, ), - enable_hlfb=False, + enable_hlfb=config.enable_hlfb, ), cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( query_dim=output_channel, @@ -492,7 +492,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): dim=output_channel, num_query_groups=config.transformer_num_attention_heads, ), - enable_hlfb=False, + enable_hlfb=config.enable_hlfb, ), pre_conv_normalization_config=config.transformer_pre_conv_norm_config, feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig( @@ -581,13 +581,16 @@ def forward( return x -def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: - """Get configs for the Diffusion model of Stable Diffusion v1.5 +def get_model_config( + batch_size: int, device_type: str = "cpu" +) -> unet_cfg.DiffusionModelConfig: + """Get configs for the Diffusion model of Stable Diffusion v1.5. Args: batch_size (int): the batch size of input. + device_type (str): the device type of the model. Default to "cpu". - Retruns: + Returns: The configuration of diffusion model of Stable Diffusion v1.5. """ in_channels = 4 @@ -596,9 +599,15 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: layers_per_block = 2 downsample_padding = 1 + # For now, only turns on StableHLO composite ops on GPU backend for better + # performance. CPU should also switch to it once the support is done. + enable_hlfb = True if device_type == "gpu" else False + # Residual configs. residual_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, group_num=32 + layers_cfg.NormalizationType.GROUP_NORM, + group_num=32, + enable_hlfb=enable_hlfb, ) residual_activation_type = layers_cfg.ActivationType.SILU @@ -607,10 +616,14 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: transformer_batch_size = batch_size transformer_cross_attention_dim = 768 # Embedding from CLIP model transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32 + layers_cfg.NormalizationType.GROUP_NORM, + epsilon=1e-6, + group_num=32, + enable_hlfb=enable_hlfb, ) transformer_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.LAYER_NORM + layers_cfg.NormalizationType.LAYER_NORM, + enable_hlfb=enable_hlfb, ) transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU @@ -623,7 +636,9 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: # Finaly layer configs. final_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, group_num=32 + layers_cfg.NormalizationType.GROUP_NORM, + group_num=32, + enable_hlfb=enable_hlfb, ) final_activation_type = layers_cfg.ActivationType.SILU @@ -646,16 +661,20 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: time_embedding_blocks_dim=time_embedding_blocks_dim, final_norm_config=final_norm_config, final_activation_type=final_activation_type, + enable_hlfb=enable_hlfb, ) -def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: +def get_fake_model_config( + batch_size: int, device_type: str = "cpu" +) -> unet_cfg.DiffusionModelConfig: """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing. Args: batch_size (int): the batch size of input. + device_type (str): the device type of the model. Default to "cpu". - Retruns: + Returns: The configuration of diffusion model of Stable Diffusion v1.5. """ in_channels = 4 @@ -664,9 +683,15 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: layers_per_block = 1 downsample_padding = 1 + # For now, only turns on StableHLO composite ops on GPU backend for better + # performance. CPU should also switch to it once the support is done. + enable_hlfb = True if device_type == "gpu" else False + # Residual configs. residual_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, group_num=2 + layers_cfg.NormalizationType.GROUP_NORM, + group_num=2, + enable_hlfb=enable_hlfb, ) residual_activation_type = layers_cfg.ActivationType.SILU @@ -675,10 +700,14 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: transformer_batch_size = batch_size transformer_cross_attention_dim = 4 # Embedding from CLIP model transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2 + layers_cfg.NormalizationType.GROUP_NORM, + epsilon=1e-6, + group_num=2, + enable_hlfb=enable_hlfb, ) transformer_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.LAYER_NORM + layers_cfg.NormalizationType.LAYER_NORM, + enable_hlfb=enable_hlfb, ) transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU @@ -691,7 +720,9 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: # Finaly layer configs. final_norm_config = layers_cfg.NormalizationConfig( - layers_cfg.NormalizationType.GROUP_NORM, group_num=2 + layers_cfg.NormalizationType.GROUP_NORM, + group_num=2, + enable_hlfb=enable_hlfb, ) final_activation_type = layers_cfg.ActivationType.SILU @@ -714,4 +745,5 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: time_embedding_blocks_dim=time_embedding_blocks_dim, final_norm_config=final_norm_config, final_activation_type=final_activation_type, + enable_hlfb=enable_hlfb, ) diff --git a/ai_edge_torch/generative/layers/unet/model_config.py b/ai_edge_torch/generative/layers/unet/model_config.py index d7d91b79..1236f246 100644 --- a/ai_edge_torch/generative/layers/unet/model_config.py +++ b/ai_edge_torch/generative/layers/unet/model_config.py @@ -269,3 +269,6 @@ class DiffusionModelConfig: # Activation type used in final layer final_activation_type: layers_cfg.ActivationType + + # Whether to enable StableHLO composite ops in the model. + enable_hlfb: bool = False