Skip to content

Commit

Permalink
Build RoPE cos, sin tensors on demand
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706805762
  • Loading branch information
talumbau authored and copybara-github committed Dec 16, 2024
1 parent 0704751 commit 097b8ee
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 47 deletions.
33 changes: 4 additions & 29 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,6 @@
from torch import nn


def _embed_rope(
q: torch.Tensor,
k: torch.Tensor,
n_elem: int,
rope: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed rotary positional embedding for query and key.
Args:
q (torch.Tensor): query tensor.
k (torch.Tensor): key tensor.
n_elem (int): number of elements to embed rotarty positional embedding.
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
"""
if n_elem > 0:
cos, sin = rope
q_roped = rotary_pos_emb.apply_rope(
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
k_roped = rotary_pos_emb.apply_rope(
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
return q, k


class TransformerBlock(nn.Module):

def __init__(
Expand Down Expand Up @@ -238,7 +211,8 @@ def forward(
if rope is not None:
# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
q, k = _embed_rope(q, k, n_elem, rope)
cos, sin = rope
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)

if kv_cache is not None:
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
Expand Down Expand Up @@ -374,7 +348,8 @@ def forward(
if rope is not None:
# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
q, k = _embed_rope(q, k, n_elem, rope)
cos, sin = rope
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)

if kv_cache is not None:
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
Expand Down
63 changes: 54 additions & 9 deletions ai_edge_torch/generative/layers/rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,83 @@ def apply_rope(
return roped.transpose(1, 2).type_as(x)


def apply_rope_inline(
def _embed_rope(
q: torch.Tensor,
k: torch.Tensor,
n_elem: int,
rope: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed rotary positional embedding for query and key.
Args:
q (torch.Tensor): query tensor.
k (torch.Tensor): key tensor.
n_elem (int): number of elements to embed rotarty positional embedding.
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
"""
if n_elem > 0:
cos, sin = rope
q_roped = rotary_pos_emb.apply_rope(
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
k_roped = rotary_pos_emb.apply_rope(
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
return q, k


def build_rope(
input_pos: torch.Tensor,
n_elem: int,
head_dim: int,
base: int = 10_000,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes rotary positional embedding inline for a query and key.
"""Computes rotary positional embedding cosine and sine tensors.
Args:
q: the query tensor.
k: the key tensor.
input_pos: the sequence indices for the query and key
n_elem: number of elements of the head dimension for RoPE computation
base: the base of the exponentiated value for RoPE.
Returns:
output the RoPE'd query and key.
cos, sin tensors
"""

if n_elem <= 0:
return q, k
return None, None

theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
freq_exponents = (2.0 / n_elem) * torch.arange(
q.shape[-1] // 2, dtype=torch.float32
head_dim // 2, dtype=torch.float32
)
timescale = float(base) ** freq_exponents
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
0
).unsqueeze(0)
cos = torch.cos(radians).type_as(q)
sin = torch.sin(radians).type_as(q)
cos = torch.cos(radians)
sin = torch.sin(radians)
return cos, sin


def apply_rope_inline(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes rotary positional embedding inline for a query and key.
Args:
q: the query tensor.
k: the key tensor.
cos: the cosine tensor.
sin: the sine tensor.
Returns:
output the RoPE'd query and key.
"""

def apply(x, sin, cos):
x = x.transpose(1, 2)
Expand Down
17 changes: 8 additions & 9 deletions ai_edge_torch/generative/utilities/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
from torch import nn
Expand Down Expand Up @@ -85,13 +86,6 @@ def __init__(self, config: cfg.ModelConfig):
config.embedding_dim,
config.final_norm_config,
)
# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = config.block_config(0).attn_config
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max,
)
Expand All @@ -113,11 +107,16 @@ def forward(

# token embeddings of shape (b, t, n_embd)
input_embeds = self.tok_embedding(tokens)
cos, sin = self.rope_cache
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

return self.forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache, export_config
)
Expand Down

0 comments on commit 097b8ee

Please sign in to comment.