Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build RoPE Inline with configurable function #451

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

"""Example of building a Gemma2 model."""

from typing import Optional, Tuple
from typing import List, Optional, Tuple

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
Expand Down Expand Up @@ -103,17 +104,12 @@ def __init__(self, config: cfg.ModelConfig):
config.embedding_dim,
config.final_norm_config,
)
# Gemma2 has same hyper parameters for each layer except for attention
# types. Use the first layer.
attn_config = config.block_config(0).attn_config
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max,
)
# Gemma2 has same hyper parameters for each layer except for attention
# types. Use the first layer.
attn_config = config.block_config(0).attn_config
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
size=config.kv_cache_max,
window_size=attn_config.sliding_window_size,
Expand All @@ -140,29 +136,51 @@ def forward(
f"Cannot forward sequence of length {seq_len}, max seq length is only"
f" {self.config.max_seq_len}"
)

# token embeddings of shape (b, t, n_embd)
input_embeds = self.tok_embedding(tokens)
# RoPE parameters are the same for all blocks. Use the first layer.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)
mask = [
self.get_attention_mask(
self.config.block_config(i).attn_config.attn_type, input_pos
)
for i in range(self.config.num_layers)
]

return self._forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache, export_config
)

def _forward_with_embeds(
self,
input_embeds: torch.Tensor,
rope: Tuple[torch.Tensor, torch.Tensor],
mask: List[torch.Tensor],
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
export_config: Optional[model_builder.ExportConfig] = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
"""Forwards the model with input embeddings."""
assert len(self.transformer_blocks) == len(kv_cache.caches), (
"The number of transformer blocks and the number of KV cache entries"
" must be the same."
)

cos, sin = self.rope_cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)

# token embeddings of shape (b, t, n_embd)
x = self.tok_embedding(tokens)
x = x * (self.config.embedding_dim**0.5)

updated_kv_entires = []
if self.config.embedding_scale is not None:
input_embeds = input_embeds * self.config.embedding_scale
x = input_embeds
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
mask = self.get_attention_mask(
block.config.attn_config.attn_type, input_pos
)
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_entries.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down Expand Up @@ -228,11 +246,13 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
)

num_layers = 26
embedding_dim = 2304
config = cfg.ModelConfig(
vocab_size=256000,
num_layers=num_layers,
max_seq_len=8192,
embedding_dim=2304,
embedding_dim=embedding_dim,
embedding_scale=embedding_dim**0.5,
kv_cache_max_len=kv_cache_max_len,
block_configs=[get_block_config(i) for i in range(num_layers)],
final_norm_config=norm_config,
Expand All @@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
config.num_layers = 2
config.max_seq_len = 2 * kv_cache_max_len
config.embedding_dim = 128
config.embedding_scale = config.embedding_dim**0.5
config.block_configs = config.block_configs[: config.num_layers]
for block_config in config.block_configs:
block_config.attn_config.num_heads = 4
Expand Down
54 changes: 29 additions & 25 deletions ai_edge_torch/generative/examples/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Example of building Llama 3.2 models."""

from functools import partial
import math
from typing import Tuple

Expand All @@ -26,8 +27,8 @@


def _build_llama3_rope_cache(
size: int,
dim: int,
input_pos: torch.Tensor,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
Expand All @@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
low_freq_factor: float,
high_freq_factor: float,
max_seq_len: int,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precomputes Rotary Positional Embeddings for Llama 3.2 model.
"""Computes Rotary Positional Embeddings for Llama 3.2 model.

It's a modified version of attn_utils.build_rope_cache with additional
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
Expand All @@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307

Args:
size (int): The size of the built cache.
dim (int): Each sequence's dimmension.
base (int, optional): Rope base value.
condense_ratio (int, optional): The ratio by which sequence indicies are
condensed.
dtype (torch.dtype, optional): Output tensor's data type.
device (torch.device, optional): Output tensor's data type.
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int): Rope base value.
condense_ratio (int): The ratio by which sequence indicies are condensed.
dtype (torch.dtype): Output tensor's data type.
device (torch.device): Output tensor's data type.
factor (float): Factor to scale theta down for tokens in long range in the
sequence.
low_freq_factor (float): Factor to determine if tokens are in long range
Expand All @@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
low_freq_wavelen = max_seq_len / low_freq_factor
high_freq_wavelen = max_seq_len / high_freq_factor
wavelen = 2 * math.pi / theta
Expand All @@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
theta = torch.where(is_medium, smoothed_theta, theta)

