Skip to content

Commit

Permalink
Update CrossAttention config to have hidden_dim and output_dim.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675257767
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 16, 2024
1 parent 28ce35b commit e50f7a7
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=output_channel,
cross_dim=config.transformer_cross_attention_dim,
hidden_dim=output_channel,
output_dim=output_channel,
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=build_attention_config(
Expand Down Expand Up @@ -406,6 +408,8 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=mid_block_channels,
cross_dim=config.transformer_cross_attention_dim,
hidden_dim=mid_block_channels,
output_dim=mid_block_channels,
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=build_attention_config(
Expand Down Expand Up @@ -477,6 +481,8 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
query_dim=output_channel,
cross_dim=config.transformer_cross_attention_dim,
hidden_dim=output_channel,
output_dim=output_channel,
attention_batch_size=config.transformer_batch_size,
normalization_config=config.transformer_norm_config,
attention_config=build_attention_config(
Expand Down
12 changes: 8 additions & 4 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def __init__(
batch_size: int,
query_dim: int,
cross_dim: int,
hidden_dim: int,
output_dim: int,
config: cfg.AttentionConfig,
enable_hlfb: bool,
):
Expand All @@ -307,23 +309,25 @@ def __init__(
batch_size (int): batch size of the input tensor.
query_dim (int): query tensor's dimension.
cross_dim (int): cross attention's dimensions, for key and value tensors.
hidden_dim (int): hidden dimension that q, k, v tensors project to.
output_dim (int): output tensor's dimension.
config (cfg.AttentionConfig): attention specific configurations.
enable_hlfb (bool): whether hlfb is enabled or not.
"""
super().__init__()
self.config = config
self.n_heads = config.num_heads
self.q_projection = nn.Linear(
query_dim, query_dim, bias=config.qkv_use_bias
query_dim, hidden_dim, bias=config.qkv_use_bias
)
self.k_projection = nn.Linear(
cross_dim, query_dim, bias=config.qkv_use_bias
cross_dim, hidden_dim, bias=config.qkv_use_bias
)
self.v_projection = nn.Linear(
cross_dim, query_dim, bias=config.qkv_use_bias
cross_dim, hidden_dim, bias=config.qkv_use_bias
)
self.output_projection = nn.Linear(
query_dim, query_dim, bias=config.output_proj_use_bias
hidden_dim, output_dim, bias=config.output_proj_use_bias
)

self.sdpa_func = (
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/layers/unet/blocks_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
config.attention_batch_size,
config.query_dim,
config.cross_dim,
config.hidden_dim,
config.output_dim,
config.attention_config,
enable_hlfb=config.enable_hlfb,
)
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/layers/unet/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
class CrossAttentionBlock2DConfig:
query_dim: int
cross_dim: int
hidden_dim: int
output_dim: int
normalization_config: layers_cfg.NormalizationConfig
attention_config: layers_cfg.AttentionConfig
enable_hlfb: bool = True
Expand Down
6 changes: 6 additions & 0 deletions ai_edge_torch/generative/utilities/stable_diffusion_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,8 @@ def load(
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
query_dim=output_channel,
cross_dim=config.transformer_cross_attention_dim,
hidden_dim=output_channel,
output_dim=output_channel,
normalization_config=config.transformer_norm_config,
attention_config=build_attention_config(
num_heads=config.transformer_num_attention_heads,
Expand Down Expand Up @@ -877,6 +879,8 @@ def load(
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
query_dim=mid_block_channels,
cross_dim=config.transformer_cross_attention_dim,
hidden_dim=mid_block_channels,
output_dim=mid_block_channels,
normalization_config=config.transformer_norm_config,
attention_config=build_attention_config(
num_heads=config.transformer_num_attention_heads,
Expand Down Expand Up @@ -950,6 +954,8 @@ def load(
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
query_dim=output_channel,
cross_dim=config.transformer_cross_attention_dim,
hidden_dim=output_channel,
output_dim=output_channel,
normalization_config=config.transformer_norm_config,
attention_config=build_attention_config(
num_heads=config.transformer_num_attention_heads,
Expand Down

0 comments on commit e50f7a7

Please sign in to comment.