diff --git a/ai_edge_torch/generative/examples/t5/t5_attention.py b/ai_edge_torch/generative/examples/t5/t5_attention.py index 378a6852..9ef99237 100644 --- a/ai_edge_torch/generative/examples/t5/t5_attention.py +++ b/ai_edge_torch/generative/examples/t5/t5_attention.py @@ -27,6 +27,8 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA +BATCH_SIZE = 1 + class EncoderDecoderBlock(nn.Module): @@ -44,6 +46,7 @@ def __init__( super().__init__() self.atten_func = T5Attention( + BATCH_SIZE, config.embedding_dim, config.attn_config, config.pre_attention_norm_config, @@ -54,6 +57,7 @@ def __init__( # For a decoder, we add a cross attention. if config.is_decoder: self.cross_atten_func = T5Attention( + BATCH_SIZE, config.embedding_dim, config.attn_config, config.pre_attention_norm_config, @@ -127,6 +131,7 @@ class T5Attention(CrossAttention): def __init__( self, + batch: int, dim: int, config: cfg.AttentionConfig, norm_config: cfg.NormalizationConfig, @@ -144,7 +149,7 @@ def __init__( enable_hlfb (bool): whether hlfb is enabled or not. has_relative_attention_bias (bool): whether we compute relative bias. """ - super().__init__(dim, dim, config, kv_cache_max, enable_hlfb) + super().__init__(batch, dim, dim, config, kv_cache_max, enable_hlfb) self.pre_atten_norm = builder.build_norm(dim, norm_config) self.has_relative_attention_bias = has_relative_attention_bias