diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index f4c43e40..b52e04b4 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -26,33 +26,6 @@ from torch import nn -def _embed_rope( - q: torch.Tensor, - k: torch.Tensor, - n_elem: int, - rope: Tuple[torch.Tensor, torch.Tensor], -) -> Tuple[torch.Tensor, torch.Tensor]: - """Embed rotary positional embedding for query and key. - - Args: - q (torch.Tensor): query tensor. - k (torch.Tensor): key tensor. - n_elem (int): number of elements to embed rotarty positional embedding. - rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor. - """ - if n_elem > 0: - cos, sin = rope - q_roped = rotary_pos_emb.apply_rope( - q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - k_roped = rotary_pos_emb.apply_rope( - k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) - k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) - return q, k - - class TransformerBlock(nn.Module): def __init__( @@ -238,7 +211,8 @@ def forward( if rope is not None: # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.config.head_dim) - q, k = _embed_rope(q, k, n_elem, rope) + cos, sin = rope + q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin) if kv_cache is not None: kv_cache = kv_utils.update(kv_cache, input_pos, k, v) @@ -374,7 +348,8 @@ def forward( if rope is not None: # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.config.head_dim) - q, k = _embed_rope(q, k, n_elem, rope) + cos, sin = rope + q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin) if kv_cache is not None: kv_cache = kv_utils.update(kv_cache, input_pos, k, v) diff --git a/ai_edge_torch/generative/layers/rotary_position_embedding.py b/ai_edge_torch/generative/layers/rotary_position_embedding.py index c06dc818..f9b8fb05 100644 --- a/ai_edge_torch/generative/layers/rotary_position_embedding.py +++ b/ai_edge_torch/generative/layers/rotary_position_embedding.py @@ -39,38 +39,83 @@ def apply_rope( return roped.transpose(1, 2).type_as(x) -def apply_rope_inline( +def _embed_rope( q: torch.Tensor, k: torch.Tensor, + n_elem: int, + rope: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed rotary positional embedding for query and key. + + Args: + q (torch.Tensor): query tensor. + k (torch.Tensor): key tensor. + n_elem (int): number of elements to embed rotarty positional embedding. + rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor. + """ + if n_elem > 0: + cos, sin = rope + q_roped = rotary_pos_emb.apply_rope( + q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) + ) + k_roped = rotary_pos_emb.apply_rope( + k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) + ) + q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) + k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + return q, k + + +def build_rope( input_pos: torch.Tensor, n_elem: int, + head_dim: int, base: int = 10_000, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes rotary positional embedding inline for a query and key. + """Computes rotary positional embedding cosine and sine tensors. Args: - q: the query tensor. - k: the key tensor. input_pos: the sequence indices for the query and key n_elem: number of elements of the head dimension for RoPE computation + base: the base of the exponentiated value for RoPE. Returns: - output the RoPE'd query and key. + cos, sin tensors """ if n_elem <= 0: - return q, k + return None, None theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem)) freq_exponents = (2.0 / n_elem) * torch.arange( - q.shape[-1] // 2, dtype=torch.float32 + head_dim // 2, dtype=torch.float32 ) timescale = float(base) ** freq_exponents radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze( 0 ).unsqueeze(0) - cos = torch.cos(radians).type_as(q) - sin = torch.sin(radians).type_as(q) + cos = torch.cos(radians) + sin = torch.sin(radians) + return cos, sin + + +def apply_rope_inline( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes rotary positional embedding inline for a query and key. + + Args: + q: the query tensor. + k: the key tensor. + cos: the cosine tensor. + sin: the sine tensor. + + Returns: + output the RoPE'd query and key. + """ def apply(x, sin, cos): x = x.transpose(1, 2) diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index 2714eeaa..ca2768c0 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb import ai_edge_torch.generative.utilities.loader as loading_utils import torch from torch import nn @@ -85,13 +86,6 @@ def __init__(self, config: cfg.ModelConfig): config.embedding_dim, config.final_norm_config, ) - # ROPE parameters for all attn_configs are the same. Take the first one. - attn_config = config.block_config(0).attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) self.mask_cache = attn_utils.build_causal_mask_cache( size=config.kv_cache_max, ) @@ -113,11 +107,16 @@ def forward( # token embeddings of shape (b, t, n_embd) input_embeds = self.tok_embedding(tokens) - cos, sin = self.rope_cache - rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos)) mask = self.mask_cache.index_select(2, input_pos) mask = mask[:, :, :, : self.config.kv_cache_max] + # ROPE parameters for all attn_configs are the same. Take the first one. + attn_config = self.config.block_config(0).attn_config + n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) + rope = rotary_pos_emb.build_rope( + input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base + ) + return self.forward_with_embeds( input_embeds, rope, mask, input_pos, kv_cache, export_config )