From e50f7a756983d3e76a0f2cfdec7fd92cde3588e6 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Mon, 16 Sep 2024 12:53:25 -0700 Subject: [PATCH] Update CrossAttention config to have hidden_dim and output_dim. PiperOrigin-RevId: 675257767 --- .../examples/stable_diffusion/diffusion.py | 6 ++++++ ai_edge_torch/generative/layers/attention.py | 12 ++++++++---- ai_edge_torch/generative/layers/unet/blocks_2d.py | 2 ++ ai_edge_torch/generative/layers/unet/model_config.py | 2 ++ .../generative/utilities/stable_diffusion_loader.py | 6 ++++++ 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py index 735f497a..f88f8fa3 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py @@ -336,6 +336,8 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( query_dim=output_channel, cross_dim=config.transformer_cross_attention_dim, + hidden_dim=output_channel, + output_dim=output_channel, attention_batch_size=config.transformer_batch_size, normalization_config=config.transformer_norm_config, attention_config=build_attention_config( @@ -406,6 +408,8 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( query_dim=mid_block_channels, cross_dim=config.transformer_cross_attention_dim, + hidden_dim=mid_block_channels, + output_dim=mid_block_channels, attention_batch_size=config.transformer_batch_size, normalization_config=config.transformer_norm_config, attention_config=build_attention_config( @@ -477,6 +481,8 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig): cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig( query_dim=output_channel, cross_dim=config.transformer_cross_attention_dim, + hidden_dim=output_channel, + output_dim=output_channel, attention_batch_size=config.transformer_batch_size, normalization_config=config.transformer_norm_config, attention_config=build_attention_config( diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index 431f5203..32190b54 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -298,6 +298,8 @@ def __init__( batch_size: int, query_dim: int, cross_dim: int, + hidden_dim: int, + output_dim: int, config: cfg.AttentionConfig, enable_hlfb: bool, ): @@ -307,6 +309,8 @@ def __init__( batch_size (int): batch size of the input tensor. query_dim (int): query tensor's dimension. cross_dim (int): cross attention's dimensions, for key and value tensors. + hidden_dim (int): hidden dimension that q, k, v tensors project to. + output_dim (int): output tensor's dimension. config (cfg.AttentionConfig): attention specific configurations. enable_hlfb (bool): whether hlfb is enabled or not. """ @@ -314,16 +318,16 @@ def __init__( self.config = config self.n_heads = config.num_heads self.q_projection = nn.Linear( - query_dim, query_dim, bias=config.qkv_use_bias + query_dim, hidden_dim, bias=config.qkv_use_bias ) self.k_projection = nn.Linear( - cross_dim, query_dim, bias=config.qkv_use_bias + cross_dim, hidden_dim, bias=config.qkv_use_bias ) self.v_projection = nn.Linear( - cross_dim, query_dim, bias=config.qkv_use_bias + cross_dim, hidden_dim, bias=config.qkv_use_bias ) self.output_projection = nn.Linear( - query_dim, query_dim, bias=config.output_proj_use_bias + hidden_dim, output_dim, bias=config.output_proj_use_bias ) self.sdpa_func = ( diff --git a/ai_edge_torch/generative/layers/unet/blocks_2d.py b/ai_edge_torch/generative/layers/unet/blocks_2d.py index 5498c9a3..5eb553ce 100644 --- a/ai_edge_torch/generative/layers/unet/blocks_2d.py +++ b/ai_edge_torch/generative/layers/unet/blocks_2d.py @@ -178,6 +178,8 @@ def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig): config.attention_batch_size, config.query_dim, config.cross_dim, + config.hidden_dim, + config.output_dim, config.attention_config, enable_hlfb=config.enable_hlfb, ) diff --git a/ai_edge_torch/generative/layers/unet/model_config.py b/ai_edge_torch/generative/layers/unet/model_config.py index 24a70dcc..b5d09934 100644 --- a/ai_edge_torch/generative/layers/unet/model_config.py +++ b/ai_edge_torch/generative/layers/unet/model_config.py @@ -68,6 +68,8 @@ class AttentionBlock2DConfig: class CrossAttentionBlock2DConfig: query_dim: int cross_dim: int + hidden_dim: int + output_dim: int normalization_config: layers_cfg.NormalizationConfig attention_config: layers_cfg.AttentionConfig enable_hlfb: bool = True diff --git a/ai_edge_torch/generative/utilities/stable_diffusion_loader.py b/ai_edge_torch/generative/utilities/stable_diffusion_loader.py index 2a9adcf6..c873b186 100644 --- a/ai_edge_torch/generative/utilities/stable_diffusion_loader.py +++ b/ai_edge_torch/generative/utilities/stable_diffusion_loader.py @@ -811,6 +811,8 @@ def load( cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig( query_dim=output_channel, cross_dim=config.transformer_cross_attention_dim, + hidden_dim=output_channel, + output_dim=output_channel, normalization_config=config.transformer_norm_config, attention_config=build_attention_config( num_heads=config.transformer_num_attention_heads, @@ -877,6 +879,8 @@ def load( cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig( query_dim=mid_block_channels, cross_dim=config.transformer_cross_attention_dim, + hidden_dim=mid_block_channels, + output_dim=mid_block_channels, normalization_config=config.transformer_norm_config, attention_config=build_attention_config( num_heads=config.transformer_num_attention_heads, @@ -950,6 +954,8 @@ def load( cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig( query_dim=output_channel, cross_dim=config.transformer_cross_attention_dim, + hidden_dim=output_channel, + output_dim=output_channel, normalization_config=config.transformer_norm_config, attention_config=build_attention_config( num_heads=config.transformer_num_attention_heads,