seq_idx = torch.arange(size) / condense_ratio
seq_idx = input_pos / condense_ratio
idx_theta = torch.outer(seq_idx, theta)
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
Expand All @@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
def __init__(self, config: cfg.ModelConfig):
super().__init__(config)
attn_config = self.config.block_config(0).attn_config
self.rope_cache = _build_llama3_rope_cache(
size=self.config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
max_seq_len=self.config.max_seq_len,
)


def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
Expand Down Expand Up @@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)

max_seq_len = 8192
# Create the RoPE callable
build_rope = partial(
_build_llama3_rope_cache,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
max_seq_len=max_seq_len,
)

config = cfg.ModelConfig(
vocab_size=128256,
num_layers=16,
max_seq_len=8192,
max_seq_len=max_seq_len,
embedding_dim=2048,
kv_cache_max_len=kv_cache_max_len,
block_configs=block_config,
final_norm_config=norm_config,
enable_hlfb=True,
build_rope=build_rope,
)
return config

Expand Down
49 changes: 26 additions & 23 deletions ai_edge_torch/generative/examples/phi/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""

from functools import partial
import math
from typing import Tuple

Expand Down Expand Up @@ -93,40 +94,41 @@
]


def _build_rope_cache(
size: int,
dim: int,
def _build_phi3_rope(
input_pos: int,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
device: torch.device,
theta_factors: torch.Tensor,
scale: float,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
"""Computes Rotary Positional Embeddings for Phi-3.5 model.

It's a modified version of attn_utils.build_rope_cache with additional
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
Cos values with scaling factors for quick lookup during the inference.

Args:
size (int): The size of the built cache.
dim (int): Each sequence's dimmension.
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int, optional): Rope base value.
condense_ratio (int, optional): The ratio by which sequence indicies are
condensed.
dtype (torch.dtype, optional): Output tensor's data type.
device (torch.device, optional): Output tensor's data type.
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
scale the theta values.
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
to scale the theta values.
scale (float, optional): A float used to scale the rope values.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
theta = theta / theta_factors
seq_idx = torch.arange(size) / condense_ratio
seq_idx = input_pos / condense_ratio
idx_theta = torch.outer(seq_idx, theta)
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
Expand All @@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
def __init__(self, config: cfg.ModelConfig):
super().__init__(config)
attn_config = self.config.block_config(0).attn_config
self.rope_cache = _build_rope_cache(
size=self.config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
scale=math.sqrt(
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
),
)


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
Expand Down Expand Up @@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)
max_seq_len = 4096
# Create the RoPE callable
build_rope = partial(
_build_phi3_rope,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
max_seq_len=max_seq_len,
)

config = cfg.ModelConfig(
vocab_size=32064,
num_layers=32,
max_seq_len=4096,
max_seq_len=max_seq_len,
kv_cache_max_len=kv_cache_max_len,
embedding_dim=3072,
block_configs=block_config,
final_norm_config=norm_config,
lm_head_share_weight_with_embedding=False,
enable_hlfb=True,
build_rope=build_rope,
)
return config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def forward(
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.max_seq_len]

updated_kv_entires = []
updated_kv_entries = []
for i, block in enumerate(self.transformer_blocks):
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_entries.append(kv_entry)

updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

if export_config is not None:
if (
Expand Down
Loading
Loading