From fca58c3f22ef980ad850ad16cf79ce95a1f51330 Mon Sep 17 00:00:00 2001 From: yichunkuo Date: Wed, 12 Jun 2024 11:24:40 -0700 Subject: [PATCH] Add 2D blocks used in diffusion model of stable diffusion. (#50) * Add CrossAttentionBlock2D and TransformerBlock2D which are basic modules used in UNet of diffusion model. * Add DownEncoderBlock2D and downsampling related configs --- .../examples/stable_diffusion/decoder.py | 4 +- .../generative/layers/unet/blocks_2d.py | 478 ++++++++++++++++-- .../generative/layers/unet/builder.py | 22 +- .../generative/layers/unet/model_config.py | 99 +++- 4 files changed, 552 insertions(+), 51 deletions(-) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py index ee3695b9..5966be08 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py @@ -225,7 +225,7 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig): num_layers=config.layers_per_block, add_upsample=not_final_block, upsample_conv=True, - sampling_config=unet_cfg.SamplingConfig( + sampling_config=unet_cfg.UpSamplingConfig( 2, unet_cfg.SamplingType.NEAREST ), ) @@ -271,7 +271,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig: ) att_config = unet_cfg.AttentionBlock2DConfig( - dims=block_out_channels[-1], + dim=block_out_channels[-1], normalization_config=norm_config, attention_config=layers_cfg.AttentionConfig( num_heads=1, diff --git a/ai_edge_torch/generative/layers/unet/blocks_2d.py b/ai_edge_torch/generative/layers/unet/blocks_2d.py index f0608240..f1a5e0b6 100644 --- a/ai_edge_torch/generative/layers/unet/blocks_2d.py +++ b/ai_edge_torch/generative/layers/unet/blocks_2d.py @@ -13,13 +13,15 @@ # limitations under the License. # ============================================================================== -from typing import Optional +from typing import List, Optional import torch from torch import nn +from ai_edge_torch.generative.layers.attention import CrossAttention from ai_edge_torch.generative.layers.attention import SelfAttention import ai_edge_torch.generative.layers.builder as layers_builder +import ai_edge_torch.generative.layers.model_config as layers_cfg import ai_edge_torch.generative.layers.unet.builder as unet_builder import ai_edge_torch.generative.layers.unet.model_config as unet_cfg @@ -78,6 +80,7 @@ def forward( x = self.act_fn(x) x = self.conv_1(x) if self.time_emb_proj is not None: + time_emb = self.act_fn(time_emb) time_emb = self.time_emb_proj(time_emb)[:, :, None, None] x = x + time_emb x = self.norm_2(x) @@ -90,7 +93,8 @@ def forward( class AttentionBlock2D(nn.Module): """2D self attention block - x = SelfAttention(Norm(input_tensor)) + residual = x + x = SelfAttention(Norm(input_tensor)) + residual """ @@ -101,8 +105,8 @@ def __init__(self, config: unet_cfg.AttentionBlock2DConfig): config (unet_cfg.AttentionBlock2DConfig): the configuration of this block. """ super().__init__() - self.norm = layers_builder.build_norm(config.dims, config.normalization_config) - self.attention = SelfAttention(config.dims, config.attention_config, 0, True) + self.norm = layers_builder.build_norm(config.dim, config.normalization_config) + self.attention = SelfAttention(config.dim, config.attention_config, 0, True) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: """Forward function of the AttentionBlock2D. @@ -125,28 +129,287 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return x -class UpDecoderBlock2D(nn.Module): - """Decoder block containing several residual blocks followed by an optional upsampler. +class CrossAttentionBlock2D(nn.Module): + """2D cross attention block - input_tensor - | - ▼ - ┌───────────────────┐ - │ ResidualBlock2D │ num_layers - └─────────┬─────────┘ + residual = x + x = CrossAttention(Norm(input_tensor), context) + residual + + """ + + def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig): + """Initialize an instance of the AttentionBlock2D. + + Args: + config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this block. + """ + super().__init__() + self.config = config + self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config) + self.attention = CrossAttention( + config.query_dim, config.cross_dim, config.attention_config, 0, True + ) + + def forward( + self, input_tensor: torch.Tensor, context_tensor: torch.Tensor + ) -> torch.Tensor: + """Forward function of the CrossAttentionBlock2D. + + Args: + input_tensor (torch.Tensor): the input tensor. + context_tensor (torch.Tensor): the context tensor to apply cross attention on. + + Returns: + output activation tensor after cross attention. + """ + residual = input_tensor + x = self.norm(input_tensor) + B, C, H, W = x.shape + x = x.view(B, C, H * W) + x = x.transpose(-1, -2) + x = self.attention(x, context_tensor) + x = x.transpose(-1, -2) + x = x.view(B, C, H, W) + x = x + residual + return x + + +class FeedForwardBlock2D(nn.Module): + """2D feed forward block + + residual = x + x = w2(Activation(w1(Norm(x)))) + residual + + """ + + def __init__( + self, + config: unet_cfg.FeedForwardBlock2DConfig, + ): + super().__init__() + self.config = config + self.act = layers_builder.get_activation(config.activation_config) + self.norm = layers_builder.build_norm(config.dim, config.normalization_config) + if config.activation_config.type == layers_cfg.ActivationType.GE_GLU: + self.w1 = nn.Identity() + self.w2 = nn.Linear(config.hidden_dim, config.dim) + else: + self.w1 = nn.Linear(config.dim, config.hidden_dim) + self.w2 = nn.Linear(config.hidden_dim, config.dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + B, C, H, W = x.shape + x = x.view((B, C, H * W)) + x = x.transpose(-1, -2) # (B, HW, C) + + x = self.norm(x) + x = self.w1(x) + x = self.act(x) + x = self.w2(x) + + x = x.transpose(-1, -2) # (B, C, HW) + x = x.view((B, C, H, W)) + + return x + residual + + +class TransformerBlock2D(nn.Module): + """Basic transformer block used in UNet of diffusion model + + input_tensor context_tensor + | | + ┌─────────▼─────────┐ | + │ ConvIn | │ + └─────────┬─────────┘ | + | | + ▼ | + ┌───────────────────┐ | + │ Attention Block │ | + └─────────┬─────────┘ | + │ | + ┌────────────────────┐ | + │CrossAttention Block│◄─────┘ + └─────────┬──────────┘ │ ┌─────────▼─────────┐ - │ (Optional) │ - │ Upsampler │ + │ FeedForwardBlock │ └─────────┬─────────┘ │ ┌─────────▼─────────┐ - │ (Optional) │ - │ Conv2D │ + │ ConvOut │ └─────────┬─────────┘ - │ ▼ hidden_states + + + """ + + def __init__(self, config: unet_cfg.TransformerBlock2Dconfig): + """Initialize an instance of the TransformerBlock2D. + + Args: + config (unet_cfg.TransformerBlock2Dconfig): the configuration of this block. + """ + super().__init__() + self.config = config + self.pre_conv_norm = layers_builder.build_norm( + config.attention_block_config.dim, config.pre_conv_normalization_config + ) + self.conv_in = nn.Conv2d( + config.attention_block_config.dim, + config.attention_block_config.dim, + kernel_size=1, + padding=0, + ) + self.self_attention = AttentionBlock2D(config.attention_block_config) + self.cross_attention = CrossAttentionBlock2D(config.cross_attention_block_config) + self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config) + self.conv_out = nn.Conv2d( + config.attention_block_config.dim, + config.attention_block_config.dim, + kernel_size=1, + padding=0, + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor): + """Forward function of the TransformerBlock2D. + + Args: + input_tensor (torch.Tensor): the input tensor. + context_tensor (torch.Tensor): the context tensor to apply cross attention on. + + Returns: + output activation tensor after transformer block. + """ + residual_long = x + + x = self.pre_conv_norm(x) + x = self.conv_in(x) + + x = self.self_attention(x) + x = self.cross_attention(x, context) + x = self.feed_forward(x) + + x = self.conv_out(x) + x = x + residual_long + + return x + + +class DownEncoderBlock2D(nn.Module): + """Encoder block containing several residual blocks with optional interleaved transformer blocks. + + input_tensor + | + ┌──────────────▼─────────────┐ + │ ┌────────────────────┐ │ + │ │ ResidualBlock2D │ │ + │ └──────────┬─────────┘ │ + │ │ │ num_layers + │ ┌────────────────────┐ │ + │ │ (Optional) │ │ + │ │ TransformerBlock2D │ │ + │ └──────────┬─────────┘ │ + └──────────────┬─────────────┘ + │ + ┌──────────▼─────────┐ + │ (Optional) │ + │ Downsampler │ + └──────────┬─────────┘ + │ + ▼ + hidden_states + """ + + def __init__(self, config: unet_cfg.DownEncoderBlock2DConfig): + """Initialize an instance of the DownEncoderBlock2D. + + Args: + config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this block. + """ + super().__init__() + self.config = config + resnets = [] + transformers = [] + for i in range(config.num_layers): + input_channels = config.in_channels if i == 0 else config.out_channels + resnets.append( + ResidualBlock2D( + unet_cfg.ResidualBlock2DConfig( + in_channels=input_channels, + out_channels=config.out_channels, + time_embedding_channels=config.time_embedding_channels, + normalization_config=config.normalization_config, + activation_config=config.activation_config, + ) + ) + ) + if config.transformer_block_config: + transformers.append(TransformerBlock2D(config.transformer_block_config)) + self.resnets = nn.ModuleList(resnets) + self.transformers = nn.ModuleList(transformers) + if config.add_downsample: + self.downsampler = unet_builder.build_downsampling(config.sampling_config) + else: + self.downsampler = None + + def forward( + self, + input_tensor: torch.Tensor, + time_emb: Optional[torch.Tensor] = None, + context_tensor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward function of the DownEncoderBlock2D. + + Args: + input_tensor (torch.Tensor): the input tensor. + time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept + time embedding. + context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block. + + Returns: + output hidden_states tensor after DownEncoderBlock2D. + """ + hidden_states = input_tensor + for resnet, transformer in zip(self.resnets, self.transformers): + hidden_states = resnet(hidden_states, time_emb) + if transformer is not None: + hidden_states = transformer(hidden_states, context_tensor) + if self.downsampler: + hidden_states = self.downsampler(hidden_states) + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + """Decoder block containing several residual blocks with optional interleaved transformer blocks. + + input_tensor + | + ┌──────────────▼─────────────┐ + │ ┌────────────────────┐ │ + │ │ ResidualBlock2D │ │ + │ └──────────┬─────────┘ │ + │ │ │ num_layers + │ ┌────────────────────┐ │ + │ │ (Optional) │ │ + │ │ TransformerBlock2D │ │ + │ └──────────┬─────────┘ │ + └──────────────┬─────────────┘ + │ + ┌──────────▼─────────┐ + │ (Optional) │ + │ Upsampler │ + └──────────┬─────────┘ + │ + ┌──────────▼─────────┐ + │ (Optional) │ + │ Conv2D │ + └──────────┬─────────┘ + │ + ▼ + hidden_states """ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig): @@ -158,6 +421,7 @@ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig): super().__init__() self.config = config resnets = [] + transformers = [] for i in range(config.num_layers): input_channels = config.in_channels if i == 0 else config.out_channels resnets.append( @@ -171,7 +435,10 @@ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig): ) ) ) + if config.transformer_block_config: + transformers.append(TransformerBlock2D(config.transformer_block_config)) self.resnets = nn.ModuleList(resnets) + self.transformers = nn.ModuleList(transformers) if config.add_upsample: self.upsampler = unet_builder.build_upsampling(config.sampling_config) if config.upsample_conv: @@ -182,21 +449,130 @@ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig): self.upsampler = None def forward( - self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None + self, + input_tensor: torch.Tensor, + time_emb: Optional[torch.Tensor] = None, + context_tensor: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward function of the UpDecoderBlock2D. Args: input_tensor (torch.Tensor): the input tensor. time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept - time embedding context. + time embedding. + context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block. Returns: output hidden_states tensor after UpDecoderBlock2D. """ hidden_states = input_tensor - for resnet in self.resnets: + for resnet, transformer in zip(self.resnets, self.transformers): hidden_states = resnet(hidden_states, time_emb) + if transformer is not None: + hidden_states = transformer(hidden_states, context_tensor) + if self.upsampler: + hidden_states = self.upsampler(hidden_states) + if self.upsample_conv: + hidden_states = self.upsample_conv(hidden_states) + return hidden_states + + +class SkipUpDecoderBlock2D(nn.Module): + """Decoder block contains skip connections and residual blocks with optional interleaved transformer blocks. + + input_tensor, skip_connection_tensors + | + ┌──────────────▼─────────────┐ + │ ┌────────────────────┐ │ + │ │ ResidualBlock2D │ │ + │ └──────────┬─────────┘ │ + │ │ │ num_layers + │ ┌────────────────────┐ │ + │ │ (Optional) │ │ + │ │ TransformerBlock2D │ │ + │ └──────────┬─────────┘ │ + └──────────────┬─────────────┘ + │ + ┌──────────▼─────────┐ + │ (Optional) │ + │ Upsampler │ + └──────────┬─────────┘ + │ + ┌──────────▼─────────┐ + │ (Optional) │ + │ Conv2D │ + └──────────┬─────────┘ + │ + ▼ + hidden_states + """ + + def __init__(self, config: unet_cfg.SkipUpDecoderBlock2DConfig): + """Initialize an instance of the SkipUpDecoderBlock2D. + + Args: + config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this block. + """ + super().__init__() + self.config = config + resnets = [] + transformers = [] + for i in range(config.num_layers): + res_skip_channels = ( + config.in_channels if (i == config.num_layers - 1) else config.out_channels + ) + resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels + resnets.append( + ResidualBlock2D( + unet_cfg.ResidualBlock2DConfig( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=config.out_channels, + time_embedding_channels=config.time_embedding_channels, + normalization_config=config.normalization_config, + activation_config=config.activation_config, + ) + ) + ) + if config.transformer_block_config: + transformers.append(TransformerBlock2D(config.transformer_block_config)) + self.resnets = nn.ModuleList(resnets) + self.transformers = nn.ModuleList(transformers) + if config.add_upsample: + self.upsampler = unet_builder.build_upsampling(config.sampling_config) + if config.upsample_conv: + self.upsample_conv = nn.Conv2d( + config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.upsampler = None + + def forward( + self, + input_tensor: torch.Tensor, + skip_connection_tensors: List[torch.Tensor], + time_emb: Optional[torch.Tensor] = None, + context_tensor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward function of the SkipUpDecoderBlock2D. + + Args: + input_tensor (torch.Tensor): the input tensor. + skip_connection_tensors (List[torch.Tensor]): the skip connection tensors from encoder blocks. + time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept + time embedding. + context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block. + + Returns: + output hidden_states tensor after SkipUpDecoderBlock2D. + """ + hidden_states = input_tensor + for resnet, skip_connection_tensor, transformer in zip( + self.resnets, skip_connection_tensors, self.transformers + ): + hidden_states = torch.cat([resnet, skip_connection_tensor], dim=1) + hidden_states = resnet(hidden_states, time_emb) + if transformer is not None: + hidden_states = transformer(hidden_states, context_tensor) if self.upsampler: hidden_states = self.upsampler(hidden_states) if self.upsample_conv: @@ -207,25 +583,30 @@ def forward( class MidBlock2D(nn.Module): """Middle block containing at least one residual blocks with optional interleaved attention blocks. - input_tensor - | - ▼ - ┌───────────────────┐ - │ ResidualBlock2D │ - └─────────┬─────────┘ - │ - ┌─────────────▼─────────────┐ - │ ┌───────────────────┐ │ - │ │ (Optional) │ │ - │ │ AttentionBlock2D │ │ - │ └─────────┬─────────┘ │ num_layers - │ │ │ - │ ┌─────────▼─────────┐ │ - │ │ ResidualBlock2D │ │ - │ └───────────────────┘ │ - └─────────────┬─────────────┘ - │ - ▼ + input_tensor + | + ▼ + ┌───────────────────┐ + │ ResidualBlock2D │ + └─────────┬─────────┘ + │ + ┌──────────────▼─────────────┐ + │ ┌────────────────────┐ │ + │ │ (Optional) │ │ + │ │ AttentionBlock2D │ │ + │ └──────────┬─────────┘ │ + │ │ │ + │ ┌──────────▼─────────┐ │ + │ │ (Optional) │ │ num_layers + │ │ TransformerBlock2D │ │ + │ └──────────┬─────────┘ │ + │ │ │ + │ ┌──────────▼─────────┐ │ + │ │ ResidualBlock2D │ │ + │ └────────────────────┘ │ + └──────────────┬─────────────┘ + │ + ▼ hidden_states """ @@ -249,9 +630,12 @@ def __init__(self, config: unet_cfg.MidBlock2DConfig): ) ] attentions = [] + transformers = [] for i in range(config.num_layers): if self.config.attention_block_config: attentions.append(AttentionBlock2D(config.attention_block_config)) + if self.config.transformer_block_config: + transformers.append(TransformerBlock2D(config.transformer_block_config)) resnets.append( ResidualBlock2D( unet_cfg.ResidualBlock2DConfig( @@ -265,23 +649,33 @@ def __init__(self, config: unet_cfg.MidBlock2DConfig): ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) + self.transformers = nn.ModuleList(transformers) def forward( - self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None + self, + input_tensor: torch.Tensor, + time_emb: Optional[torch.Tensor] = None, + context_tensor: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward function of the MidBlock2D. Args: input_tensor (torch.Tensor): the input tensor. time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept - time embedding context. + time embedding. + context_tensor (torch.Tensor): optional context tensor, if the block if configured to use + transofrmer block. Returns: output hidden_states tensor after MidBlock2D. """ hidden_states = self.resnets[0](input_tensor, time_emb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): + for attn, transformer, resnet in zip( + self.attentions, self.transformers, self.resnets[1:] + ): if attn is not None: hidden_states = attn(hidden_states) + if transformer is not None: + hidden_states = transformer(hidden_states, context_tensor) hidden_states = resnet(hidden_states, time_emb) return hidden_states diff --git a/ai_edge_torch/generative/layers/unet/builder.py b/ai_edge_torch/generative/layers/unet/builder.py index d8b6a15f..f4df76f3 100644 --- a/ai_edge_torch/generative/layers/unet/builder.py +++ b/ai_edge_torch/generative/layers/unet/builder.py @@ -15,15 +15,33 @@ # Builder utils for individual components. from torch import nn -import torch.nn.functional as F import ai_edge_torch.generative.layers.unet.model_config as unet_config -def build_upsampling(config: unet_config.SamplingConfig): +def build_upsampling(config: unet_config.UpSamplingConfig): if config.mode == unet_config.SamplingType.NEAREST: return nn.UpsamplingNearest2d(scale_factor=config.scale_factor) elif config.mode == unet_config.SamplingType.BILINEAR: return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor) else: raise ValueError("Unsupported upsampling type.") + + +def build_downsampling(config: unet_config.DownSamplingConfig): + if config.mode == unet_config.SamplingType.AVERAGE: + return nn.AvgPool2d(config.kernel_size, config.stride, padding=config.padding) + elif config.mode == unet_config.SamplingType.CONVOLUTION: + out_channels = ( + config.in_channels if config.out_channels is None else config.out_channels + ) + padding = (0, 1, 0, 1) if config.padding == 0 else config.padding + return nn.Conv2d( + config.in_channels, + out_channels=out_channels, + kernel_size=config.kernel_size, + stride=config.stride, + padding=padding, + ) + else: + raise ValueError("Unsupported downsampling type.") diff --git a/ai_edge_torch/generative/layers/unet/model_config.py b/ai_edge_torch/generative/layers/unet/model_config.py index 6e491014..0e0e2ca5 100644 --- a/ai_edge_torch/generative/layers/unet/model_config.py +++ b/ai_edge_torch/generative/layers/unet/model_config.py @@ -22,16 +22,28 @@ import ai_edge_torch.generative.layers.model_config as layers_cfg -@dataclass +@enum.unique class SamplingType(enum.Enum): NEAREST = enum.auto() BILINEAR = enum.auto() + AVERAGE = enum.auto() + CONVOLUTION = enum.auto() @dataclass -class SamplingConfig: +class UpSamplingConfig: + mode: SamplingType scale_factor: float + + +@dataclass +class DownSamplingConfig: mode: SamplingType + in_channels: int + kernel_size: int + stride: int + padding: int + out_channels: Optional[int] = None @dataclass @@ -46,11 +58,36 @@ class ResidualBlock2DConfig: @dataclass class AttentionBlock2DConfig: - dims: int + dim: int + normalization_config: layers_cfg.NormalizationConfig + attention_config: layers_cfg.AttentionConfig + + +@dataclass +class CrossAttentionBlock2DConfig: + query_dim: int + cross_dim: int normalization_config: layers_cfg.NormalizationConfig attention_config: layers_cfg.AttentionConfig +@dataclass +class FeedForwardBlock2DConfig: + dim: int + hidden_dim: int + normalization_config: layers_cfg.NormalizationConfig + activation_config: layers_cfg.ActivationConfig + use_bias: bool + + +@dataclass +class TransformerBlock2Dconfig: + pre_conv_normalization_config: layers_cfg.NormalizationConfig + attention_block_config: AttentionBlock2DConfig + cross_attention_block_config: CrossAttentionBlock2DConfig + feed_forward_block_config: FeedForwardBlock2DConfig + + @dataclass class UpDecoderBlock2DConfig: in_channels: int @@ -58,14 +95,62 @@ class UpDecoderBlock2DConfig: normalization_config: layers_cfg.NormalizationConfig activation_config: layers_cfg.ActivationConfig num_layers: int - # Optional time embedding channels if the residual blocks take a time embedding context as input + # Optional time embedding channels if the residual blocks take a time embedding as input + time_embedding_channels: Optional[int] = None + # Whether to add upsample operation after residual blocks + add_upsample: bool = True + # Whether to add a conv2d layer after upsample + upsample_conv: bool = True + # Optional sampling config if add_upsample is True. + sampling_config: Optional[UpSamplingConfig] = None + # Optional config of transformer blocks interleaved with residual blocks + transformer_block_config: Optional[TransformerBlock2Dconfig] = None + # Optional dimension of context tensor if context tensor is given as input. + context_dim: Optional[int] = None + + +@dataclass +class SkipUpDecoderBlock2DConfig: + in_channels: int + out_channels: int + # The dimension of output channels of previous connected block + prev_out_channels: int + normalization_config: layers_cfg.NormalizationConfig + activation_config: layers_cfg.ActivationConfig + num_layers: int + # Optional time embedding channels if the residual blocks take a time embedding as input time_embedding_channels: Optional[int] = None # Whether to add upsample operation after residual blocks add_upsample: bool = True # Whether to add a conv2d layer after upsample upsample_conv: bool = True # Optional sampling config if add_upsample is True. - sampling_config: Optional[SamplingConfig] = None + sampling_config: Optional[UpSamplingConfig] = None + # Optional config of transformer blocks interleaved with residual blocks + transformer_block_config: Optional[TransformerBlock2Dconfig] = None + # Optional dimension of context tensor if context tensor is given as input. + context_dim: Optional[int] = None + + +@dataclass +class DownEncoderBlock2DConfig: + in_channels: int + out_channels: int + normalization_config: layers_cfg.NormalizationConfig + activation_config: layers_cfg.ActivationConfig + num_layers: int + # Padding for the downsampling convolution. + padding: int = 1 + # Optional time embedding channels if the residual blocks take a time embedding as input + time_embedding_channels: Optional[int] = None + # Whether to add downsample operation after residual blocks + add_downsample: bool = True + # Optional sampling config if add_upsample is True. + sampling_config: Optional[DownSamplingConfig] = None + # Optional config of transformer blocks interleaved with residual blocks + transformer_block_config: Optional[TransformerBlock2Dconfig] = None + # Optional dimension of context tensor if context tensor is given as input. + context_dim: Optional[int] = None @dataclass @@ -78,6 +163,10 @@ class MidBlock2DConfig: time_embedding_channels: Optional[int] = None # Optional config of attention blocks interleaved with residual blocks attention_block_config: Optional[AttentionBlock2DConfig] = None + # Optional config of transformer blocks interleaved with residual blocks + transformer_block_config: Optional[TransformerBlock2Dconfig] = None + # Optional dimension of context tensor if context tensor is given as input. + context_dim: Optional[int] = None @dataclass