Skip to content

Commit

Permalink
Call CrossAttention __init__ with batch dim (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
talumbau authored Jul 10, 2024
1 parent 68752fe commit 70de3be
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ai_edge_torch/generative/examples/t5/t5_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA

BATCH_SIZE = 1


class EncoderDecoderBlock(nn.Module):

Expand All @@ -44,6 +46,7 @@ def __init__(

super().__init__()
self.atten_func = T5Attention(
BATCH_SIZE,
config.embedding_dim,
config.attn_config,
config.pre_attention_norm_config,
Expand All @@ -54,6 +57,7 @@ def __init__(
# For a decoder, we add a cross attention.
if config.is_decoder:
self.cross_atten_func = T5Attention(
BATCH_SIZE,
config.embedding_dim,
config.attn_config,
config.pre_attention_norm_config,
Expand Down Expand Up @@ -127,6 +131,7 @@ class T5Attention(CrossAttention):

def __init__(
self,
batch: int,
dim: int,
config: cfg.AttentionConfig,
norm_config: cfg.NormalizationConfig,
Expand All @@ -144,7 +149,7 @@ def __init__(
enable_hlfb (bool): whether hlfb is enabled or not.
has_relative_attention_bias (bool): whether we compute relative bias.
"""
super().__init__(dim, dim, config, kv_cache_max, enable_hlfb)
super().__init__(batch, dim, dim, config, kv_cache_max, enable_hlfb)
self.pre_atten_norm = builder.build_norm(dim, norm_config)

self.has_relative_attention_bias = has_relative_attention_bias
Expand Down

0 comments on commit 70de3be

Please sign in to comment.