Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] Revert "[llama] Added the fused rotary embedding kernel (#719)" #752

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
hp,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
block_seq_stride=args.block_seq_stride,
Expand Down Expand Up @@ -218,16 +219,22 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

cache_tensors = repack_cache(cs, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=None, # We rely on causal attention
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
)
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .mmt_block_scaled_offset_q4 import *
from .mmt_block_scaled_q8 import *
from .mmt_super_block_scaled_offset_q4 import *
from .rotary import *
from .batch_matmul_transpose_b import *
from .conv_2d_nchw_fchw import *
from .pooling_nchw_sum import *
Expand Down
70 changes: 0 additions & 70 deletions sharktank/sharktank/kernels/rotary.py

This file was deleted.

63 changes: 0 additions & 63 deletions sharktank/sharktank/kernels/templates/rotary_embedding.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
k=keys, # [bs, ..., sl, dim]
v=values, # [bs, ..., sl, dim]
a=attention_mask, # [bs, ..., sl, sl]
is_causal=attention_mask is None, # assumes causal masking when true
is_causal=False, # assumes causal masking when true
scale=None, # defaults to 1/sqrt(dim)
)

Expand Down
88 changes: 57 additions & 31 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from .base import BaseLayer
from .. import ops
from .. import kernels
from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor


Expand All @@ -26,6 +25,7 @@ def __init__(
rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
):
Expand All @@ -34,44 +34,60 @@ def __init__(
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
if static_tables:
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None

@property
def rotary_embed_table(self):
return self._create_rotary_embed_table()
if self.use_table:
if self.static_tables:
return self.static_rotary_embed_table
return self._create_rotary_embed_table()

return None

def forward(
self,
*,
xt: Union[torch.Tensor, SplitPrimitiveTensor],
start_index: int,
):
table = self.rotary_embed_table
if not isinstance(xt, SplitPrimitiveTensor):
if isinstance(xt, SplitPrimitiveTensor):
rotary_shards = [None] * xt.shard_count
if self.rotary_embed_table is not None:
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xt.shard_count == self.rotary_embed_table.shard_count
)
rotary_shards = [
unbox_tensor(shard) for shard in self.rotary_embed_table.shards
]

xt_shards = [
self.forward_unsharded(
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=rotary_shard,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt
else:
return self.forward_unsharded(
xt=xt,
start_index=start_index,
rotary_embed_table=table,
)

assert (
isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count
)
rotary_shards = [unbox_tensor(shard) for shard in table.shards]

xt_shards = [
self.forward_unsharded(
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=rotary_shard,
rotary_embed_table=self.rotary_embed_table,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt

def _create_interleaved_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
Expand Down Expand Up @@ -127,17 +143,18 @@ def forward_unsharded(
# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = freqs_cis[0:sl, :]
freqs_cis = freqs_cis[None, 0:sl, None, :]
else:
freqs_cis = torch.arange(sl, device=xt.device) + start_index
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]

assert (
freqs_cis.shape[0] >= sl
freqs_cis.shape[1] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1))
xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis)
xt_ = ops.view_as_complex(xt_)
xt_ = xt_ * freqs_cis
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
Expand All @@ -164,7 +181,7 @@ def compute_batch_mask(
self.trace_tensor("rope.positions_seq", positions_seq)

if self.use_table:
freqs_cis = self.rotary_embed_table[positions_seq.flatten()]
freqs_cis = self.rotary_embed_table[positions_seq]
else:
shape = positions_seq.shape
if isinstance(positions_seq, ReplicatedTensor):
Expand All @@ -175,8 +192,11 @@ def compute_batch_mask(
freqs_cis = ReplicatedTensor(ts=ts)
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)

return freqs_cis.unsqueeze(1)
# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
return broadcast_freqs_cis

def apply_batched_mask(
self,
Expand Down Expand Up @@ -212,7 +232,9 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
if self.use_hf:
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]

xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask)
xt_ = ops.view_as_complex(xt)
xt_ = xt_ * mask
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
Expand All @@ -222,10 +244,14 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
freqs = 1.0 / (
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
freqs = torch.outer(t, freqs).float()
return freqs

cos = torch.cos(freqs)
sin = torch.sin(freqs)
complex = torch.complex(cos, sin)
return complex

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
Expand All @@ -91,6 +92,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
max_seqlen=hp.context_length,
device=self.device,
use_hf=self.use_hf,
static_tables=config.static_tables,
tensor_parallelism_size=config.tensor_parallelism_size,
),
)
Expand Down Expand Up @@ -124,7 +126,7 @@ def prefill(
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [1, 1, batch_seq_len, batch_seq_len]
attention_mask: Optional[Union[torch.Tensor, ReplicatedTensor]],
attention_mask: Union[torch.Tensor, ReplicatedTensor],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
Expand Down
Loading
Loading