diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index f93e58e61..312a53d33 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -79,6 +79,7 @@ def main(): hp, tensor_parallelism_size=tensor_parallelism_size, use_hf=False, + static_tables=False, # Rely on the compiler for hoisting tables. kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, block_seq_stride=args.block_seq_stride, @@ -218,16 +219,22 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): else: cache_tensors = cs + sl = tokens.shape[1] + input_mask = model.input_mask(seq_lens, sl) + attention_mask = model.attention_mask(input_mask) + if llama_config.tensor_parallelism_size != 1: shard_count = llama_config.tensor_parallelism_size tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + cache_tensors = repack_cache(cs, cache_shard_dim) logits = model.prefill( tokens, - attention_mask=None, # We rely on causal attention + attention_mask=attention_mask, seq_block_ids=seq_block_ids, cache_state=cache_tensors, ) diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index 1b84f0bee..445f44852 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -10,7 +10,6 @@ from .mmt_block_scaled_offset_q4 import * from .mmt_block_scaled_q8 import * from .mmt_super_block_scaled_offset_q4 import * -from .rotary import * from .batch_matmul_transpose_b import * from .conv_2d_nchw_fchw import * from .pooling_nchw_sum import * diff --git a/sharktank/sharktank/kernels/rotary.py b/sharktank/sharktank/kernels/rotary.py deleted file mode 100644 index 196fc32c2..000000000 --- a/sharktank/sharktank/kernels/rotary.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from sharktank.kernels.base import * - -__all__ = [ - "apply_rotary_embedding", -] - - -@CustomOp.register(library=LIBRARY) -class apply_rotary_embedding(CustomOp): - - signature = "apply_rotary_embedding(Tensor input, Tensor table) -> (Tensor)" - - def select(self, ksel: KernelSelection): - inputs_desc = ksel.arg_tensor(0) - table_desc = ksel.arg_tensor(1) - out_desc = ksel.return_new_tensor( - inputs_desc.t.shape, dtype=inputs_desc.t.dtype - ) - specialize_all_known_dims(inputs_desc) - specialize_all_known_dims(table_desc) - specialize_all_known_dims(out_desc) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - - input = kb.arg_value(0) - table = kb.arg_value(1) - - input_tensor_type = RankedTensorType(input.type) - table_tensor_type = RankedTensorType(table.type) - - input_asm_type, input_ident, input_dtype = unpack_tensor_type(input.type) - table_asm_type, table_ident, table_dtype = unpack_tensor_type(table.type) - - assert input_dtype == table_dtype - - # Generate specialization signature and types. - bs = input.type.shape[0] - sl = input.type.shape[1] - sl = "D" if sl < 0 else sl - heads = input.type.shape[2] - dims = input.type.shape[3] - - template_file = "rotary_embedding.mlir" - target_function_name = ( - f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}" - ) - - # Template params. - input_tensor_type = input_asm_type - table_tensor_type = table_asm_type - - target_function = inline_template_function( - kb, - template_file, - target_function_name, - input_tensor_type=input_tensor_type, - table_tensor_type=table_tensor_type, - bs=bs, - sl=sl, - heads=heads, - dims=dims, - dtype=str(input_dtype), - ) - kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/rotary_embedding.mlir b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir deleted file mode 100644 index adec6805b..000000000 --- a/sharktank/sharktank/kernels/templates/rotary_embedding.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2024 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -!input_tensor_type = {{input_tensor_type}} -!table_tensor_type = {{table_tensor_type}} - -module { - -util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type { - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - - - %d0 = tensor.dim %input, %c0 : !input_tensor_type - %d1 = tensor.dim %input, %c1 : !input_tensor_type - %d2 = tensor.dim %input, %c2 : !input_tensor_type - %d3 = tensor.dim %input, %c3 : !input_tensor_type - - %empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor - %empty = tensor.cast %empty_dyn : tensor to {{input_tensor_type}} - - %result = linalg.generic { - indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - ], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%table : !table_tensor_type ) - outs(%empty : !input_tensor_type) { - ^bb0(%b0 : {{dtype}} , %b1 : {{dtype}}): - %0 = linalg.index 0 : index - %1 = linalg.index 1 : index - %2 = linalg.index 2 : index - %3 = linalg.index 3 : index - %div = arith.divui %3, %c2 : index - %mod = arith.remui %3, %c2 : index - %a_cosb = math.cos %b0 : {{dtype}} - %a_sinb = math.sin %b0 : {{dtype}} - %real_index = arith.muli %div, %c2 : index - %imag_index = arith.addi %real_index, %c1 : index - %real = tensor.extract %input[%0, %1, %2, %real_index] : !input_tensor_type - %imag = tensor.extract %input[%0, %1, %2, %imag_index] : !input_tensor_type - %cmp = arith.cmpi eq, %mod, %c0 : index - %real_t0 = arith.mulf %real, %a_cosb : {{dtype}} - %real_t1 = arith.mulf %imag, %a_sinb : {{dtype}} - %real_t2 = arith.subf %real_t0, %real_t1 : {{dtype}} - %imag_t0 = arith.mulf %imag, %a_cosb : {{dtype}} - %imag_t1 = arith.mulf %real, %a_sinb : {{dtype}} - %imag_t2 = arith.addf %imag_t0, %imag_t1 : {{dtype}} - %val = arith.select %cmp, %real_t2, %imag_t2 : {{dtype}} - linalg.yield %val : {{dtype}} - } -> !input_tensor_type - - util.return %result : !input_tensor_type -} - -} diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index d74e2a92d..6bd33c93f 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -221,7 +221,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: k=keys, # [bs, ..., sl, dim] v=values, # [bs, ..., sl, dim] a=attention_mask, # [bs, ..., sl, sl] - is_causal=attention_mask is None, # assumes causal masking when true + is_causal=False, # assumes causal masking when true scale=None, # defaults to 1/sqrt(dim) ) diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 623c02ea6..99ecf5057 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -11,7 +11,6 @@ from .base import BaseLayer from .. import ops -from .. import kernels from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor @@ -26,6 +25,7 @@ def __init__( rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, + static_tables: bool = False, use_table: bool = True, tensor_parallelism_size: int = 1, ): @@ -34,14 +34,26 @@ def __init__( self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf + self.static_tables = static_tables self.use_table = use_table self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 self.tensor_parallelism_size = tensor_parallelism_size + if static_tables: + ops.module_register_buffer( + self, "static_rotary_embed_table", self._create_rotary_embed_table() + ) + else: + self.static_rotary_embed_table = None @property def rotary_embed_table(self): - return self._create_rotary_embed_table() + if self.use_table: + if self.static_tables: + return self.static_rotary_embed_table + return self._create_rotary_embed_table() + + return None def forward( self, @@ -49,29 +61,33 @@ def forward( xt: Union[torch.Tensor, SplitPrimitiveTensor], start_index: int, ): - table = self.rotary_embed_table - if not isinstance(xt, SplitPrimitiveTensor): + if isinstance(xt, SplitPrimitiveTensor): + rotary_shards = [None] * xt.shard_count + if self.rotary_embed_table is not None: + assert ( + isinstance(self.rotary_embed_table, ReplicatedTensor) + and xt.shard_count == self.rotary_embed_table.shard_count + ) + rotary_shards = [ + unbox_tensor(shard) for shard in self.rotary_embed_table.shards + ] + + xt_shards = [ + self.forward_unsharded( + xt=unbox_tensor(xt_shard), + start_index=start_index, + rotary_embed_table=rotary_shard, + ) + for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) + ] + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt + else: return self.forward_unsharded( xt=xt, start_index=start_index, - rotary_embed_table=table, - ) - - assert ( - isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count - ) - rotary_shards = [unbox_tensor(shard) for shard in table.shards] - - xt_shards = [ - self.forward_unsharded( - xt=unbox_tensor(xt_shard), - start_index=start_index, - rotary_embed_table=rotary_shard, + rotary_embed_table=self.rotary_embed_table, ) - for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) - ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) - return xt def _create_interleaved_tensor(_, dim): """Creates a tensor which indexes an tensor such that @@ -127,17 +143,18 @@ def forward_unsharded( # Offset the table based on starting position. if self.use_table: freqs_cis = rotary_embed_table[start_index : start_index + sl, :] - freqs_cis = freqs_cis[0:sl, :] + freqs_cis = freqs_cis[None, 0:sl, None, :] else: freqs_cis = torch.arange(sl, device=xt.device) + start_index - freqs_cis = self._compute_rotary_embed_table(freqs_cis) + freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :] assert ( - freqs_cis.shape[0] >= sl + freqs_cis.shape[1] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" - freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1)) - xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis) + xt_ = ops.view_as_complex(xt_) + xt_ = xt_ * freqs_cis + xt_out = ops.view_as_real(xt_) if self.use_hf: xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] @@ -164,7 +181,7 @@ def compute_batch_mask( self.trace_tensor("rope.positions_seq", positions_seq) if self.use_table: - freqs_cis = self.rotary_embed_table[positions_seq.flatten()] + freqs_cis = self.rotary_embed_table[positions_seq] else: shape = positions_seq.shape if isinstance(positions_seq, ReplicatedTensor): @@ -175,8 +192,11 @@ def compute_batch_mask( freqs_cis = ReplicatedTensor(ts=ts) else: freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) + freqs_cis = freqs_cis.unflatten(0, shape) - return freqs_cis.unsqueeze(1) + # Unsqueeze a unit dim for attention heads. + broadcast_freqs_cis = freqs_cis.unsqueeze(2) + return broadcast_freqs_cis def apply_batched_mask( self, @@ -212,7 +232,9 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): if self.use_hf: xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])] - xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask) + xt_ = ops.view_as_complex(xt) + xt_ = xt_ * mask + xt_out = ops.view_as_real(xt_) if self.use_hf: xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] @@ -222,10 +244,14 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count freqs = 1.0 / ( - self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0) + self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) freqs = torch.outer(t, freqs).float() - return freqs + + cos = torch.cos(freqs) + sin = torch.sin(freqs) + complex = torch.complex(cos, sin) + return complex def _create_rotary_embed_table(self): t = torch.arange(self.max_seqlen, device=self.device) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 6fef6704e..0a9a6f1c3 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -67,6 +67,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, + static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, @@ -91,6 +92,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, + static_tables=config.static_tables, tensor_parallelism_size=config.tensor_parallelism_size, ), ) @@ -124,7 +126,7 @@ def prefill( tokens: Union[torch.Tensor, ReplicatedTensor], *, # [1, 1, batch_seq_len, batch_seq_len] - attention_mask: Optional[Union[torch.Tensor, ReplicatedTensor]], + attention_mask: Union[torch.Tensor, ReplicatedTensor], # [bs, batch_seq_len // block_seq_stride] seq_block_ids: Union[torch.Tensor, ReplicatedTensor], cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/tests/kernels/rotary.py b/sharktank/tests/kernels/rotary.py deleted file mode 100644 index 6c3d032a3..000000000 --- a/sharktank/tests/kernels/rotary.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging - -logging.basicConfig(level=logging.DEBUG) - -import torch -import unittest - -from sharktank import kernels -from sharktank import ops - - -class rotary_test(unittest.TestCase): - def setUp(self): - torch.manual_seed(42) - - def test_rotary(self): - dtype = torch.float32 - a = torch.rand([1, 128, 1, 64], dtype=dtype) - rot = torch.rand([128, 32], dtype=dtype) - res_b = ops.view_as_real(torch.complex(rot, rot)) - ref_b = torch.complex(torch.cos(rot), torch.sin(rot)) - - result = kernels.apply_rotary_embedding(a, res_b) - ref = ops.view_as_real(ops.view_as_complex(a) * ref_b[None, :, None, :]) - torch.testing.assert_close(result, ref)