diff --git a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py index 569d9c29..0a45f6bc 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py @@ -21,11 +21,11 @@ import ai_edge_torch import ai_edge_torch.generative.examples.stable_diffusion.clip as clip import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder -from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA +import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder import ai_edge_torch.generative.examples.stable_diffusion.util as util -import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader import ai_edge_torch.generative.utilities.loader as loading_utils +import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader @torch.inference_mode @@ -45,11 +45,14 @@ def convert_stable_diffusion_to_tflite( encoder = Encoder() encoder.load_state_dict(torch.load(encoder_ckpt_path)) - diffusion = Diffusion() - diffusion.load_state_dict(torch.load(diffusion_ckpt_path)) + diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2)) + diffusion_loader = stable_diffusion_loader.DiffusionModelLoader( + diffusion_ckpt_path, diffusion.TENSORS_NAMES + ) + diffusion_loader.load(diffusion_model) decoder_model = decoder.Decoder(decoder.get_model_config()) - decoder_loader = autoencoder_loader.AutoEncoderModelLoader( + decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader( decoder_ckpt_path, decoder.TENSORS_NAMES ) decoder_loader.load(decoder_model) @@ -84,7 +87,7 @@ def convert_stable_diffusion_to_tflite( # Diffusion ai_edge_torch.signature( 'diffusion', - diffusion, + diffusion_model, (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding), ).convert().export('/tmp/stable_diffusion/diffusion.tflite') diff --git a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py index a8c13082..c3ada8a7 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py @@ -20,20 +20,20 @@ import ai_edge_torch.generative.layers.model_config as layers_cfg import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d import ai_edge_torch.generative.layers.unet.model_config as unet_cfg -import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader +import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader -TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames( +TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames( post_quant_conv="0", conv_in="1", - mid_block_tensor_names=autoencoder_loader.MidBlockTensorNames( + mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames( residual_block_tensor_names=[ - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="2.groupnorm_1", norm_2="2.groupnorm_2", conv_1="2.conv_1", conv_2="2.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="4.groupnorm_1", norm_2="4.groupnorm_2", conv_1="4.conv_1", @@ -41,7 +41,7 @@ ), ], attention_block_tensor_names=[ - autoencoder_loader.AttnetionBlockTensorNames( + stable_diffusion_loader.AttentionBlockTensorNames( norm="3.groupnorm", fused_qkv_proj="3.attention.in_proj", output_proj="3.attention.out_proj", @@ -49,21 +49,21 @@ ], ), up_decoder_blocks_tensor_names=[ - autoencoder_loader.UpDecoderBlockTensorNames( + stable_diffusion_loader.UpDecoderBlockTensorNames( residual_block_tensor_names=[ - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="5.groupnorm_1", norm_2="5.groupnorm_2", conv_1="5.conv_1", conv_2="5.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="6.groupnorm_1", norm_2="6.groupnorm_2", conv_1="6.conv_1", conv_2="6.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="7.groupnorm_1", norm_2="7.groupnorm_2", conv_1="7.conv_1", @@ -72,21 +72,21 @@ ], upsample_conv="9", ), - autoencoder_loader.UpDecoderBlockTensorNames( + stable_diffusion_loader.UpDecoderBlockTensorNames( residual_block_tensor_names=[ - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="10.groupnorm_1", norm_2="10.groupnorm_2", conv_1="10.conv_1", conv_2="10.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="11.groupnorm_1", norm_2="11.groupnorm_2", conv_1="11.conv_1", conv_2="11.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="12.groupnorm_1", norm_2="12.groupnorm_2", conv_1="12.conv_1", @@ -95,22 +95,22 @@ ], upsample_conv="14", ), - autoencoder_loader.UpDecoderBlockTensorNames( + stable_diffusion_loader.UpDecoderBlockTensorNames( residual_block_tensor_names=[ - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="15.groupnorm_1", norm_2="15.groupnorm_2", conv_1="15.conv_1", conv_2="15.conv_2", residual_layer="15.residual_layer", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="16.groupnorm_1", norm_2="16.groupnorm_2", conv_1="16.conv_1", conv_2="16.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="17.groupnorm_1", norm_2="17.groupnorm_2", conv_1="17.conv_1", @@ -119,22 +119,22 @@ ], upsample_conv="19", ), - autoencoder_loader.UpDecoderBlockTensorNames( + stable_diffusion_loader.UpDecoderBlockTensorNames( residual_block_tensor_names=[ - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="20.groupnorm_1", norm_2="20.groupnorm_2", conv_1="20.conv_1", conv_2="20.conv_2", residual_layer="20.residual_layer", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="21.groupnorm_1", norm_2="21.groupnorm_2", conv_1="21.conv_1", conv_2="21.conv_2", ), - autoencoder_loader.ResidualBlockTensorNames( + stable_diffusion_loader.ResidualBlockTensorNames( norm_1="22.groupnorm_1", norm_2="22.groupnorm_2", conv_1="22.conv_1", @@ -245,6 +245,14 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig): ) def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor: + """Forward function of decoder model. + + Args: + latents (torch.Tensor): latents space tensor. + + Returns: + output decoded image tensor from decoder model. + """ x = latents_tensor / self.config.scaling_factor x = self.post_quant_conv(x) x = self.conv_in(x) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py index 19bbf2d1..fd786262 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py @@ -15,230 +15,551 @@ import torch from torch import nn -from torch.nn import functional as F -from ai_edge_torch.generative.examples.stable_diffusion.attention import CrossAttention # NOQA -from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA - - -class TimeEmbedding(nn.Module): - - def __init__(self, n_embd): - super().__init__() - self.linear_1 = nn.Linear(n_embd, 4 * n_embd) - self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd) - - def forward(self, x): - x = self.linear_1(x) - x = F.silu(x) - x = self.linear_2(x) - return x - - -class ResidualBlock(nn.Module): - - def __init__(self, in_channels, out_channels, n_time=1280): - super().__init__() - self.groupnorm_feature = nn.GroupNorm(32, in_channels) - self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - self.linear_time = nn.Linear(n_time, out_channels) - - self.groupnorm_merged = nn.GroupNorm(32, out_channels) - self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) - - if in_channels == out_channels: - self.residual_layer = nn.Identity() - else: - self.residual_layer = nn.Conv2d( - in_channels, out_channels, kernel_size=1, padding=0 - ) - - def forward(self, feature, time): - residue = feature - - feature = self.groupnorm_feature(feature) - feature = F.silu(feature) - feature = self.conv_feature(feature) - - time = F.silu(time) - time = self.linear_time(time) +import ai_edge_torch.generative.layers.builder as layers_builder +import ai_edge_torch.generative.layers.model_config as layers_cfg +import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d +import ai_edge_torch.generative.layers.unet.model_config as unet_cfg +import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader + +_down_encoder_blocks_tensor_names = [ + stable_diffusion_loader.DownEncoderBlockTensorNames( + residual_block_tensor_names=[ + stable_diffusion_loader.ResidualBlockTensorNames( + norm_1=f"unet.encoders.{i*3+j+1}.0.groupnorm_feature", + conv_1=f"unet.encoders.{i*3+j+1}.0.conv_feature", + norm_2=f"unet.encoders.{i*3+j+1}.0.groupnorm_merged", + conv_2=f"unet.encoders.{i*3+j+1}.0.conv_merged", + time_embedding=f"unet.encoders.{i*3+j+1}.0.linear_time", + residual_layer=f"unet.encoders.{i*3+j+1}.0.residual_layer" + if (i * 3 + j + 1) in [4, 7] + else None, + ) + for j in range(2) + ], + transformer_block_tensor_names=[ + stable_diffusion_loader.TransformerBlockTensorNames( + pre_conv_norm=f"unet.encoders.{i*3+j+1}.1.groupnorm", + conv_in=f"unet.encoders.{i*3+j+1}.1.conv_input", + conv_out=f"unet.encoders.{i*3+j+1}.1.conv_output", + self_attention=stable_diffusion_loader.AttentionBlockTensorNames( + norm=f"unet.encoders.{i*3+j+1}.1.layernorm_1", + fused_qkv_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.in_proj", + output_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.out_proj", + ), + cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames( + norm=f"unet.encoders.{i*3+j+1}.1.layernorm_2", + q_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.q_proj", + k_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.k_proj", + v_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.v_proj", + output_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.out_proj", + ), + feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames( + norm=f"unet.encoders.{i*3+j+1}.1.layernorm_3", + ge_glu=f"unet.encoders.{i*3+j+1}.1.linear_geglu_1", + w2=f"unet.encoders.{i*3+j+1}.1.linear_geglu_2", + ), + ) + for j in range(2) + ] + if i < 3 + else None, + downsample_conv=f"unet.encoders.{i*3+3}.0" if i < 3 else None, + ) + for i in range(4) +] + +_mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames( + residual_block_tensor_names=[ + stable_diffusion_loader.ResidualBlockTensorNames( + norm_1=f"unet.bottleneck.{i}.groupnorm_feature", + conv_1=f"unet.bottleneck.{i}.conv_feature", + norm_2=f"unet.bottleneck.{i}.groupnorm_merged", + conv_2=f"unet.bottleneck.{i}.conv_merged", + time_embedding=f"unet.bottleneck.{i}.linear_time", + ) + for i in [0, 2] + ], + transformer_block_tensor_names=[ + stable_diffusion_loader.TransformerBlockTensorNames( + pre_conv_norm=f"unet.bottleneck.{i}.groupnorm", + conv_in=f"unet.bottleneck.{i}.conv_input", + conv_out=f"unet.bottleneck.{i}.conv_output", + self_attention=stable_diffusion_loader.AttentionBlockTensorNames( + norm=f"unet.bottleneck.{i}.layernorm_1", + fused_qkv_proj=f"unet.bottleneck.{i}.attention_1.in_proj", + output_proj=f"unet.bottleneck.{i}.attention_1.out_proj", + ), + cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames( + norm=f"unet.bottleneck.{i}.layernorm_2", + q_proj=f"unet.bottleneck.{i}.attention_2.q_proj", + k_proj=f"unet.bottleneck.{i}.attention_2.k_proj", + v_proj=f"unet.bottleneck.{i}.attention_2.v_proj", + output_proj=f"unet.bottleneck.{i}.attention_2.out_proj", + ), + feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames( + norm=f"unet.bottleneck.{i}.layernorm_3", + ge_glu=f"unet.bottleneck.{i}.linear_geglu_1", + w2=f"unet.bottleneck.{i}.linear_geglu_2", + ), + ) + for i in [1] + ], +) + +_up_decoder_blocks_tensor_names = [ + stable_diffusion_loader.SkipUpDecoderBlockTensorNames( + residual_block_tensor_names=[ + stable_diffusion_loader.ResidualBlockTensorNames( + norm_1=f"unet.decoders.{i*3+j}.0.groupnorm_feature", + conv_1=f"unet.decoders.{i*3+j}.0.conv_feature", + norm_2=f"unet.decoders.{i*3+j}.0.groupnorm_merged", + conv_2=f"unet.decoders.{i*3+j}.0.conv_merged", + time_embedding=f"unet.decoders.{i*3+j}.0.linear_time", + residual_layer=f"unet.decoders.{i*3+j}.0.residual_layer", + ) + for j in range(3) + ], + transformer_block_tensor_names=[ + stable_diffusion_loader.TransformerBlockTensorNames( + pre_conv_norm=f"unet.decoders.{i*3+j}.1.groupnorm", + conv_in=f"unet.decoders.{i*3+j}.1.conv_input", + conv_out=f"unet.decoders.{i*3+j}.1.conv_output", + self_attention=stable_diffusion_loader.AttentionBlockTensorNames( + norm=f"unet.decoders.{i*3+j}.1.layernorm_1", + fused_qkv_proj=f"unet.decoders.{i*3+j}.1.attention_1.in_proj", + output_proj=f"unet.decoders.{i*3+j}.1.attention_1.out_proj", + ), + cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames( + norm=f"unet.decoders.{i*3+j}.1.layernorm_2", + q_proj=f"unet.decoders.{i*3+j}.1.attention_2.q_proj", + k_proj=f"unet.decoders.{i*3+j}.1.attention_2.k_proj", + v_proj=f"unet.decoders.{i*3+j}.1.attention_2.v_proj", + output_proj=f"unet.decoders.{i*3+j}.1.attention_2.out_proj", + ), + feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames( + norm=f"unet.decoders.{i*3+j}.1.layernorm_3", + ge_glu=f"unet.decoders.{i*3+j}.1.linear_geglu_1", + w2=f"unet.decoders.{i*3+j}.1.linear_geglu_2", + ), + ) + for j in range(3) + ] + if i > 0 + else None, + upsample_conv=f"unet.decoders.{i*3+2}.2.conv" + if 0 < i < 3 + else (f"unet.decoders.2.1.conv" if i == 0 else None), + ) + for i in range(4) +] - merged = feature + time.unsqueeze(-1).unsqueeze(-1) - merged = self.groupnorm_merged(merged) - merged = F.silu(merged) - merged = self.conv_merged(merged) - return merged + self.residual_layer(residue) +TENSORS_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames( + time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames( + w1="time_embedding.linear_1", + w2="time_embedding.linear_2", + ), + conv_in="unet.encoders.0.0", + conv_out="final.conv", + final_norm="final.groupnorm", + down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names, + mid_block_tensor_names=_mid_block_tensor_names, + up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names, +) -class AttentionBlock(nn.Module): +class TimeEmbedding(nn.Module): - def __init__(self, n_head: int, n_embd: int, d_context=768): + def __init__(self, in_dim, out_dim): super().__init__() - channels = n_head * n_embd - - self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6) - self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - - self.layernorm_1 = nn.LayerNorm(channels) - self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False) - self.layernorm_2 = nn.LayerNorm(channels) - self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False) - self.layernorm_3 = nn.LayerNorm(channels) - self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2) - self.linear_geglu_2 = nn.Linear(4 * channels, channels) - - self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - - def forward(self, x, context): - residue_long = x - - x = self.groupnorm(x) - x = self.conv_input(x) - - n, c, h, w = x.shape - x = x.view((n, c, h * w)) # (n, c, hw) - x = x.transpose(-1, -2) # (n, hw, c) - - residue_short = x - x = self.layernorm_1(x) - x = self.attention_1(x) - x += residue_short - - residue_short = x - x = self.layernorm_2(x) - x = self.attention_2(x, context) - x += residue_short - - residue_short = x - x = self.layernorm_3(x) - x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) - x = x * F.gelu(gate) - x = self.linear_geglu_2(x) - x += residue_short - - x = x.transpose(-1, -2) # (n, c, hw) - x = x.view((n, c, h, w)) # (n, c, h, w) - - return self.conv_output(x) + residue_long + self.w1 = nn.Linear(in_dim, out_dim) + self.w2 = nn.Linear(out_dim, out_dim) + self.act = layers_builder.get_activation( + layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU) + ) + def forward(self, x: torch.Tensor): + return self.w2(self.act(self.w1(x))) -class Upsample(nn.Module): - def __init__(self, channels): +class Diffusion(nn.Module): + """The Diffusion model used in Stable Diffusion. + + For details, see https://arxiv.org/abs/2103.00020 + + Sturcture of the Diffusion model: + + latents text context time embed + │ │ │ + │ │ │ + ┌─────────▼─────────┐ │ ┌─────────▼─────────┐ + │ ConvIn │ │ │ Time Embedding │ + └─────────┬─────────┘ │ └─────────┬─────────┘ + │ │ │ + ┌─────────▼─────────┐ │ │ + ┌──────┤ DownEncoder2D │ ◄─────┼────────────┤ + │ └─────────┬─────────┘ x 4 │ │ + │ │ │ │ + │ ┌─────────▼─────────┐ │ │ + skip connection │ MidBlock2D │ ◄─────┼────────────┤ + │ └─────────┬─────────┘ │ │ + │ │ │ │ + │ ┌─────────▼─────────┐ │ │ + └──────► SkipUpDecoder2D │ ◄─────┴────────────┘ + └─────────┬─────────┘ x 4 + │ + ┌─────────▼─────────┐ + │ FinalNorm │ + └─────────┬─────────┘ + │ + ┌─────────▼─────────┐ + │ Activation │ + └─────────┬─────────┘ + │ + ┌─────────▼─────────┐ + │ ConvOut │ + └─────────┬─────────┘ + │ + ▼ + output image + """ + + def __init__(self, config: unet_cfg.DiffusionModelConfig): super().__init__() - self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) - - def forward(self, x): - x = F.interpolate(x, scale_factor=2, mode="nearest") - return self.conv(x) + self.config = config + block_out_channels = config.block_out_channels + reversed_block_out_channels = list(reversed(block_out_channels)) -class SwitchSequential(nn.Sequential): - - def forward(self, x, context, time): - for layer in self: - if isinstance(layer, AttentionBlock): - x = layer(x, context) - elif isinstance(layer, ResidualBlock): - x = layer(x, time) - else: - x = layer(x) - return x - - -class UNet(nn.Module): + time_embedding_blocks_dim = config.time_embedding_blocks_dim + self.time_embedding = TimeEmbedding( + config.time_embedding_dim, config.time_embedding_blocks_dim + ) - def __init__(self): - super().__init__() - self.encoders = nn.ModuleList( - [ - SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)), - SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)), - SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)), - SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)), - SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)), - SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)), - SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)), - SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)), - SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)), - SwitchSequential(ResidualBlock(1280, 1280)), - SwitchSequential(ResidualBlock(1280, 1280)), - ] + self.conv_in = nn.Conv2d( + config.in_channels, block_out_channels[0], kernel_size=3, padding=1 ) - self.bottleneck = SwitchSequential( - ResidualBlock(1280, 1280), - AttentionBlock(8, 160), - ResidualBlock(1280, 1280), + + attention_config = layers_cfg.AttentionConfig( + num_heads=config.transformer_num_attention_heads, + num_query_groups=config.transformer_num_attention_heads, + rotary_percentage=0.0, + qkv_transpose_before_split=True, + qkv_use_bias=False, + output_proj_use_bias=True, + enable_kv_cache=False, ) - self.decoders = nn.ModuleList( - [ - SwitchSequential(ResidualBlock(2560, 1280)), - SwitchSequential(ResidualBlock(2560, 1280)), - SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)), - SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), - SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), - SwitchSequential( - ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280) + # Down encoders. + down_encoders = [] + output_channel = block_out_channels[0] + for i, block_out_channel in enumerate(block_out_channels): + input_channel = output_channel + output_channel = block_out_channel + not_final_block = i < len(block_out_channels) - 1 + if not_final_block: + down_encoders.append( + blocks_2d.DownEncoderBlock2D( + unet_cfg.DownEncoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_cfg.ActivationConfig( + config.residual_activation_type + ), + num_layers=config.layers_per_block, + padding=config.downsample_padding, + time_embedding_channels=time_embedding_blocks_dim, + add_downsample=True, + sampling_config=unet_cfg.DownSamplingConfig( + mode=unet_cfg.SamplingType.CONVOLUTION, + in_channels=output_channel, + out_channels=output_channel, + kernel_size=3, + stride=2, + padding=config.downsample_padding, + ), + transformer_block_config=unet_cfg.TransformerBlock2Dconfig( + attention_block_config=unet_cfg.AttentionBlock2DConfig( + dim=output_channel, + attention_batch_size=config.transformer_batch_size, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( + query_dim=output_channel, + cross_dim=config.transformer_cross_attention_dim, + attention_batch_size=config.transformer_batch_size, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + pre_conv_normalization_config=config.transformer_pre_conv_norm_config, + feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig( + dim=output_channel, + hidden_dim=output_channel * 4, + normalization_config=config.transformer_norm_config, + activation_config=layers_cfg.ActivationConfig( + type=config.transformer_ff_activation_type, + dim_in=output_channel, + dim_out=output_channel * 4, + ), + use_bias=True, + ), + ), + ) + ) + ) + else: + down_encoders.append( + blocks_2d.DownEncoderBlock2D( + unet_cfg.DownEncoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_cfg.ActivationConfig( + config.residual_activation_type + ), + num_layers=config.layers_per_block, + padding=config.downsample_padding, + time_embedding_channels=time_embedding_blocks_dim, + add_downsample=False, + ) + ) + ) + self.down_encoders = nn.ModuleList(down_encoders) + + # Mid block. + mid_block_channels = block_out_channels[-1] + self.mid_block = blocks_2d.MidBlock2D( + unet_cfg.MidBlock2DConfig( + in_channels=block_out_channels[-1], + normalization_config=config.residual_norm_config, + activation_config=layers_cfg.ActivationConfig( + config.residual_activation_type ), - SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)), - SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)), - SwitchSequential( - ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640) + num_layers=config.mid_block_layers, + time_embedding_channels=config.time_embedding_blocks_dim, + transformer_block_config=unet_cfg.TransformerBlock2Dconfig( + attention_block_config=unet_cfg.AttentionBlock2DConfig( + dim=mid_block_channels, + attention_batch_size=config.transformer_batch_size, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( + query_dim=mid_block_channels, + cross_dim=config.transformer_cross_attention_dim, + attention_batch_size=config.transformer_batch_size, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + pre_conv_normalization_config=config.transformer_pre_conv_norm_config, + feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig( + dim=mid_block_channels, + hidden_dim=mid_block_channels * 4, + normalization_config=config.transformer_norm_config, + activation_config=layers_cfg.ActivationConfig( + type=config.transformer_ff_activation_type, + dim_in=mid_block_channels, + dim_out=mid_block_channels * 4, + ), + use_bias=True, + ), ), - SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), - ] + ) ) - def forward(self, x, context, time): - skip_connections = [] - for layers in self.encoders: - x = layers(x, context, time) - skip_connections.append(x) - - x = self.bottleneck(x, context, time) - - for layers in self.decoders: - x = torch.cat((x, skip_connections.pop()), dim=1) - x = layers(x, context, time) - - return x - - -class FinalLayer(nn.Module): - - def __init__(self, in_channels, out_channels): - super().__init__() - self.groupnorm = nn.GroupNorm(32, in_channels) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - - def forward(self, x): - x = self.groupnorm(x) - x = F.silu(x) - x = self.conv(x) - return x - - -class Diffusion(nn.Module): - - def __init__(self): - super().__init__() - self.time_embedding = TimeEmbedding(320) - self.unet = UNet() - self.final = FinalLayer(320, 4) + # Up decoders. + up_decoders = [] + up_decoder_layers_per_block = config.layers_per_block + 1 + output_channel = reversed_block_out_channels[0] + for i, block_out_channel in enumerate(reversed_block_out_channels): + prev_out_channel = output_channel + output_channel = block_out_channel + input_channel = reversed_block_out_channels[ + min(i + 1, len(reversed_block_out_channels) - 1) + ] + not_final_block = i < len(reversed_block_out_channels) - 1 + not_first_block = i != 0 + if not_first_block: + up_decoders.append( + blocks_2d.SkipUpDecoderBlock2D( + unet_cfg.SkipUpDecoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + prev_out_channels=prev_out_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_cfg.ActivationConfig( + config.residual_activation_type + ), + num_layers=up_decoder_layers_per_block, + time_embedding_channels=time_embedding_blocks_dim, + add_upsample=not_final_block, + upsample_conv=True, + sampling_config=unet_cfg.UpSamplingConfig( + mode=unet_cfg.SamplingType.NEAREST, + scale_factor=2, + ), + transformer_block_config=unet_cfg.TransformerBlock2Dconfig( + attention_block_config=unet_cfg.AttentionBlock2DConfig( + dim=output_channel, + attention_batch_size=config.transformer_batch_size, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( + query_dim=output_channel, + cross_dim=config.transformer_cross_attention_dim, + attention_batch_size=config.transformer_batch_size, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + pre_conv_normalization_config=config.transformer_pre_conv_norm_config, + feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig( + dim=output_channel, + hidden_dim=output_channel * 4, + normalization_config=config.transformer_norm_config, + activation_config=layers_cfg.ActivationConfig( + type=config.transformer_ff_activation_type, + dim_in=output_channel, + dim_out=output_channel * 4, + ), + use_bias=True, + ), + ), + ) + ) + ) + else: + up_decoders.append( + blocks_2d.SkipUpDecoderBlock2D( + unet_cfg.SkipUpDecoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + prev_out_channels=prev_out_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_cfg.ActivationConfig( + config.residual_activation_type + ), + num_layers=up_decoder_layers_per_block, + time_embedding_channels=time_embedding_blocks_dim, + add_upsample=not_final_block, + upsample_conv=True, + sampling_config=unet_cfg.UpSamplingConfig( + mode=unet_cfg.SamplingType.NEAREST, scale_factor=2 + ), + ) + ) + ) + self.up_decoders = nn.ModuleList(up_decoders) + + self.final_norm = layers_builder.build_norm( + reversed_block_out_channels[-1], config.final_norm_config + ) + self.final_act = layers_builder.get_activation( + layers_cfg.ActivationConfig(config.final_activation_type) + ) + self.conv_out = nn.Conv2d( + reversed_block_out_channels[-1], config.out_channels, kernel_size=3, padding=1 + ) @torch.inference_mode - def forward(self, latent, context, time): - time = self.time_embedding(time) - output = self.unet(latent, context, time) - output = self.final(output) - return output + def forward( + self, latents: torch.Tensor, context: torch.Tensor, time_emb: torch.Tensor + ) -> torch.Tensor: + """Forward function of diffusion model. + + Args: + latents (torch.Tensor): latents space tensor. + context (torch.Tensor): context tensor from CLIP text encoder. + time_emb (torch.Tensor): the time embedding tensor. + + Returns: + output latents from diffusion model. + """ + time_emb = self.time_embedding(time_emb) + x = self.conv_in(latents) + skip_connection_tensors = [x] + for encoder in self.down_encoders: + x, hidden_states = encoder(x, time_emb, context, output_hidden_states=True) + skip_connection_tensors.extend(hidden_states) + x = self.mid_block(x, time_emb, context) + for decoder in self.up_decoders: + encoder_tensors = [ + skip_connection_tensors.pop() for i in range(self.config.layers_per_block + 1) + ] + x = decoder(x, encoder_tensors, time_emb, context) + x = self.final_norm(x) + x = self.final_act(x) + x = self.conv_out(x) + return x -if __name__ == "__main__": - diffusion = Diffusion() - print(diffusion.state_dict().keys()) +def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: + """Get configs for the Diffusion model of Stable Diffusion v1.5 + + Args: + batch_size (int): the batch size of input. + + Retruns: + The configuration of diffusion model of Stable Diffusion v1.5. + + """ + in_channels = 4 + out_channels = 4 + block_out_channels = [320, 640, 1280, 1280] + layers_per_block = 2 + downsample_padding = 1 + + # Residual configs. + residual_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, group_num=32 + ) + residual_activation_type = layers_cfg.ActivationType.SILU + + # Transformer configs. + transformer_num_attention_heads = 8 + transformer_batch_size = batch_size + transformer_cross_attention_dim = 768 # Embedding fomr CLIP model + transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32 + ) + transformer_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.LAYER_NORM + ) + transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU + + # Time embedding configs. + time_embedding_dim = 320 + time_embedding_blocks_dim = 1280 + + # Mid block configs. + mid_block_layers = 1 + + # Finaly layer configs. + final_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, group_num=32 + ) + final_activation_type = layers_cfg.ActivationType.SILU + + return unet_cfg.DiffusionModelConfig( + in_channels=in_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + downsample_padding=downsample_padding, + residual_norm_config=residual_norm_config, + residual_activation_type=residual_activation_type, + transformer_batch_size=transformer_batch_size, + transformer_num_attention_heads=transformer_num_attention_heads, + transformer_cross_attention_dim=transformer_cross_attention_dim, + transformer_pre_conv_norm_config=transformer_pre_conv_norm_config, + transformer_norm_config=transformer_norm_config, + transformer_ff_activation_type=transformer_ff_activation_type, + mid_block_layers=mid_block_layers, + time_embedding_dim=time_embedding_dim, + time_embedding_blocks_dim=time_embedding_blocks_dim, + final_norm_config=final_norm_config, + final_activation_type=final_activation_type, + ) diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index 4dab1b8a..2456cb7a 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -244,7 +244,10 @@ def forward( """ B, T, _ = x.size() return super().forward( - x, rope=rope, mask=torch.zeros((B, T), dtype=torch.float32), input_pos=input_pos + x, + rope=rope, + mask=torch.zeros((B, 1, T, T), dtype=torch.float32), + input_pos=input_pos, ) @@ -304,20 +307,31 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ): - B, T, E = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + """Forward function of the CrossAttention layer. + + Args: + x (torch.Tensor): the target tensor, with shape [B, target_seq_len, ...]. + y (torch.Tensor): the source tensor, with shape [B, source_seq_len, ...]. + rope (Tuple[torch.Tensor, torch.Tensor]): the optional input rope tensor. + mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape [B, n_heads, target_seq_len, source_seq_len]. + input_pos (torch.Tensor): the optional input position tensor. + + Returns: + output activation from this cross attention layer. + """ + batch_size = x.size()[0] + target_seq_len = x.size()[1] + source_seq_len = y.size()[1] q = self.q_projection(x) - k = self.k_projection(x) - v = self.v_projection(x) + k = self.k_projection(y) + v = self.v_projection(y) - interim_shape = (B, T, self.n_heads, self.head_dim) + interim_shape = (batch_size, -1, self.n_heads, self.head_dim) q = q.view(interim_shape) k = k.view(interim_shape) v = v.view(interim_shape) - if mask is None: - mask = torch.zeros((B, T), dtype=torch.float32) - # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.head_dim) q, k = _embed_rope(q, k, n_elem, rope) @@ -325,9 +339,12 @@ def forward( if self.kv_cache is not None: # TODO(haoliang): Handle when execeeding max sequence length. k, v = self.kv_cache.update_cache(input_pos, k, v) - + if mask is None: + mask = torch.zeros( + (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32 + ) y = self.sdpa_func(q, k, v, self.head_dim, mask=mask) - y = y.reshape(B, T, E) + y = y.reshape(batch_size, target_seq_len, -1) # Compute the output projection. y = self.output_projection(y) diff --git a/ai_edge_torch/generative/layers/unet/blocks_2d.py b/ai_edge_torch/generative/layers/unet/blocks_2d.py index 8ccb9d16..ff038e11 100644 --- a/ai_edge_torch/generative/layers/unet/blocks_2d.py +++ b/ai_edge_torch/generative/layers/unet/blocks_2d.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import List, Optional +from typing import List, Optional, Tuple import torch from torch import nn @@ -93,8 +93,7 @@ def forward( class AttentionBlock2D(nn.Module): """2D self attention block - residual = x - x = SelfAttention(Norm(input_tensor)) + residual + x = SelfAttention(Norm(input_tensor)) + x """ @@ -105,8 +104,15 @@ def __init__(self, config: unet_cfg.AttentionBlock2DConfig): config (unet_cfg.AttentionBlock2DConfig): the configuration of this block. """ super().__init__() + self.config = config self.norm = layers_builder.build_norm(config.dim, config.normalization_config) - self.attention = SelfAttention(config.dim, config.attention_config, 0, True) + self.attention = SelfAttention( + config.attention_batch_size, + config.dim, + config.attention_config, + 0, + enable_hlfb=config.enable_hlfb, + ) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: """Forward function of the AttentionBlock2D. @@ -118,10 +124,16 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: output activation tensor after self attention. """ residual = input_tensor - x = self.norm(input_tensor) - B, C, H, W = x.shape - x = x.view(B, C, H * W) - x = x.transpose(-1, -2) + B, C, H, W = input_tensor.shape + x = input_tensor + if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM: + x = self.norm(x) + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + else: + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + x = self.norm(x) x = self.attention(x) x = x.transpose(-1, -2) x = x.view(B, C, H, W) @@ -132,8 +144,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: class CrossAttentionBlock2D(nn.Module): """2D cross attention block - residual = x - x = CrossAttention(Norm(input_tensor), context) + residual + x = CrossAttention(Norm(input_tensor), context) + x """ @@ -147,7 +158,12 @@ def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig): self.config = config self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config) self.attention = CrossAttention( - config.query_dim, config.cross_dim, config.attention_config, 0, True + config.attention_batch_size, + config.query_dim, + config.cross_dim, + config.attention_config, + 0, + enable_hlfb=config.enable_hlfb, ) def forward( @@ -163,10 +179,16 @@ def forward( output activation tensor after cross attention. """ residual = input_tensor - x = self.norm(input_tensor) - B, C, H, W = x.shape - x = x.view(B, C, H * W) - x = x.transpose(-1, -2) + B, C, H, W = input_tensor.shape + x = input_tensor + if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM: + x = self.norm(x) + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + else: + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + x = self.norm(x) x = self.attention(x, context_tensor) x = x.transpose(-1, -2) x = x.view(B, C, H, W) @@ -177,8 +199,7 @@ def forward( class FeedForwardBlock2D(nn.Module): """2D feed forward block - residual = x - x = w2(Activation(w1(Norm(x)))) + residual + x = w2(Activation(w1(Norm(x)))) + x """ @@ -197,14 +218,18 @@ def __init__( self.w1 = nn.Linear(config.dim, config.hidden_dim) self.w2 = nn.Linear(config.hidden_dim, config.dim) - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - B, C, H, W = x.shape - x = x.view((B, C, H * W)) - x = x.transpose(-1, -2) # (B, HW, C) - - x = self.norm(x) + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + residual = input_tensor + B, C, H, W = input_tensor.shape + x = input_tensor + if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM: + x = self.norm(x) + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + else: + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + x = self.norm(x) x = self.w1(x) x = self.act(x) x = self.w2(x) @@ -287,7 +312,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor): x = self.pre_conv_norm(x) x = self.conv_in(x) - x = self.self_attention(x) x = self.cross_attention(x, context) x = self.feed_forward(x) @@ -360,7 +384,8 @@ def forward( input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None, context_tensor: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + output_hidden_states: bool = False, + ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]: """Forward function of the DownEncoderBlock2D. Args: @@ -368,18 +393,24 @@ def forward( time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept time embedding. context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block. - + output_hidden_states (bool): whether to output hidden states, usually for skip connections. Returns: output hidden_states tensor after DownEncoderBlock2D. """ hidden_states = input_tensor - for resnet, transformer in zip(self.resnets, self.transformers): + output_states = [] + for i, resnet in enumerate(self.resnets): hidden_states = resnet(hidden_states, time_emb) - if transformer is not None: - hidden_states = transformer(hidden_states, context_tensor) + if self.transformers is not None: + hidden_states = self.transformers[i](hidden_states, context_tensor) + output_states.append(hidden_states) if self.downsampler: hidden_states = self.downsampler(hidden_states) - return hidden_states + output_states.append(hidden_states) + if output_hidden_states: + return hidden_states, output_states + else: + return hidden_states class UpDecoderBlock2D(nn.Module): @@ -569,7 +600,7 @@ def forward( for i, (resnet, skip_connection_tensor) in enumerate( zip(self.resnets, skip_connection_tensors) ): - hidden_states = torch.cat([resnet, skip_connection_tensor], dim=1) + hidden_states = torch.cat([hidden_states, skip_connection_tensor], dim=1) hidden_states = resnet(hidden_states, time_emb) if self.transformers is not None: hidden_states = self.transformers[i](hidden_states, context_tensor) diff --git a/ai_edge_torch/generative/layers/unet/model_config.py b/ai_edge_torch/generative/layers/unet/model_config.py index 0e0e2ca5..85bfa22b 100644 --- a/ai_edge_torch/generative/layers/unet/model_config.py +++ b/ai_edge_torch/generative/layers/unet/model_config.py @@ -61,6 +61,8 @@ class AttentionBlock2DConfig: dim: int normalization_config: layers_cfg.NormalizationConfig attention_config: layers_cfg.AttentionConfig + enable_hlfb: bool = True + attention_batch_size: int = 1 @dataclass @@ -69,6 +71,8 @@ class CrossAttentionBlock2DConfig: cross_dim: int normalization_config: layers_cfg.NormalizationConfig attention_config: layers_cfg.AttentionConfig + enable_hlfb: bool = True + attention_batch_size: int = 1 @dataclass @@ -204,3 +208,62 @@ class AutoEncoderConfig: # The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder. mid_block_config: MidBlock2DConfig + + +@dataclass +class DiffusionModelConfig: + """Configurations of Diffusion model.""" + + # Number of channels in the input tensor. + in_channels: int + + # Number of channels in the output tensor. + out_channels: int + + # The output channels of each block. + block_out_channels: List[int] + + # The layesr number of each block. + layers_per_block: int + + # The padding to use for the downsampling. + downsample_padding: int + + # Normalization config used in residual blocks. + residual_norm_config: layers_cfg.NormalizationConfig + + # Activation config used in residual blocks + residual_activation_type: layers_cfg.ActivationType + + # The batch size used in transformer blocks, for attention layers. + transformer_batch_size: int + + # The number of attention heads used in transformer blocks. + transformer_num_attention_heads: int + + # The dimension of cross attention used in transformer blocks. + transformer_cross_attention_dim: int + + # Normalization config used in prev conv layer of transformer blocks. + transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig + + # Normalization config used in transformer blocks. + transformer_norm_config: layers_cfg.NormalizationConfig + + # Activation type of feed forward used in transformer blocks. + transformer_ff_activation_type: layers_cfg.ActivationType + + # Number of layers in mid block. + mid_block_layers: int + + # Dimension of time embedding. + time_embedding_dim: int + + # Time embedding dimensions for blocks. + time_embedding_blocks_dim: int + + # Normalization config used for final layer + final_norm_config: layers_cfg.NormalizationConfig + + # Activation type used in final layer + final_activation_type: layers_cfg.ActivationType diff --git a/ai_edge_torch/generative/utilities/autoencoder_loader.py b/ai_edge_torch/generative/utilities/autoencoder_loader.py deleted file mode 100644 index e39f2876..00000000 --- a/ai_edge_torch/generative/utilities/autoencoder_loader.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright 2024 The AI Edge Torch Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Common utility functions for data loading etc. -from dataclasses import dataclass -from typing import Dict, List, Tuple - -import torch - -import ai_edge_torch.generative.layers.model_config as layers_config -import ai_edge_torch.generative.layers.unet.model_config as unet_config -import ai_edge_torch.generative.utilities.loader as loader - - -@dataclass -class ResidualBlockTensorNames: - norm_1: str = None - conv_1: str = None - norm_2: str = None - conv_2: str = None - residual_layer: str = None - - -@dataclass -class AttnetionBlockTensorNames: - norm: str = None - fused_qkv_proj: str = None - output_proj: str = None - - -@dataclass -class MidBlockTensorNames: - residual_block_tensor_names: List[ResidualBlockTensorNames] - attention_block_tensor_names: List[AttnetionBlockTensorNames] - - -@dataclass -class UpDecoderBlockTensorNames: - residual_block_tensor_names: List[ResidualBlockTensorNames] - upsample_conv: str = None - - -def _map_to_converted_state( - state: Dict[str, torch.Tensor], - state_param: str, - converted_state: Dict[str, torch.Tensor], - converted_state_param: str, -): - converted_state[f"{converted_state_param}.weight"] = state.pop( - f"{state_param}.weight" - ) - if f"{state_param}.bias" in state: - converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias") - - -class AutoEncoderModelLoader(loader.ModelLoader): - - @dataclass - class TensorNames: - quant_conv: str = None - post_quant_conv: str = None - conv_in: str = None - conv_out: str = None - final_norm: str = None - mid_block_tensor_names: MidBlockTensorNames = None - up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None - - def __init__(self, file_name: str, names: TensorNames): - """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models. - - Args: - file_name (str): Path to the checkpoint. Can be a directory or an - exact file. - names (TensorNames): An instance of `TensorNames` to determine mappings. - """ - self._file_name = file_name - self._names = names - self._loader = self._get_loader() - - def load( - self, model: torch.nn.Module, strict: bool = True - ) -> Tuple[List[str], List[str]]: - """Load the model from the checkpoint. - - Args: - model (torch.nn.Module): The pytorch model that needs to be loaded. - strict (bool, optional): Whether the converted keys are strictly - matched. Defaults to True. - - Returns: - missing_keys (List[str]): a list of str containing the missing keys. - unexpected_keys (List[str]): a list of str containing the unexpected keys. - - Raises: - ValueError: If conversion results in unmapped tensors and strict mode is - enabled. - """ - state = self._loader(self._file_name) - converted_state = dict() - if self._names.quant_conv is not None: - _map_to_converted_state( - state, self._names.quant_conv, converted_state, "quant_conv" - ) - if self._names.post_quant_conv is not None: - _map_to_converted_state( - state, self._names.post_quant_conv, converted_state, "post_quant_conv" - ) - if self._names.conv_in is not None: - _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in") - if self._names.conv_out is not None: - _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out") - if self._names.final_norm is not None: - _map_to_converted_state( - state, self._names.final_norm, converted_state, "final_norm" - ) - self._map_mid_block( - state, - converted_state, - model.config.mid_block_config, - self._names.mid_block_tensor_names, - ) - - reversed_block_out_channels = list(reversed(model.config.block_out_channels)) - block_out_channels = reversed_block_out_channels[0] - for i, out_channels in enumerate(reversed_block_out_channels): - prev_output_channel = block_out_channels - block_out_channels = out_channels - not_final_block = i < len(reversed_block_out_channels) - 1 - self._map_up_decoder_block( - state, - converted_state, - f"up_decoder_blocks.{i}", - unet_config.UpDecoderBlock2DConfig( - in_channels=prev_output_channel, - out_channels=block_out_channels, - normalization_config=model.config.normalization_config, - activation_config=model.config.activation_config, - num_layers=model.config.layers_per_block, - add_upsample=not_final_block, - upsample_conv=True, - ), - self._names.up_decoder_blocks_tensor_names[i], - ) - if strict and state: - raise ValueError( - f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}" - ) - return model.load_state_dict(converted_state, strict=strict) - - def _map_residual_block( - self, - state: Dict[str, torch.Tensor], - converted_state: Dict[str, torch.Tensor], - tensor_names: ResidualBlockTensorNames, - converted_state_param_prefix: str, - config: unet_config.ResidualBlock2DConfig, - ): - _map_to_converted_state( - state, - tensor_names.norm_1, - converted_state, - f"{converted_state_param_prefix}.norm_1", - ) - _map_to_converted_state( - state, - tensor_names.conv_1, - converted_state, - f"{converted_state_param_prefix}.conv_1", - ) - _map_to_converted_state( - state, - tensor_names.norm_2, - converted_state, - f"{converted_state_param_prefix}.norm_2", - ) - _map_to_converted_state( - state, - tensor_names.conv_2, - converted_state, - f"{converted_state_param_prefix}.conv_2", - ) - if config.in_channels != config.out_channels: - _map_to_converted_state( - state, - tensor_names.residual_layer, - converted_state, - f"{converted_state_param_prefix}.residual_layer", - ) - - def _map_attention_block( - self, - state: Dict[str, torch.Tensor], - converted_state: Dict[str, torch.Tensor], - tensor_names: AttnetionBlockTensorNames, - converted_state_param_prefix: str, - config: unet_config.AttentionBlock2DConfig, - ): - if config.normalization_config.type != layers_config.NormalizationType.NONE: - _map_to_converted_state( - state, - tensor_names.norm, - converted_state, - f"{converted_state_param_prefix}.norm", - ) - attention_layer_prefix = f"{converted_state_param_prefix}.attention" - _map_to_converted_state( - state, - tensor_names.fused_qkv_proj, - converted_state, - f"{attention_layer_prefix}.qkv_projection", - ) - _map_to_converted_state( - state, - tensor_names.output_proj, - converted_state, - f"{attention_layer_prefix}.output_projection", - ) - - def _map_mid_block( - self, - state: Dict[str, torch.Tensor], - converted_state: Dict[str, torch.Tensor], - config: unet_config.MidBlock2DConfig, - tensor_names: MidBlockTensorNames, - ): - converted_state_param_prefix = "mid_block" - residual_block_config = unet_config.ResidualBlock2DConfig( - in_channels=config.in_channels, - out_channels=config.in_channels, - time_embedding_channels=config.time_embedding_channels, - normalization_config=config.normalization_config, - activation_config=config.activation_config, - ) - self._map_residual_block( - state, - converted_state, - tensor_names.residual_block_tensor_names[0], - f"{converted_state_param_prefix}.resnets.0", - residual_block_config, - ) - for i in range(config.num_layers): - if config.attention_block_config: - self._map_attention_block( - state, - converted_state, - tensor_names.attention_block_tensor_names[i], - f"{converted_state_param_prefix}.attentions.{i}", - config.attention_block_config, - ) - self._map_residual_block( - state, - converted_state, - tensor_names.residual_block_tensor_names[i + 1], - f"{converted_state_param_prefix}.resnets.{i+1}", - residual_block_config, - ) - - def _map_up_decoder_block( - self, - state: Dict[str, torch.Tensor], - converted_state: Dict[str, torch.Tensor], - converted_state_param_prefix: str, - config: unet_config.UpDecoderBlock2DConfig, - tensor_names: UpDecoderBlockTensorNames, - ): - for i in range(config.num_layers): - input_channels = config.in_channels if i == 0 else config.out_channels - self._map_residual_block( - state, - converted_state, - tensor_names.residual_block_tensor_names[i], - f"{converted_state_param_prefix}.resnets.{i}", - unet_config.ResidualBlock2DConfig( - in_channels=input_channels, - out_channels=config.out_channels, - time_embedding_channels=config.time_embedding_channels, - normalization_config=config.normalization_config, - activation_config=config.activation_config, - ), - ) - if config.add_upsample and config.upsample_conv: - _map_to_converted_state( - state, - tensor_names.upsample_conv, - converted_state, - f"{converted_state_param_prefix}.upsample_conv", - ) diff --git a/ai_edge_torch/generative/utilities/stable_diffusion_loader.py b/ai_edge_torch/generative/utilities/stable_diffusion_loader.py new file mode 100644 index 00000000..aca42fa1 --- /dev/null +++ b/ai_edge_torch/generative/utilities/stable_diffusion_loader.py @@ -0,0 +1,860 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Common utility functions for data loading etc. +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +import ai_edge_torch.generative.layers.model_config as layers_config +import ai_edge_torch.generative.layers.unet.model_config as unet_config +import ai_edge_torch.generative.utilities.loader as loader + + +@dataclass +class ResidualBlockTensorNames: + norm_1: str = None + conv_1: str = None + norm_2: str = None + conv_2: str = None + residual_layer: str = None + time_embedding: str = None + + +@dataclass +class AttentionBlockTensorNames: + norm: str = None + fused_qkv_proj: str = None + output_proj: str = None + + +@dataclass +class CrossAttentionBlockTensorNames: + norm: str = None + q_proj: str = None + k_proj: str = None + v_proj: str = None + output_proj: str = None + + +@dataclass +class TimeEmbeddingTensorNames: + w1: str = None + w2: str = None + + +@dataclass +class FeedForwardBlockTensorNames: + w1: str = None + w2: str = None + norm: str = None + ge_glu: str = None + + +@dataclass +class TransformerBlockTensorNames: + pre_conv_norm: str + conv_in: str + self_attention: AttentionBlockTensorNames + cross_attention: CrossAttentionBlockTensorNames + feed_forward: FeedForwardBlockTensorNames + conv_out: str + + +@dataclass +class MidBlockTensorNames: + residual_block_tensor_names: List[ResidualBlockTensorNames] + attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None + transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None + + +@dataclass +class DownEncoderBlockTensorNames: + residual_block_tensor_names: List[ResidualBlockTensorNames] + transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None + downsample_conv: str = None + + +@dataclass +class UpDecoderBlockTensorNames: + residual_block_tensor_names: List[ResidualBlockTensorNames] + transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None + upsample_conv: str = None + + +@dataclass +class SkipUpDecoderBlockTensorNames: + residual_block_tensor_names: List[ResidualBlockTensorNames] + transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None + upsample_conv: str = None + + +def _map_to_converted_state( + state: Dict[str, torch.Tensor], + state_param: str, + converted_state: Dict[str, torch.Tensor], + converted_state_param: str, +): + converted_state[f"{converted_state_param}.weight"] = state.pop( + f"{state_param}.weight" + ) + if f"{state_param}.bias" in state: + converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias") + + +class BaseLoader(loader.ModelLoader): + + def _map_residual_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + tensor_names: ResidualBlockTensorNames, + converted_state_param_prefix: str, + config: unet_config.ResidualBlock2DConfig, + ): + _map_to_converted_state( + state, + tensor_names.norm_1, + converted_state, + f"{converted_state_param_prefix}.norm_1", + ) + _map_to_converted_state( + state, + tensor_names.conv_1, + converted_state, + f"{converted_state_param_prefix}.conv_1", + ) + _map_to_converted_state( + state, + tensor_names.norm_2, + converted_state, + f"{converted_state_param_prefix}.norm_2", + ) + _map_to_converted_state( + state, + tensor_names.conv_2, + converted_state, + f"{converted_state_param_prefix}.conv_2", + ) + if config.in_channels != config.out_channels: + _map_to_converted_state( + state, + tensor_names.residual_layer, + converted_state, + f"{converted_state_param_prefix}.residual_layer", + ) + if config.time_embedding_channels is not None: + _map_to_converted_state( + state, + tensor_names.time_embedding, + converted_state, + f"{converted_state_param_prefix}.time_emb_proj", + ) + + def _map_attention_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + tensor_names: AttentionBlockTensorNames, + converted_state_param_prefix: str, + config: unet_config.AttentionBlock2DConfig, + ): + if config.normalization_config.type != layers_config.NormalizationType.NONE: + _map_to_converted_state( + state, + tensor_names.norm, + converted_state, + f"{converted_state_param_prefix}.norm", + ) + attention_layer_prefix = f"{converted_state_param_prefix}.attention" + _map_to_converted_state( + state, + tensor_names.fused_qkv_proj, + converted_state, + f"{attention_layer_prefix}.qkv_projection", + ) + _map_to_converted_state( + state, + tensor_names.output_proj, + converted_state, + f"{attention_layer_prefix}.output_projection", + ) + + def _map_cross_attention_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + tensor_names: CrossAttentionBlockTensorNames, + converted_state_param_prefix: str, + config: unet_config.CrossAttentionBlock2DConfig, + ): + if config.normalization_config.type != layers_config.NormalizationType.NONE: + _map_to_converted_state( + state, + tensor_names.norm, + converted_state, + f"{converted_state_param_prefix}.norm", + ) + attention_layer_prefix = f"{converted_state_param_prefix}.attention" + _map_to_converted_state( + state, + tensor_names.q_proj, + converted_state, + f"{attention_layer_prefix}.q_projection", + ) + _map_to_converted_state( + state, + tensor_names.k_proj, + converted_state, + f"{attention_layer_prefix}.k_projection", + ) + _map_to_converted_state( + state, + tensor_names.v_proj, + converted_state, + f"{attention_layer_prefix}.v_projection", + ) + _map_to_converted_state( + state, + tensor_names.output_proj, + converted_state, + f"{attention_layer_prefix}.output_projection", + ) + + def _map_feedforward_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + tensor_names: FeedForwardBlockTensorNames, + converted_state_param_prefix: str, + config: unet_config.FeedForwardBlock2DConfig, + ): + _map_to_converted_state( + state, + tensor_names.norm, + converted_state, + f"{converted_state_param_prefix}.norm", + ) + if config.activation_config.type == layers_config.ActivationType.GE_GLU: + _map_to_converted_state( + state, + tensor_names.ge_glu, + converted_state, + f"{converted_state_param_prefix}.act.proj", + ) + else: + _map_to_converted_state( + state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1" + ) + + _map_to_converted_state( + state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2" + ) + + def _map_transformer_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + tensor_names: TransformerBlockTensorNames, + converted_state_param_prefix: str, + config: unet_config.TransformerBlock2Dconfig, + ): + _map_to_converted_state( + state, + tensor_names.pre_conv_norm, + converted_state, + f"{converted_state_param_prefix}.pre_conv_norm", + ) + _map_to_converted_state( + state, + tensor_names.conv_in, + converted_state, + f"{converted_state_param_prefix}.conv_in", + ) + self._map_attention_block( + state, + converted_state, + tensor_names.self_attention, + f"{converted_state_param_prefix}.self_attention", + config.attention_block_config, + ) + self._map_cross_attention_block( + state, + converted_state, + tensor_names.cross_attention, + f"{converted_state_param_prefix}.cross_attention", + config.cross_attention_block_config, + ) + self._map_feedforward_block( + state, + converted_state, + tensor_names.feed_forward, + f"{converted_state_param_prefix}.feed_forward", + config.feed_forward_block_config, + ) + _map_to_converted_state( + state, + tensor_names.conv_out, + converted_state, + f"{converted_state_param_prefix}.conv_out", + ) + + def _map_mid_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + tensor_names: MidBlockTensorNames, + converted_state_param_prefix: str, + config: unet_config.MidBlock2DConfig, + ): + residual_block_config = unet_config.ResidualBlock2DConfig( + in_channels=config.in_channels, + out_channels=config.in_channels, + time_embedding_channels=config.time_embedding_channels, + normalization_config=config.normalization_config, + activation_config=config.activation_config, + ) + self._map_residual_block( + state, + converted_state, + tensor_names.residual_block_tensor_names[0], + f"{converted_state_param_prefix}.resnets.0", + residual_block_config, + ) + for i in range(config.num_layers): + if config.attention_block_config: + self._map_attention_block( + state, + converted_state, + tensor_names.attention_block_tensor_names[i], + f"{converted_state_param_prefix}.attentions.{i}", + config.attention_block_config, + ) + if config.transformer_block_config: + self._map_transformer_block( + state, + converted_state, + tensor_names.transformer_block_tensor_names[i], + f"{converted_state_param_prefix}.transformers.{i}", + config.transformer_block_config, + ) + self._map_residual_block( + state, + converted_state, + tensor_names.residual_block_tensor_names[i + 1], + f"{converted_state_param_prefix}.resnets.{i+1}", + residual_block_config, + ) + + def _map_down_encoder_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + converted_state_param_prefix: str, + config: unet_config.DownEncoderBlock2DConfig, + tensor_names: DownEncoderBlockTensorNames, + ): + for i in range(config.num_layers): + input_channels = config.in_channels if i == 0 else config.out_channels + self._map_residual_block( + state, + converted_state, + tensor_names.residual_block_tensor_names[i], + f"{converted_state_param_prefix}.resnets.{i}", + unet_config.ResidualBlock2DConfig( + in_channels=input_channels, + out_channels=config.out_channels, + time_embedding_channels=config.time_embedding_channels, + normalization_config=config.normalization_config, + activation_config=config.activation_config, + ), + ) + if config.transformer_block_config: + self._map_transformer_block( + state, + converted_state, + tensor_names.transformer_block_tensor_names[i], + f"{converted_state_param_prefix}.transformers.{i}", + config.transformer_block_config, + ) + if ( + config.add_downsample + and config.sampling_config.mode == unet_config.SamplingType.CONVOLUTION + ): + _map_to_converted_state( + state, + tensor_names.downsample_conv, + converted_state, + f"{converted_state_param_prefix}.downsampler", + ) + + def _map_up_decoder_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + converted_state_param_prefix: str, + config: unet_config.UpDecoderBlock2DConfig, + tensor_names: UpDecoderBlockTensorNames, + ): + for i in range(config.num_layers): + input_channels = config.in_channels if i == 0 else config.out_channels + self._map_residual_block( + state, + converted_state, + tensor_names.residual_block_tensor_names[i], + f"{converted_state_param_prefix}.resnets.{i}", + unet_config.ResidualBlock2DConfig( + in_channels=input_channels, + out_channels=config.out_channels, + time_embedding_channels=config.time_embedding_channels, + normalization_config=config.normalization_config, + activation_config=config.activation_config, + ), + ) + if config.transformer_block_config: + self._map_transformer_block( + state, + converted_state, + tensor_names.transformer_block_tensor_names[i], + f"{converted_state_param_prefix}.transformers.{i}", + config.transformer_block_config, + ) + if config.add_upsample and config.upsample_conv: + _map_to_converted_state( + state, + tensor_names.upsample_conv, + converted_state, + f"{converted_state_param_prefix}.upsample_conv", + ) + + def _map_skip_up_decoder_block( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + converted_state_param_prefix: str, + config: unet_config.SkipUpDecoderBlock2DConfig, + tensor_names: UpDecoderBlockTensorNames, + ): + for i in range(config.num_layers): + res_skip_channels = ( + config.in_channels if (i == config.num_layers - 1) else config.out_channels + ) + resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels + self._map_residual_block( + state, + converted_state, + tensor_names.residual_block_tensor_names[i], + f"{converted_state_param_prefix}.resnets.{i}", + unet_config.ResidualBlock2DConfig( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=config.out_channels, + time_embedding_channels=config.time_embedding_channels, + normalization_config=config.normalization_config, + activation_config=config.activation_config, + ), + ) + if config.transformer_block_config: + self._map_transformer_block( + state, + converted_state, + tensor_names.transformer_block_tensor_names[i], + f"{converted_state_param_prefix}.transformers.{i}", + config.transformer_block_config, + ) + if config.add_upsample and config.upsample_conv: + _map_to_converted_state( + state, + tensor_names.upsample_conv, + converted_state, + f"{converted_state_param_prefix}.upsample_conv", + ) + + +class AutoEncoderModelLoader(BaseLoader): + + @dataclass + class TensorNames: + quant_conv: str = None + post_quant_conv: str = None + conv_in: str = None + conv_out: str = None + final_norm: str = None + mid_block_tensor_names: MidBlockTensorNames = None + up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None + + def __init__(self, file_name: str, names: TensorNames): + """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models. + + Args: + file_name (str): Path to the checkpoint. Can be a directory or an + exact file. + names (TensorNames): An instance of `TensorNames` to determine mappings. + """ + self._file_name = file_name + self._names = names + self._loader = self._get_loader() + + def load( + self, model: torch.nn.Module, strict: bool = True + ) -> Tuple[List[str], List[str]]: + """Load the model from the checkpoint. + + Args: + model (torch.nn.Module): The pytorch model that needs to be loaded. + strict (bool, optional): Whether the converted keys are strictly + matched. Defaults to True. + + Returns: + missing_keys (List[str]): a list of str containing the missing keys. + unexpected_keys (List[str]): a list of str containing the unexpected keys. + + Raises: + ValueError: If conversion results in unmapped tensors and strict mode is + enabled. + """ + state = self._loader(self._file_name) + converted_state = dict() + if self._names.quant_conv is not None: + _map_to_converted_state( + state, self._names.quant_conv, converted_state, "quant_conv" + ) + if self._names.post_quant_conv is not None: + _map_to_converted_state( + state, self._names.post_quant_conv, converted_state, "post_quant_conv" + ) + if self._names.conv_in is not None: + _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in") + if self._names.conv_out is not None: + _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out") + if self._names.final_norm is not None: + _map_to_converted_state( + state, self._names.final_norm, converted_state, "final_norm" + ) + self._map_mid_block( + state, + converted_state, + self._names.mid_block_tensor_names, + "mid_block", + model.config.mid_block_config, + ) + + reversed_block_out_channels = list(reversed(model.config.block_out_channels)) + block_out_channels = reversed_block_out_channels[0] + for i, out_channels in enumerate(reversed_block_out_channels): + prev_output_channel = block_out_channels + block_out_channels = out_channels + not_final_block = i < len(reversed_block_out_channels) - 1 + self._map_up_decoder_block( + state, + converted_state, + f"up_decoder_blocks.{i}", + unet_config.UpDecoderBlock2DConfig( + in_channels=prev_output_channel, + out_channels=block_out_channels, + normalization_config=model.config.normalization_config, + activation_config=model.config.activation_config, + num_layers=model.config.layers_per_block, + add_upsample=not_final_block, + upsample_conv=True, + ), + self._names.up_decoder_blocks_tensor_names[i], + ) + if strict and state: + raise ValueError( + f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}" + ) + return model.load_state_dict(converted_state, strict=strict) + + +class DiffusionModelLoader(BaseLoader): + + @dataclass + class TensorNames: + time_embedding: TimeEmbeddingTensorNames = None + conv_in: str = None + conv_out: str = None + final_norm: str = None + down_encoder_blocks_tensor_names: List[DownEncoderBlockTensorNames] = None + mid_block_tensor_names: MidBlockTensorNames = None + up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None + + def __init__(self, file_name: str, names: TensorNames): + """DiffusionModelLoader constructor. Can be used to load diffusion models of Stable Diffusion. + + Args: + file_name (str): Path to the checkpoint. Can be a directory or an + exact file. + names (TensorNames): An instance of `TensorNames` to determine mappings. + """ + self._file_name = file_name + self._names = names + self._loader = self._get_loader() + + def load( + self, model: torch.nn.Module, strict: bool = True + ) -> Tuple[List[str], List[str]]: + """Load the model from the checkpoint. + + Args: + model (torch.nn.Module): The pytorch model that needs to be loaded. + strict (bool, optional): Whether the converted keys are strictly + matched. Defaults to True. + + Returns: + missing_keys (List[str]): a list of str containing the missing keys. + unexpected_keys (List[str]): a list of str containing the unexpected keys. + + Raises: + ValueError: If conversion results in unmapped tensors and strict mode is + enabled. + """ + state = self._loader(self._file_name) + converted_state = dict() + config: unet_config.DiffusionModelConfig = model.config + self._map_time_embedding( + state, converted_state, "time_embedding", self._names.time_embedding + ) + _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in") + _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out") + _map_to_converted_state( + state, self._names.final_norm, converted_state, "final_norm" + ) + + attention_config = layers_config.AttentionConfig( + num_heads=config.transformer_num_attention_heads, + num_query_groups=config.transformer_num_attention_heads, + rotary_percentage=0.0, + qkv_transpose_before_split=True, + qkv_use_bias=False, + output_proj_use_bias=True, + enable_kv_cache=False, + ) + + # Map down_encoders. + output_channel = config.block_out_channels[0] + for i, block_out_channel in enumerate(config.block_out_channels): + input_channel = output_channel + output_channel = block_out_channel + not_final_block = i < len(config.block_out_channels) - 1 + if not_final_block: + down_encoder_block_config = unet_config.DownEncoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_config.ActivationConfig( + config.residual_activation_type + ), + num_layers=config.layers_per_block, + padding=config.downsample_padding, + time_embedding_channels=config.time_embedding_blocks_dim, + add_downsample=True, + sampling_config=unet_config.DownSamplingConfig( + mode=unet_config.SamplingType.CONVOLUTION, + in_channels=output_channel, + out_channels=output_channel, + kernel_size=3, + stride=2, + padding=config.downsample_padding, + ), + transformer_block_config=unet_config.TransformerBlock2Dconfig( + attention_block_config=unet_config.AttentionBlock2DConfig( + dim=output_channel, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig( + query_dim=output_channel, + cross_dim=config.transformer_cross_attention_dim, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + pre_conv_normalization_config=config.transformer_pre_conv_norm_config, + feed_forward_block_config=unet_config.FeedForwardBlock2DConfig( + dim=output_channel, + hidden_dim=output_channel * 4, + normalization_config=config.transformer_norm_config, + activation_config=layers_config.ActivationConfig( + type=config.transformer_ff_activation_type, + dim_in=output_channel, + dim_out=output_channel * 4, + ), + use_bias=True, + ), + ), + ) + else: + down_encoder_block_config = unet_config.DownEncoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_config.ActivationConfig( + config.residual_activation_type + ), + num_layers=config.layers_per_block, + padding=config.downsample_padding, + time_embedding_channels=config.time_embedding_blocks_dim, + add_downsample=False, + ) + + self._map_down_encoder_block( + state, + converted_state, + f"down_encoders.{i}", + down_encoder_block_config, + self._names.down_encoder_blocks_tensor_names[i], + ) + + # Map mid block. + mid_block_channels = config.block_out_channels[-1] + mid_block_config = unet_config.MidBlock2DConfig( + in_channels=mid_block_channels, + normalization_config=config.residual_norm_config, + activation_config=layers_config.ActivationConfig( + config.residual_activation_type + ), + num_layers=config.mid_block_layers, + time_embedding_channels=config.time_embedding_blocks_dim, + transformer_block_config=unet_config.TransformerBlock2Dconfig( + attention_block_config=unet_config.AttentionBlock2DConfig( + dim=mid_block_channels, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig( + query_dim=mid_block_channels, + cross_dim=config.transformer_cross_attention_dim, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + pre_conv_normalization_config=config.transformer_pre_conv_norm_config, + feed_forward_block_config=unet_config.FeedForwardBlock2DConfig( + dim=mid_block_channels, + hidden_dim=mid_block_channels * 4, + normalization_config=config.transformer_norm_config, + activation_config=layers_config.ActivationConfig( + type=config.transformer_ff_activation_type, + dim_in=mid_block_channels, + dim_out=mid_block_channels * 4, + ), + use_bias=True, + ), + ), + ) + self._map_mid_block( + state, + converted_state, + self._names.mid_block_tensor_names, + "mid_block", + mid_block_config, + ) + + # Map up_decoders. + reversed_block_out_channels = list(reversed(model.config.block_out_channels)) + up_decoder_layers_per_block = config.layers_per_block + 1 + output_channel = reversed_block_out_channels[0] + for i, block_out_channel in enumerate(reversed_block_out_channels): + prev_out_channel = output_channel + output_channel = block_out_channel + input_channel = reversed_block_out_channels[ + min(i + 1, len(reversed_block_out_channels) - 1) + ] + not_final_block = i < len(reversed_block_out_channels) - 1 + not_first_block = i != 0 + if not_first_block: + up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + prev_out_channels=prev_out_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_config.ActivationConfig( + config.residual_activation_type + ), + num_layers=up_decoder_layers_per_block, + time_embedding_channels=config.time_embedding_blocks_dim, + add_upsample=not_final_block, + upsample_conv=True, + sampling_config=unet_config.UpSamplingConfig( + mode=unet_config.SamplingType.NEAREST, + scale_factor=2, + ), + transformer_block_config=unet_config.TransformerBlock2Dconfig( + attention_block_config=unet_config.AttentionBlock2DConfig( + dim=output_channel, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig( + query_dim=output_channel, + cross_dim=config.transformer_cross_attention_dim, + normalization_config=config.transformer_norm_config, + attention_config=attention_config, + ), + pre_conv_normalization_config=config.transformer_pre_conv_norm_config, + feed_forward_block_config=unet_config.FeedForwardBlock2DConfig( + dim=output_channel, + hidden_dim=output_channel * 4, + normalization_config=config.transformer_norm_config, + activation_config=layers_config.ActivationConfig( + type=config.transformer_ff_activation_type, + dim_in=output_channel, + dim_out=output_channel * 4, + ), + use_bias=True, + ), + ), + ) + else: + up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig( + in_channels=input_channel, + out_channels=output_channel, + prev_out_channels=prev_out_channel, + normalization_config=config.residual_norm_config, + activation_config=layers_config.ActivationConfig( + config.residual_activation_type + ), + num_layers=up_decoder_layers_per_block, + time_embedding_channels=config.time_embedding_blocks_dim, + add_upsample=not_final_block, + upsample_conv=True, + sampling_config=unet_config.UpSamplingConfig( + mode=unet_config.SamplingType.NEAREST, scale_factor=2 + ), + ) + self._map_skip_up_decoder_block( + state, + converted_state, + f"up_decoders.{i}", + up_encoder_block_config, + self._names.up_decoder_blocks_tensor_names[i], + ) + if strict and state: + raise ValueError( + f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}" + ) + return model.load_state_dict(converted_state, strict=strict) + + def _map_time_embedding( + self, + state: Dict[str, torch.Tensor], + converted_state: Dict[str, torch.Tensor], + converted_state_param_prefix: str, + tensor_names: TimeEmbeddingTensorNames, + ): + _map_to_converted_state( + state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1" + ) + _map_to_converted_state( + state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2" + )