Skip to content

Commit

Permalink
In SD decoder/diffusion configuration, add a device type field to che…
Browse files Browse the repository at this point in the history
…ck if hlfb is needed. Also update the converter code to inspect device type flag. This hopefully will limit the number of places where `enable_hlfb` should be flipped.

PiperOrigin-RevId: 691071718
  • Loading branch information
haozha111 authored and copybara-github committed Oct 29, 2024
1 parent 61800ad commit 1c1dcb5
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@
default=True,
)

_DEVICE_TYPE = flags.DEFINE_string(
'device_type',
'cpu',
help='The device type of the model. Currently supported: cpu, gpu.',
default='cpu',
)


@torch.inference_mode
def convert_stable_diffusion_to_tflite(
Expand All @@ -80,13 +87,17 @@ def convert_stable_diffusion_to_tflite(
)
loader.load(clip_model, strict=False)

diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
diffusion_model = diffusion.Diffusion(
diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value)
)
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
diffusion_ckpt_path, diffusion.TENSOR_NAMES
)
diffusion_loader.load(diffusion_model, strict=False)

decoder_model = decoder.Decoder(decoder.get_model_config())
decoder_model = decoder.Decoder(
decoder.get_model_config(device_type=_DEVICE_TYPE.value)
)
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
decoder_ckpt_path, decoder.TENSOR_NAMES
)
Expand Down
28 changes: 21 additions & 7 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,23 @@ def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
return x


def get_model_config() -> unet_cfg.AutoEncoderConfig:
"""Get configs for the Decoder of Stable Diffusion v1.5"""
def get_model_config(device_type: str = "cpu") -> unet_cfg.AutoEncoderConfig:
"""Get configs for the Decoder of Stable Diffusion v1.5."""
in_channels = 3
latent_channels = 4
out_channels = 3
block_out_channels = [128, 256, 512, 512]
scaling_factor = 0.18215
layers_per_block = 3

# For now, only turns on StableHLO composite ops on GPU backend for better
# performance. CPU should also switch to it once the support is done.
enable_hlfb = True if device_type == "gpu" else False

norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=32
layers_cfg.NormalizationType.GROUP_NORM,
group_num=32,
enable_hlfb=enable_hlfb,
)

att_config = unet_cfg.AttentionBlock2DConfig(
Expand All @@ -298,7 +304,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
rotary_base=0,
rotary_percentage=0.0,
),
enable_hlfb=False,
enable_hlfb=enable_hlfb,
)

mid_block_config = unet_cfg.MidBlock2DConfig(
Expand Down Expand Up @@ -327,7 +333,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
return config


def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
def get_fake_model_config(
device_type: str = "cpu",
) -> unet_cfg.AutoEncoderConfig:
"""Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
in_channels = 3
latent_channels = 4
Expand All @@ -336,8 +344,14 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
scaling_factor = 0.18215
layers_per_block = 2

# For now, only turns on StableHLO composite ops on GPU backend for better
# performance. CPU should also switch to it once the support is done.
enable_hlfb = True if device_type == "gpu" else False

norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
layers_cfg.NormalizationType.GROUP_NORM,
group_num=2,
enable_hlfb=enable_hlfb,
)

att_config = unet_cfg.AttentionBlock2DConfig(
Expand All @@ -355,7 +369,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
rotary_base=0,
rotary_percentage=0.0,
),
enable_hlfb=False,
enable_hlfb=enable_hlfb,
)

mid_block_config = unet_cfg.MidBlock2DConfig(
Expand Down
70 changes: 51 additions & 19 deletions ai_edge_torch/generative/examples/stable_diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
dim=output_channel,
num_query_groups=config.transformer_num_attention_heads,
),
enable_hlfb=False,
enable_hlfb=config.enable_hlfb,
),
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=output_channel,
Expand All @@ -347,7 +347,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
dim=output_channel,
num_query_groups=config.transformer_num_attention_heads,
),
enable_hlfb=False,
enable_hlfb=config.enable_hlfb,
),
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
Expand Down Expand Up @@ -405,7 +405,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
dim=mid_block_channels,
num_query_groups=config.transformer_num_attention_heads,
),
enable_hlfb=False,
enable_hlfb=config.enable_hlfb,
),
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=mid_block_channels,
Expand All @@ -419,7 +419,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
dim=mid_block_channels,
num_query_groups=config.transformer_num_attention_heads,
),
enable_hlfb=False,
enable_hlfb=config.enable_hlfb,
),
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
Expand Down Expand Up @@ -478,7 +478,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
dim=output_channel,
num_query_groups=config.transformer_num_attention_heads,
),
enable_hlfb=False,
enable_hlfb=config.enable_hlfb,
),
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=output_channel,
Expand All @@ -492,7 +492,7 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
dim=output_channel,
num_query_groups=config.transformer_num_attention_heads,
),
enable_hlfb=False,
enable_hlfb=config.enable_hlfb,
),
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
Expand Down Expand Up @@ -581,13 +581,16 @@ def forward(
return x


def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
"""Get configs for the Diffusion model of Stable Diffusion v1.5
def get_model_config(
batch_size: int, device_type: str = "cpu"
) -> unet_cfg.DiffusionModelConfig:
"""Get configs for the Diffusion model of Stable Diffusion v1.5.
Args:
batch_size (int): the batch size of input.
device_type (str): the device type of the model. Default to "cpu".
Retruns:
Returns:
The configuration of diffusion model of Stable Diffusion v1.5.
"""
in_channels = 4
Expand All @@ -596,9 +599,15 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
layers_per_block = 2
downsample_padding = 1

# For now, only turns on StableHLO composite ops on GPU backend for better
# performance. CPU should also switch to it once the support is done.
enable_hlfb = True if device_type == "gpu" else False

# Residual configs.
residual_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=32
layers_cfg.NormalizationType.GROUP_NORM,
group_num=32,
enable_hlfb=enable_hlfb,
)
residual_activation_type = layers_cfg.ActivationType.SILU

