Skip to content

Commit

Permalink
KV Cache updates use dynamic_update_slice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705660857
  • Loading branch information
talumbau authored and copybara-github committed Dec 13, 2024
1 parent 5a93316 commit f142f4a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 30 deletions.
8 changes: 2 additions & 6 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def forward(
q, k = _embed_rope(q, k, n_elem, rope)

if kv_cache is not None:
kv_cache = kv_utils.update(
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
)
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
k, v = kv_cache.k_cache, kv_cache.v_cache

y = self.sdpa_func(
Expand Down Expand Up @@ -379,9 +377,7 @@ def forward(
q, k = _embed_rope(q, k, n_elem, rope)

if kv_cache is not None:
kv_cache = kv_utils.update(
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
)
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
k, v = kv_cache.k_cache, kv_cache.v_cache
if mask is None:
mask = torch.zeros(
Expand Down
43 changes: 25 additions & 18 deletions ai_edge_torch/generative/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def update(
input_pos: torch.Tensor,
k_slice: torch.Tensor,
v_slice: torch.Tensor,
enable_hlfb: bool = True,
use_dus: bool = True,
) -> KVCacheEntry:
"""Out of place update of Cache buffer.
Expand All @@ -155,17 +155,14 @@ def update(
input_pos (torch.Tensor): The update slice positions.
k_slice (torch.Tensor): The K slice to be updated in the new cache.
v_slice (torch.Tensor): The V slice to be updated in the new cache.
enable_hlfb (bool, optional): Whether the op is annotated for export with
High Level Function Boundary. Defaults to True.
Returns:
KVCacheEntry: The updated KVCache entry based on the passed inputs.
"""
# Don't enable HLFB for kv cache op for now, since it won't work with LLM
# inference engine. Remove this part once we ship a new LLM inference engine.
enable_hlfb=False
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
return update_func(cache, input_pos, k_slice, v_slice)
# Turn dynamic_update_slice updates off for now.
use_dus=False
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
return update_kv_cache(cache, input_pos, k_slice, v_slice)


def _update_kv_base_impl(
Expand All @@ -181,18 +178,28 @@ def _update_kv_base_impl(
return updated_cache


def _update_kv_hlfb_impl(
def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""

zero = torch.zeros([]).int()
positions = positions.int()[0].reshape([])
return [zero, positions, zero, zero]


def _update_kv_impl(
cache: KVCacheEntry,
input_pos: torch.Tensor,
k_slice: torch.Tensor,
v_slice: torch.Tensor,
) -> KVCacheEntry:
"""Update the cache buffer with High Level Function Boundary annotation."""
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
)
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
k, v = builder.mark_outputs(k, v)
return KVCacheEntry(k, v)
"""Update the cache buffer for K and V caches."""
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))

k_slice_indices = _get_slice_indices(input_pos)
v_slice_indices = _get_slice_indices(input_pos)

k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)

updated_cache = KVCacheEntry(k, v)
return updated_cache
6 changes: 3 additions & 3 deletions ai_edge_torch/generative/test/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ def test_cache_udpate(self):
[0, 0, 5, 5, 0, 0, 0, 0],
)
# multi-slice update
input_pos = torch.tensor([0, 3])
input_pos = torch.tensor([0, 1])
k_slice = v_slice = torch.full(
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
)
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
self.assertEqual(
updated_entry.k_cache.numpy().flatten().tolist(),
[7, 7, 0, 0, 0, 0, 7, 7],
[7, 7, 7, 7, 0, 0, 0, 0],
)
self.assertEqual(
updated_entry.v_cache.numpy().flatten().tolist(),
[7, 7, 0, 0, 0, 0, 7, 7],
[7, 7, 7, 7, 0, 0, 0, 0],
)

def test_serialization(self):
Expand Down
6 changes: 3 additions & 3 deletions ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def test_toy_model_with_kv_cache_with_hlfb(self):
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
)
def test_toy_model_has_ekv_op(self):
"""Tests that the model has the external kv cache op."""
def test_toy_model_has_dus_op(self):
"""Tests that the model has the dynamic update slice op."""
_, edge_model, _ = self._get_params(enable_hlfb=True)
interpreter_ = interpreter.InterpreterWithCustomOps(
custom_op_registerers=["GenAIOpsRegisterer"],
Expand All @@ -111,7 +111,7 @@ def test_toy_model_has_ekv_op(self):

# pylint: disable=protected-access
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
self.assertIn("odml.update_external_kv_cache", op_names)
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)

def _test_multisig_model(self, config, pytorch_model, atol, rtol):
# prefill
Expand Down

0 comments on commit f142f4a

Please sign in to comment.