Skip to content

Commit

Permalink
Add cross attention layer and refactor t5_attention to inherit from c…
Browse files Browse the repository at this point in the history
…ross attention (#49)
  • Loading branch information
yichunk authored Jun 12, 2024
1 parent c0187c2 commit 6b0ba07
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 66 deletions.
49 changes: 10 additions & 39 deletions ai_edge_torch/generative/examples/t5/t5_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch import nn
import torch.nn.functional as F

from ai_edge_torch.generative.layers.attention import CrossAttention
import ai_edge_torch.generative.layers.builder as builder
from ai_edge_torch.generative.layers.kv_cache import KVCache
import ai_edge_torch.generative.layers.model_config as cfg
Expand Down Expand Up @@ -122,7 +123,7 @@ def forward(
return hidden_states, position_bias, encoder_decoder_position_bias


class T5Attention(nn.Module):
class T5Attention(CrossAttention):

def __init__(
self,
Expand All @@ -138,51 +139,21 @@ def __init__(
Args:
dim (int): causal attention's input/output dimmension.
config (cfg.AttentionConfig): attention specific configurations.
norm_config (cfg.NormalizationConfig): normalization configure before attention.
kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
enable_hlfb (bool): whether hlfb is enabled or not.
has_relative_attention_bias (bool): whether we compute relative bias.
"""
super().__init__()
super().__init__(dim, dim, config, kv_cache_max, enable_hlfb)
self.pre_atten_norm = builder.build_norm(dim, norm_config)

self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = dim
self.head_dim = dim // config.num_heads
self.n_heads = config.num_heads
self.inner_dim = self.n_heads * self.head_dim

self.q = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
# output projection
self.proj = nn.Linear(
self.inner_dim, self.d_model, bias=config.output_proj_use_bias
)

if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(
self.relative_attention_num_buckets, self.n_heads
)

self.config = config
self.kv_cache = None
# Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
# Now only supports a max batch_size of 1.
if config.enable_kv_cache:
self.kv_cache = KVCache(
1,
kv_cache_max,
config.num_query_groups,
self.head_dim,
enable_hlfb,
)

if enable_hlfb:
self.sdpa_func = scaled_dot_product_attention_with_hlfb
else:
self.sdpa_func = scaled_dot_product_attention

def forward(
self,
x: torch.Tensor,
Expand All @@ -206,7 +177,7 @@ def forward(

x = self.pre_atten_norm(x)
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
query_states = self.q(x)
query_states = self.q_projection(x)
query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs)

if key_value_states is not None:
Expand All @@ -217,13 +188,13 @@ def forward(
) = (
key_value_states.size()
) # batch size, sequence length, embedding dimensionality (n_embd)
key_states = self.k(key_value_states)
value_states = self.v(key_value_states)
key_states = self.k_projection(key_value_states)
value_states = self.v_projection(key_value_states)
key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
else:
key_states = self.k(x)
value_states = self.v(x)
key_states = self.k_projection(x)
value_states = self.v_projection(x)
key_states = key_states.reshape(B, T, -1, self.head_dim)
value_states = value_states.reshape(B, T, -1, self.head_dim)

Expand Down Expand Up @@ -251,5 +222,5 @@ def forward(
)
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.proj(y)
y = self.output_projection(y)
return y, position_bias
123 changes: 113 additions & 10 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA


def _embed_rope(
q: torch.Tensor,
k: torch.Tensor,
n_elem: int,
rope: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed rotary positional embedding for query and key.
Args:
q (torch.Tensor): query tensor.
k (torch.Tensor): key tensor.
n_elem (int): number of elements to embed rotarty positional embedding.
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
"""
if n_elem > 0:
cos, sin = rope
q_roped = rotary_pos_emb.apply_rope(
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
k_roped = rotary_pos_emb.apply_rope(
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
return q, k


class TransformerBlock(nn.Module):

def __init__(self, config: cfg.ModelConfig) -> None:
Expand Down Expand Up @@ -178,16 +205,7 @@ def forward(

# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.head_dim)
if n_elem > 0:
cos, sin = rope
q_roped = rotary_pos_emb.apply_rope(
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
k_roped = rotary_pos_emb.apply_rope(
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
q, k = _embed_rope(q, k, n_elem, rope)

if self.kv_cache is not None:
# TODO(haoliang): Handle when execeeding max sequence length.
Expand Down Expand Up @@ -224,3 +242,88 @@ def forward(
return super().forward(
x, rope=rope, mask=torch.zeros((B, T), dtype=torch.float32), input_pos=input_pos
)


class CrossAttention(nn.Module):

def __init__(
self,
query_dim: int,
cross_dim: int,
config: cfg.AttentionConfig,
kv_cache_max: int,
enable_hlfb: bool,
) -> None:
"""Initialize an instance of CrossAttention.
Args:
query_dim (int): query tensor's dimension.
cross_dim (int): cross attention's dimensions, for key and value tensors.
config (cfg.AttentionConfig): attention specific configurations.
kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
enable_hlfb (bool): whether hlfb is enabled or not.
"""
super().__init__()
self.config = config
self.head_dim = query_dim // config.num_heads
self.n_heads = config.num_heads
self.q_projection = nn.Linear(query_dim, query_dim, bias=config.qkv_use_bias)
self.k_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias)
self.v_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias)
self.output_projection = nn.Linear(
query_dim, query_dim, bias=config.output_proj_use_bias
)

self.kv_cache = None
# Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
# Now only supports a max batch_size of 1.
if config.enable_kv_cache:
self.kv_cache = KVCache(
1,
kv_cache_max,
config.num_query_groups,
self.head_dim,
enable_hlfb,
)

if enable_hlfb:
self.sdpa_func = scaled_dot_product_attention_with_hlfb
else:
self.sdpa_func = scaled_dot_product_attention

def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
):
B, T, E = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

q = self.q_projection(x)
k = self.k_projection(x)
v = self.v_projection(x)

interim_shape = (B, T, self.n_heads, self.head_dim)
q = q.view(interim_shape)
k = k.view(interim_shape)
v = v.view(interim_shape)

if mask is None:
mask = torch.zeros((B, T), dtype=torch.float32)

# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.head_dim)
q, k = _embed_rope(q, k, n_elem, rope)

if self.kv_cache is not None:
# TODO(haoliang): Handle when execeeding max sequence length.
k, v = self.kv_cache.update_cache(input_pos, k, v)

y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
y = y.reshape(B, T, E)

# Compute the output projection.
y = self.output_projection(y)
return y
50 changes: 33 additions & 17 deletions ai_edge_torch/generative/utilities/t5_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _map_attention(
q_name = names.attn_query_proj.format(idx)
k_name = names.attn_key_proj.format(idx)
v_name = names.attn_value_proj.format(idx)
# model.encoder.transformer_blocks[0].atten_func.q.weight
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
if fuse_attention:
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
config,
Expand All @@ -334,18 +334,34 @@ def _map_attention(
state.pop(f"{v_name}.bias"),
)
else:
converted_state[f"{prefix}.atten_func.q.weight"] = state.pop(f"{q_name}.weight")
converted_state[f"{prefix}.atten_func.k.weight"] = state.pop(f"{k_name}.weight")
converted_state[f"{prefix}.atten_func.v.weight"] = state.pop(f"{v_name}.weight")
converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
f"{q_name}.weight"
)
converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
f"{k_name}.weight"
)
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
f"{v_name}.weight"
)
if config.attn_config.qkv_use_bias:
converted_state[f"{prefix}.atten_func.q.bias"] = state.pop(f"{q_name}.bias")
converted_state[f"{prefix}.atten_func.k.bias"] = state.pop(f"{k_name}.bias")
converted_state[f"{prefix}.atten_func.v.bias"] = state.pop(f"{v_name}.bias")
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
f"{q_name}.bias"
)
converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
f"{k_name}.bias"
)
converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
f"{v_name}.bias"
)

o_name = names.attn_output_proj.format(idx)
converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
f"{o_name}.weight"
)
if config.attn_config.output_proj_use_bias:
converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
f"{o_name}.bias"
)

def _map_cross_attention(
self,
Expand Down Expand Up @@ -383,32 +399,32 @@ def _map_cross_attention(
state.pop(f"{v_name}.bias"),
)
else:
converted_state[f"{prefix}.cross_atten_func.q.weight"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
f"{q_name}.weight"
)
converted_state[f"{prefix}.cross_atten_func.k.weight"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
f"{k_name}.weight"
)
converted_state[f"{prefix}.cross_atten_func.v.weight"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
f"{v_name}.weight"
)
if config.attn_config.qkv_use_bias:
converted_state[f"{prefix}.cross_atten_func.q.bias"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
f"{q_name}.bias"
)
converted_state[f"{prefix}.cross_atten_func.k.bias"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
f"{k_name}.bias"
)
converted_state[f"{prefix}.cross_atten_func.v.bias"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
f"{v_name}.bias"
)

o_name = names.cross_attn_output_proj.format(idx)
converted_state[f"{prefix}.cross_atten_func.proj.weight"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
f"{o_name}.weight"
)
if config.attn_config.output_proj_use_bias:
converted_state[f"{prefix}.cross_atten_func.proj.bias"] = state.pop(
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
f"{o_name}.bias"
)

Expand Down

0 comments on commit 6b0ba07

Please sign in to comment.