Expand All @@ -607,10 +616,14 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
transformer_batch_size = batch_size
transformer_cross_attention_dim = 768 # Embedding from CLIP model
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
layers_cfg.NormalizationType.GROUP_NORM,
epsilon=1e-6,
group_num=32,
enable_hlfb=enable_hlfb,
)
transformer_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.LAYER_NORM
layers_cfg.NormalizationType.LAYER_NORM,
enable_hlfb=enable_hlfb,
)
transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU

Expand All @@ -623,7 +636,9 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:

# Finaly layer configs.
final_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=32
layers_cfg.NormalizationType.GROUP_NORM,
group_num=32,
enable_hlfb=enable_hlfb,
)
final_activation_type = layers_cfg.ActivationType.SILU

Expand All @@ -646,16 +661,20 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
time_embedding_blocks_dim=time_embedding_blocks_dim,
final_norm_config=final_norm_config,
final_activation_type=final_activation_type,
enable_hlfb=enable_hlfb,
)


def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
def get_fake_model_config(
batch_size: int, device_type: str = "cpu"
) -> unet_cfg.DiffusionModelConfig:
"""Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
Args:
batch_size (int): the batch size of input.
device_type (str): the device type of the model. Default to "cpu".
Retruns:
Returns:
The configuration of diffusion model of Stable Diffusion v1.5.
"""
in_channels = 4
Expand All @@ -664,9 +683,15 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
layers_per_block = 1
downsample_padding = 1

# For now, only turns on StableHLO composite ops on GPU backend for better
# performance. CPU should also switch to it once the support is done.
enable_hlfb = True if device_type == "gpu" else False

# Residual configs.
residual_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
layers_cfg.NormalizationType.GROUP_NORM,
group_num=2,
enable_hlfb=enable_hlfb,
)
residual_activation_type = layers_cfg.ActivationType.SILU

Expand All @@ -675,10 +700,14 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
transformer_batch_size = batch_size
transformer_cross_attention_dim = 4 # Embedding from CLIP model
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
layers_cfg.NormalizationType.GROUP_NORM,
epsilon=1e-6,
group_num=2,
enable_hlfb=enable_hlfb,
)
transformer_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.LAYER_NORM
layers_cfg.NormalizationType.LAYER_NORM,
enable_hlfb=enable_hlfb,
)
transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU

Expand All @@ -691,7 +720,9 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:

# Finaly layer configs.
final_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
layers_cfg.NormalizationType.GROUP_NORM,
group_num=2,
enable_hlfb=enable_hlfb,
)
final_activation_type = layers_cfg.ActivationType.SILU

Expand All @@ -714,4 +745,5 @@ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
time_embedding_blocks_dim=time_embedding_blocks_dim,
final_norm_config=final_norm_config,
final_activation_type=final_activation_type,
enable_hlfb=enable_hlfb,
)
3 changes: 3 additions & 0 deletions ai_edge_torch/generative/layers/unet/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,6 @@ class DiffusionModelConfig:

# Activation type used in final layer
final_activation_type: layers_cfg.ActivationType

# Whether to enable StableHLO composite ops in the model.
enable_hlfb: bool = False

0 comments on commit 1c1dcb5

Please sign in to comment.