From 6b0ba07af6db766965d19bade584ff139e728247 Mon Sep 17 00:00:00 2001 From: yichunkuo Date: Tue, 11 Jun 2024 20:10:06 -0700 Subject: [PATCH] Add cross attention layer and refactor t5_attention to inherit from cross attention (#49) --- .../generative/examples/t5/t5_attention.py | 49 ++----- ai_edge_torch/generative/layers/attention.py | 123 ++++++++++++++++-- .../generative/utilities/t5_loader.py | 50 ++++--- 3 files changed, 156 insertions(+), 66 deletions(-) diff --git a/ai_edge_torch/generative/examples/t5/t5_attention.py b/ai_edge_torch/generative/examples/t5/t5_attention.py index 08ea16d7..378a6852 100644 --- a/ai_edge_torch/generative/examples/t5/t5_attention.py +++ b/ai_edge_torch/generative/examples/t5/t5_attention.py @@ -20,6 +20,7 @@ from torch import nn import torch.nn.functional as F +from ai_edge_torch.generative.layers.attention import CrossAttention import ai_edge_torch.generative.layers.builder as builder from ai_edge_torch.generative.layers.kv_cache import KVCache import ai_edge_torch.generative.layers.model_config as cfg @@ -122,7 +123,7 @@ def forward( return hidden_states, position_bias, encoder_decoder_position_bias -class T5Attention(nn.Module): +class T5Attention(CrossAttention): def __init__( self, @@ -138,51 +139,21 @@ def __init__( Args: dim (int): causal attention's input/output dimmension. config (cfg.AttentionConfig): attention specific configurations. + norm_config (cfg.NormalizationConfig): normalization configure before attention. kv_cache_max (int): determines the size of the KV Cache buffer, if enabled. enable_hlfb (bool): whether hlfb is enabled or not. has_relative_attention_bias (bool): whether we compute relative bias. """ - super().__init__() + super().__init__(dim, dim, config, kv_cache_max, enable_hlfb) self.pre_atten_norm = builder.build_norm(dim, norm_config) self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.d_model = dim - self.head_dim = dim // config.num_heads - self.n_heads = config.num_heads - self.inner_dim = self.n_heads * self.head_dim - - self.q = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias) - self.k = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias) - self.v = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias) - # output projection - self.proj = nn.Linear( - self.inner_dim, self.d_model, bias=config.output_proj_use_bias - ) - if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding( self.relative_attention_num_buckets, self.n_heads ) - self.config = config - self.kv_cache = None - # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim). - # Now only supports a max batch_size of 1. - if config.enable_kv_cache: - self.kv_cache = KVCache( - 1, - kv_cache_max, - config.num_query_groups, - self.head_dim, - enable_hlfb, - ) - - if enable_hlfb: - self.sdpa_func = scaled_dot_product_attention_with_hlfb - else: - self.sdpa_func = scaled_dot_product_attention - def forward( self, x: torch.Tensor, @@ -206,7 +177,7 @@ def forward( x = self.pre_atten_norm(x) B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - query_states = self.q(x) + query_states = self.q_projection(x) query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs) if key_value_states is not None: @@ -217,13 +188,13 @@ def forward( ) = ( key_value_states.size() ) # batch size, sequence length, embedding dimensionality (n_embd) - key_states = self.k(key_value_states) - value_states = self.v(key_value_states) + key_states = self.k_projection(key_value_states) + value_states = self.v_projection(key_value_states) key_states = key_states.reshape(kvB, kvT, -1, self.head_dim) value_states = value_states.reshape(kvB, kvT, -1, self.head_dim) else: - key_states = self.k(x) - value_states = self.v(x) + key_states = self.k_projection(x) + value_states = self.v_projection(x) key_states = key_states.reshape(B, T, -1, self.head_dim) value_states = value_states.reshape(B, T, -1, self.head_dim) @@ -251,5 +222,5 @@ def forward( ) y = y.reshape(B, T, C) # re-assemble all head outputs side by side # output projection - y = self.proj(y) + y = self.output_projection(y) return y, position_bias diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index e3b5ecd6..80aea87a 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -28,6 +28,33 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA +def _embed_rope( + q: torch.Tensor, + k: torch.Tensor, + n_elem: int, + rope: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed rotary positional embedding for query and key. + + Args: + q (torch.Tensor): query tensor. + k (torch.Tensor): key tensor. + n_elem (int): number of elements to embed rotarty positional embedding. + rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor. + """ + if n_elem > 0: + cos, sin = rope + q_roped = rotary_pos_emb.apply_rope( + q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) + ) + k_roped = rotary_pos_emb.apply_rope( + k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) + ) + q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) + k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + return q, k + + class TransformerBlock(nn.Module): def __init__(self, config: cfg.ModelConfig) -> None: @@ -178,16 +205,7 @@ def forward( # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.head_dim) - if n_elem > 0: - cos, sin = rope - q_roped = rotary_pos_emb.apply_rope( - q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - k_roped = rotary_pos_emb.apply_rope( - k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) - k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + q, k = _embed_rope(q, k, n_elem, rope) if self.kv_cache is not None: # TODO(haoliang): Handle when execeeding max sequence length. @@ -224,3 +242,88 @@ def forward( return super().forward( x, rope=rope, mask=torch.zeros((B, T), dtype=torch.float32), input_pos=input_pos ) + + +class CrossAttention(nn.Module): + + def __init__( + self, + query_dim: int, + cross_dim: int, + config: cfg.AttentionConfig, + kv_cache_max: int, + enable_hlfb: bool, + ) -> None: + """Initialize an instance of CrossAttention. + + Args: + query_dim (int): query tensor's dimension. + cross_dim (int): cross attention's dimensions, for key and value tensors. + config (cfg.AttentionConfig): attention specific configurations. + kv_cache_max (int): determines the size of the KV Cache buffer, if enabled. + enable_hlfb (bool): whether hlfb is enabled or not. + """ + super().__init__() + self.config = config + self.head_dim = query_dim // config.num_heads + self.n_heads = config.num_heads + self.q_projection = nn.Linear(query_dim, query_dim, bias=config.qkv_use_bias) + self.k_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias) + self.v_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias) + self.output_projection = nn.Linear( + query_dim, query_dim, bias=config.output_proj_use_bias + ) + + self.kv_cache = None + # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim). + # Now only supports a max batch_size of 1. + if config.enable_kv_cache: + self.kv_cache = KVCache( + 1, + kv_cache_max, + config.num_query_groups, + self.head_dim, + enable_hlfb, + ) + + if enable_hlfb: + self.sdpa_func = scaled_dot_product_attention_with_hlfb + else: + self.sdpa_func = scaled_dot_product_attention + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ): + B, T, E = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + q = self.q_projection(x) + k = self.k_projection(x) + v = self.v_projection(x) + + interim_shape = (B, T, self.n_heads, self.head_dim) + q = q.view(interim_shape) + k = k.view(interim_shape) + v = v.view(interim_shape) + + if mask is None: + mask = torch.zeros((B, T), dtype=torch.float32) + + # Compute rotary positional embedding for query and key. + n_elem = int(self.config.rotary_percentage * self.head_dim) + q, k = _embed_rope(q, k, n_elem, rope) + + if self.kv_cache is not None: + # TODO(haoliang): Handle when execeeding max sequence length. + k, v = self.kv_cache.update_cache(input_pos, k, v) + + y = self.sdpa_func(q, k, v, self.head_dim, mask=mask) + y = y.reshape(B, T, E) + + # Compute the output projection. + y = self.output_projection(y) + return y diff --git a/ai_edge_torch/generative/utilities/t5_loader.py b/ai_edge_torch/generative/utilities/t5_loader.py index 54aab3e1..5a826da6 100644 --- a/ai_edge_torch/generative/utilities/t5_loader.py +++ b/ai_edge_torch/generative/utilities/t5_loader.py @@ -318,7 +318,7 @@ def _map_attention( q_name = names.attn_query_proj.format(idx) k_name = names.attn_key_proj.format(idx) v_name = names.attn_value_proj.format(idx) - # model.encoder.transformer_blocks[0].atten_func.q.weight + # model.encoder.transformer_blocks[0].atten_func.q_projection.weight if fuse_attention: converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv( config, @@ -334,18 +334,34 @@ def _map_attention( state.pop(f"{v_name}.bias"), ) else: - converted_state[f"{prefix}.atten_func.q.weight"] = state.pop(f"{q_name}.weight") - converted_state[f"{prefix}.atten_func.k.weight"] = state.pop(f"{k_name}.weight") - converted_state[f"{prefix}.atten_func.v.weight"] = state.pop(f"{v_name}.weight") + converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop( + f"{q_name}.weight" + ) + converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop( + f"{k_name}.weight" + ) + converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop( + f"{v_name}.weight" + ) if config.attn_config.qkv_use_bias: - converted_state[f"{prefix}.atten_func.q.bias"] = state.pop(f"{q_name}.bias") - converted_state[f"{prefix}.atten_func.k.bias"] = state.pop(f"{k_name}.bias") - converted_state[f"{prefix}.atten_func.v.bias"] = state.pop(f"{v_name}.bias") + converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop( + f"{q_name}.bias" + ) + converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop( + f"{k_name}.bias" + ) + converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop( + f"{v_name}.bias" + ) o_name = names.attn_output_proj.format(idx) - converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight") + converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop( + f"{o_name}.weight" + ) if config.attn_config.output_proj_use_bias: - converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias") + converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop( + f"{o_name}.bias" + ) def _map_cross_attention( self, @@ -383,32 +399,32 @@ def _map_cross_attention( state.pop(f"{v_name}.bias"), ) else: - converted_state[f"{prefix}.cross_atten_func.q.weight"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop( f"{q_name}.weight" ) - converted_state[f"{prefix}.cross_atten_func.k.weight"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop( f"{k_name}.weight" ) - converted_state[f"{prefix}.cross_atten_func.v.weight"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop( f"{v_name}.weight" ) if config.attn_config.qkv_use_bias: - converted_state[f"{prefix}.cross_atten_func.q.bias"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop( f"{q_name}.bias" ) - converted_state[f"{prefix}.cross_atten_func.k.bias"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop( f"{k_name}.bias" ) - converted_state[f"{prefix}.cross_atten_func.v.bias"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop( f"{v_name}.bias" ) o_name = names.cross_attn_output_proj.format(idx) - converted_state[f"{prefix}.cross_atten_func.proj.weight"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop( f"{o_name}.weight" ) if config.attn_config.output_proj_use_bias: - converted_state[f"{prefix}.cross_atten_func.proj.bias"] = state.pop( + converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop( f"{o_name}.bias" )