Skip to content

Commit

Permalink
Build RoPE cos, sin tensors on demand
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706805762
  • Loading branch information
talumbau authored and copybara-github committed Dec 19, 2024
1 parent 6abeb94 commit 4d49b54
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 86 deletions.
29 changes: 14 additions & 15 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
Expand Down Expand Up @@ -103,17 +104,12 @@ def __init__(self, config: cfg.ModelConfig):
config.embedding_dim,
config.final_norm_config,
)
# Gemma2 has same hyper parameters for each layer except for attention
# types. Use the first layer.
attn_config = config.block_config(0).attn_config
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max,
)
# Gemma2 has same hyper parameters for each layer except for attention
# types. Use the first layer.
attn_config = config.block_config(0).attn_config
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
size=config.kv_cache_max,
window_size=attn_config.sliding_window_size,
Expand Down Expand Up @@ -145,24 +141,27 @@ def forward(
" must be the same."
)

cos, sin = self.rope_cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
# RoPE parameters are the same for all blocks. Use the first layer.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

# token embeddings of shape (b, t, n_embd)
x = self.tok_embedding(tokens)
x = x * (self.config.embedding_dim**0.5)

updated_kv_entires = []
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
mask = self.get_attention_mask(
block.config.attn_config.attn_type, input_pos
)
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_entries.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def forward(
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.max_seq_len]

updated_kv_entires = []
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_entries.append(kv_entry)

updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down
33 changes: 4 additions & 29 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,6 @@
from torch import nn


def _embed_rope(
q: torch.Tensor,
k: torch.Tensor,
n_elem: int,
rope: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed rotary positional embedding for query and key.
Args:
q (torch.Tensor): query tensor.
k (torch.Tensor): key tensor.
n_elem (int): number of elements to embed rotarty positional embedding.
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
"""
if n_elem > 0:
cos, sin = rope
q_roped = rotary_pos_emb.apply_rope(
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
k_roped = rotary_pos_emb.apply_rope(
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
return q, k


class TransformerBlock(nn.Module):

def __init__(
Expand Down Expand Up @@ -238,7 +211,8 @@ def forward(
if rope is not None:
# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
q, k = _embed_rope(q, k, n_elem, rope)
cos, sin = rope
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)

if kv_cache is not None:
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
Expand Down Expand Up @@ -374,7 +348,8 @@ def forward(
if rope is not None:
# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
q, k = _embed_rope(q, k, n_elem, rope)
cos, sin = rope
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)

if kv_cache is not None:
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
Expand Down
61 changes: 34 additions & 27 deletions ai_edge_torch/generative/layers/rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,57 +32,64 @@ def apply_rope(
"""
x = x.transpose(1, 2)
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x * cos) + (rotated * sin)
x1, x2 = torch.split(x, head_size // 2, dim=-1)
left = x1 * cos - x2 * sin
right = x2 * cos + x1 * sin
roped = torch.cat([left, right], dim=-1)
return roped.transpose(1, 2).type_as(x)


def apply_rope_inline(
q: torch.Tensor,
k: torch.Tensor,
def build_rope(
input_pos: torch.Tensor,
n_elem: int,
head_dim: int,
base: int = 10_000,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes rotary positional embedding inline for a query and key.
"""Computes rotary positional embedding cosine and sine tensors.
Args:
q: the query tensor.
k: the key tensor.
input_pos: the sequence indices for the query and key
n_elem: number of elements of the head dimension for RoPE computation
base: the base of the exponentiated value for RoPE.
Returns:
output the RoPE'd query and key.
cos, sin tensors
"""

if n_elem <= 0:
return q, k
return None, None

theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
freq_exponents = (2.0 / n_elem) * torch.arange(
q.shape[-1] // 2, dtype=torch.float32
head_dim // 2, dtype=torch.float32
)
timescale = float(base) ** freq_exponents
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
0
).unsqueeze(0)
cos = torch.cos(radians).type_as(q)
sin = torch.sin(radians).type_as(q)
cos = torch.cos(radians)
sin = torch.sin(radians)
return cos, sin


def apply(x, sin, cos):
x = x.transpose(1, 2)
b, h, s, d = x.shape
ans = torch.split(x, d // 2, dim=-1)
x1, x2 = ans
left = x1 * cos - x2 * sin
right = x2 * cos + x1 * sin
res = torch.cat([left, right], dim=-1)
res = res.transpose(1, 2)
return res
def apply_rope_inline(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes rotary positional embedding inline for a query and key.
Args:
q: the query tensor.
k: the key tensor.
cos: the cosine tensor.
sin: the sine tensor.
Returns:
output the RoPE'd query and key.
"""

q_roped = apply(q, sin, cos)
k_roped = apply(k, sin, cos)
q_roped = apply_rope(q, cos, sin)
k_roped = apply_rope(k, cos, sin)
return q_roped, k_roped
23 changes: 11 additions & 12 deletions ai_edge_torch/generative/utilities/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
from torch import nn
Expand Down Expand Up @@ -85,13 +86,6 @@ def __init__(self, config: cfg.ModelConfig):
config.embedding_dim,
config.final_norm_config,
)
# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = config.block_config(0).attn_config
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max,
)
Expand All @@ -113,11 +107,16 @@ def forward(

# token embeddings of shape (b, t, n_embd)
input_embeds = self.tok_embedding(tokens)
cos, sin = self.rope_cache
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

return self.forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache, export_config
)
Expand All @@ -141,13 +140,13 @@ def forward_with_embeds(
if self.config.embedding_scale is not None:
x = x * self.config.embedding_scale

updated_kv_entires = []
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_entries.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down

0 comments on commit 4d49b54

Please sign in to comment.