Skip to content

Commit

Permalink
Add UNet model config and refactor diffusion model. (#55)
Browse files Browse the repository at this point in the history
* Add UNet model config and refactor diffusion model.

* Resolve broadcasting issue, which is due to unmatched mask shape

* make enable_hlfb configurable for 2d attention blocks

* update

* Update for attention batch size
  • Loading branch information
yichunk authored Jun 13, 2024
1 parent 4b9a118 commit 5b31b82
Show file tree
Hide file tree
Showing 8 changed files with 1,576 additions and 571 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import ai_edge_torch
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
import ai_edge_torch.generative.examples.stable_diffusion.util as util
import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
import ai_edge_torch.generative.utilities.loader as loading_utils
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader


@torch.inference_mode
Expand All @@ -45,11 +45,14 @@ def convert_stable_diffusion_to_tflite(
encoder = Encoder()
encoder.load_state_dict(torch.load(encoder_ckpt_path))

diffusion = Diffusion()
diffusion.load_state_dict(torch.load(diffusion_ckpt_path))
diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
diffusion_ckpt_path, diffusion.TENSORS_NAMES
)
diffusion_loader.load(diffusion_model)

decoder_model = decoder.Decoder(decoder.get_model_config())
decoder_loader = autoencoder_loader.AutoEncoderModelLoader(
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
decoder_ckpt_path, decoder.TENSORS_NAMES
)
decoder_loader.load(decoder_model)
Expand Down Expand Up @@ -84,7 +87,7 @@ def convert_stable_diffusion_to_tflite(
# Diffusion
ai_edge_torch.signature(
'diffusion',
diffusion,
diffusion_model,
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
).convert().export('/tmp/stable_diffusion/diffusion.tflite')

Expand Down
52 changes: 30 additions & 22 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,50 @@
import ai_edge_torch.generative.layers.model_config as layers_cfg
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader

TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
post_quant_conv="0",
conv_in="1",
mid_block_tensor_names=autoencoder_loader.MidBlockTensorNames(
mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
residual_block_tensor_names=[
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="2.groupnorm_1",
norm_2="2.groupnorm_2",
conv_1="2.conv_1",
conv_2="2.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="4.groupnorm_1",
norm_2="4.groupnorm_2",
conv_1="4.conv_1",
conv_2="4.conv_2",
),
],
attention_block_tensor_names=[
autoencoder_loader.AttnetionBlockTensorNames(
stable_diffusion_loader.AttentionBlockTensorNames(
norm="3.groupnorm",
fused_qkv_proj="3.attention.in_proj",
output_proj="3.attention.out_proj",
)
],
),
up_decoder_blocks_tensor_names=[
autoencoder_loader.UpDecoderBlockTensorNames(
stable_diffusion_loader.UpDecoderBlockTensorNames(
residual_block_tensor_names=[
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="5.groupnorm_1",
norm_2="5.groupnorm_2",
conv_1="5.conv_1",
conv_2="5.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="6.groupnorm_1",
norm_2="6.groupnorm_2",
conv_1="6.conv_1",
conv_2="6.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="7.groupnorm_1",
norm_2="7.groupnorm_2",
conv_1="7.conv_1",
Expand All @@ -72,21 +72,21 @@
],
upsample_conv="9",
),
autoencoder_loader.UpDecoderBlockTensorNames(
stable_diffusion_loader.UpDecoderBlockTensorNames(
residual_block_tensor_names=[
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="10.groupnorm_1",
norm_2="10.groupnorm_2",
conv_1="10.conv_1",
conv_2="10.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="11.groupnorm_1",
norm_2="11.groupnorm_2",
conv_1="11.conv_1",
conv_2="11.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="12.groupnorm_1",
norm_2="12.groupnorm_2",
conv_1="12.conv_1",
Expand All @@ -95,22 +95,22 @@
],
upsample_conv="14",
),
autoencoder_loader.UpDecoderBlockTensorNames(
stable_diffusion_loader.UpDecoderBlockTensorNames(
residual_block_tensor_names=[
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="15.groupnorm_1",
norm_2="15.groupnorm_2",
conv_1="15.conv_1",
conv_2="15.conv_2",
residual_layer="15.residual_layer",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="16.groupnorm_1",
norm_2="16.groupnorm_2",
conv_1="16.conv_1",
conv_2="16.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="17.groupnorm_1",
norm_2="17.groupnorm_2",
conv_1="17.conv_1",
Expand All @@ -119,22 +119,22 @@
],
upsample_conv="19",
),
autoencoder_loader.UpDecoderBlockTensorNames(
stable_diffusion_loader.UpDecoderBlockTensorNames(
residual_block_tensor_names=[
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="20.groupnorm_1",
norm_2="20.groupnorm_2",
conv_1="20.conv_1",
conv_2="20.conv_2",
residual_layer="20.residual_layer",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="21.groupnorm_1",
norm_2="21.groupnorm_2",
conv_1="21.conv_1",
conv_2="21.conv_2",
),
autoencoder_loader.ResidualBlockTensorNames(
stable_diffusion_loader.ResidualBlockTensorNames(
norm_1="22.groupnorm_1",
norm_2="22.groupnorm_2",
conv_1="22.conv_1",
Expand Down Expand Up @@ -245,6 +245,14 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig):
)

def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
"""Forward function of decoder model.
Args:
latents (torch.Tensor): latents space tensor.
Returns:
output decoded image tensor from decoder model.
"""
x = latents_tensor / self.config.scaling_factor
x = self.post_quant_conv(x)
x = self.conv_in(x)
Expand Down
Loading

0 comments on commit 5b31b82

Please sign in to comment.