Skip to content

Commit

Permalink
Build RoPE Inline with configurable function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712951237
  • Loading branch information
talumbau authored and copybara-github committed Jan 7, 2025
1 parent b183411 commit 4e38df6
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 146 deletions.
71 changes: 46 additions & 25 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

"""Example of building a Gemma2 model."""

from typing import Optional, Tuple
from typing import List, Optional, Tuple

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
Expand Down Expand Up @@ -103,17 +104,12 @@ def __init__(self, config: cfg.ModelConfig):
config.embedding_dim,
config.final_norm_config,
)
# Gemma2 has same hyper parameters for each layer except for attention
# types. Use the first layer.
attn_config = config.block_config(0).attn_config
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max,
)
# Gemma2 has same hyper parameters for each layer except for attention
# types. Use the first layer.
attn_config = config.block_config(0).attn_config
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
size=config.kv_cache_max,
window_size=attn_config.sliding_window_size,
Expand All @@ -140,29 +136,51 @@ def forward(
f"Cannot forward sequence of length {seq_len}, max seq length is only"
f" {self.config.max_seq_len}"
)

# token embeddings of shape (b, t, n_embd)
input_embeds = self.tok_embedding(tokens)
# RoPE parameters are the same for all blocks. Use the first layer.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)
mask = [
self.get_attention_mask(
self.config.block_config(i).attn_config.attn_type, input_pos
)
for i in range(self.config.num_layers)
]

return self._forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache, export_config
)

def _forward_with_embeds(
self,
input_embeds: torch.Tensor,
rope: Tuple[torch.Tensor, torch.Tensor],
mask: List[torch.Tensor],
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
export_config: Optional[model_builder.ExportConfig] = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
"""Forwards the model with input embeddings."""
assert len(self.transformer_blocks) == len(kv_cache.caches), (
"The number of transformer blocks and the number of KV cache entries"
" must be the same."
)

cos, sin = self.rope_cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)

# token embeddings of shape (b, t, n_embd)
x = self.tok_embedding(tokens)
x = x * (self.config.embedding_dim**0.5)

updated_kv_entires = []
if self.config.embedding_scale is not None:
input_embeds = input_embeds * self.config.embedding_scale
x = input_embeds
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
mask = self.get_attention_mask(
block.config.attn_config.attn_type, input_pos
)
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_entries.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down Expand Up @@ -228,11 +246,13 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
)

num_layers = 26
embedding_dim = 2304
config = cfg.ModelConfig(
vocab_size=256000,
num_layers=num_layers,
max_seq_len=8192,
embedding_dim=2304,
embedding_dim=embedding_dim,
embedding_scale=embedding_dim**0.5,
kv_cache_max_len=kv_cache_max_len,
block_configs=[get_block_config(i) for i in range(num_layers)],
final_norm_config=norm_config,
Expand All @@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
config.num_layers = 2
config.max_seq_len = 2 * kv_cache_max_len
config.embedding_dim = 128
config.embedding_scale = config.embedding_dim**0.5
config.block_configs = config.block_configs[: config.num_layers]
for block_config in config.block_configs:
block_config.attn_config.num_heads = 4
Expand Down
52 changes: 28 additions & 24 deletions ai_edge_torch/generative/examples/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Example of building Llama 3.2 models."""

from functools import partial
import math
from typing import Tuple

Expand All @@ -26,8 +27,8 @@


def _build_llama3_rope_cache(
size: int,
dim: int,
input_pos: torch.Tensor,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
Expand All @@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
low_freq_factor: float,
high_freq_factor: float,
max_seq_len: int,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precomputes Rotary Positional Embeddings for Llama 3.2 model.
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
It's a modified version of attn_utils.build_rope_cache with additional
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
Expand All @@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
Args:
size (int): The size of the built cache.
dim (int): Each sequence's dimmension.
base (int, optional): Rope base value.
condense_ratio (int, optional): The ratio by which sequence indicies are
condensed.
dtype (torch.dtype, optional): Output tensor's data type.
device (torch.device, optional): Output tensor's data type.
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int): Rope base value.
condense_ratio (int): The ratio by which sequence indicies are condensed.
dtype (torch.dtype): Output tensor's data type.
device (torch.device): Output tensor's data type.
factor (float): Factor to scale theta down for tokens in long range in the
sequence.
low_freq_factor (float): Factor to determine if tokens are in long range
Expand All @@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
low_freq_wavelen = max_seq_len / low_freq_factor
high_freq_wavelen = max_seq_len / high_freq_factor
wavelen = 2 * math.pi / theta
Expand All @@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
theta = torch.where(is_medium, smoothed_theta, theta)

seq_idx = torch.arange(size) / condense_ratio
seq_idx = input_pos / condense_ratio
idx_theta = torch.outer(seq_idx, theta)
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
Expand All @@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
def __init__(self, config: cfg.ModelConfig):
super().__init__(config)
attn_config = self.config.block_config(0).attn_config
self.rope_cache = _build_llama3_rope_cache(
size=self.config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
max_seq_len=self.config.max_seq_len,
)


def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
Expand Down Expand Up @@ -140,6 +129,20 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)

max_seq_len = 8192
# Create the RoPE callable
build_rope = partial(
_build_llama3_rope_cache,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
max_seq_len=max_seq_len,
)

config = cfg.ModelConfig(
vocab_size=128256,
num_layers=16,
Expand All @@ -149,6 +152,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
block_configs=block_config,
final_norm_config=norm_config,
enable_hlfb=True,
build_rope=build_rope,
)
return config

Expand Down
49 changes: 26 additions & 23 deletions ai_edge_torch/generative/examples/phi/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""

from functools import partial
import math
from typing import Tuple

Expand Down Expand Up @@ -93,40 +94,41 @@
]


def _build_rope_cache(
size: int,
dim: int,
def _build_phi3_rope(
input_pos: int,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
device: torch.device,
theta_factors: torch.Tensor,
scale: float,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
It's a modified version of attn_utils.build_rope_cache with additional
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
Cos values with scaling factors for quick lookup during the inference.
Args:
size (int): The size of the built cache.
dim (int): Each sequence's dimmension.
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int, optional): Rope base value.
condense_ratio (int, optional): The ratio by which sequence indicies are
condensed.
dtype (torch.dtype, optional): Output tensor's data type.
device (torch.device, optional): Output tensor's data type.
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
scale the theta values.
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
to scale the theta values.
scale (float, optional): A float used to scale the rope values.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
theta = theta / theta_factors
seq_idx = torch.arange(size) / condense_ratio
seq_idx = input_pos / condense_ratio
idx_theta = torch.outer(seq_idx, theta)
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
Expand All @@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
def __init__(self, config: cfg.ModelConfig):
super().__init__(config)
attn_config = self.config.block_config(0).attn_config
self.rope_cache = _build_rope_cache(
size=self.config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
scale=math.sqrt(
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
),
)


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
Expand Down Expand Up @@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)
max_seq_len = 4096
# Create the RoPE callable
build_rope = partial(
_build_phi3_rope,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
max_seq_len=max_seq_len,
)

config = cfg.ModelConfig(
vocab_size=32064,
num_layers=32,
max_seq_len=4096,
max_seq_len=max_seq_len,
kv_cache_max_len=kv_cache_max_len,
embedding_dim=3072,
block_configs=block_config,
final_norm_config=norm_config,
lm_head_share_weight_with_embedding=False,
enable_hlfb=True,
build_rope=build_rope,
)
return config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def forward(
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.max_seq_len]

updated_kv_entires = []
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_entries.append(kv_entry)

updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down
Loading

0 comments on commit 4e38df6

Please sign in to comment.