Skip to content

Commit

Permalink
KV Cache updates use dynamic_update_slice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698172808
  • Loading branch information
talumbau authored and copybara-github committed Dec 11, 2024
1 parent 924801e commit 1fa9d72
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 34 deletions.
8 changes: 2 additions & 6 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def forward(
q, k = _embed_rope(q, k, n_elem, rope)

if kv_cache is not None:
kv_cache = kv_utils.update(
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
)
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
k, v = kv_cache.k_cache, kv_cache.v_cache

y = self.sdpa_func(
Expand Down Expand Up @@ -379,9 +377,7 @@ def forward(
q, k = _embed_rope(q, k, n_elem, rope)

if kv_cache is not None:
kv_cache = kv_utils.update(
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
)
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
k, v = kv_cache.k_cache, kv_cache.v_cache
if mask is None:
mask = torch.zeros(
Expand Down
49 changes: 21 additions & 28 deletions ai_edge_torch/generative/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ai_edge_torch import hlfb
from ai_edge_torch.generative.layers import model_config
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -146,7 +147,6 @@ def update(
input_pos: torch.Tensor,
k_slice: torch.Tensor,
v_slice: torch.Tensor,
enable_hlfb: bool = True,
) -> KVCacheEntry:
"""Out of place update of Cache buffer.
Expand All @@ -155,17 +155,21 @@ def update(
input_pos (torch.Tensor): The update slice positions.
k_slice (torch.Tensor): The K slice to be updated in the new cache.
v_slice (torch.Tensor): The V slice to be updated in the new cache.
enable_hlfb (bool, optional): Whether the op is annotated for export with
High Level Function Boundary. Defaults to True.
Returns:
KVCacheEntry: The updated KVCache entry based on the passed inputs.
"""
# Don't enable HLFB for kv cache op for now, since it won't work with LLM
# inference engine. Remove this part once we ship a new LLM inference engine.
enable_hlfb=False
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
return update_func(cache, input_pos, k_slice, v_slice)
return _update_kv_base_impl(cache, input_pos, k_slice, v_slice)


def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""

def _zero():
return torch.zeros(()).int()

positions = positions.int()[0].reshape([])
return [_zero(), positions, _zero(), _zero()]


def _update_kv_base_impl(
Expand All @@ -174,25 +178,14 @@ def _update_kv_base_impl(
k_slice: torch.Tensor,
v_slice: torch.Tensor,
) -> KVCacheEntry:
"""Update the cache buffer without High Level Function Boundary annotation."""
k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
updated_cache = KVCacheEntry(k, v)
return updated_cache
"""Update the cache buffer for K and V caches."""
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))

k_slice_indices = _get_slice_indices(input_pos)
v_slice_indices = _get_slice_indices(input_pos)

def _update_kv_hlfb_impl(
cache: KVCacheEntry,
input_pos: torch.Tensor,
k_slice: torch.Tensor,
v_slice: torch.Tensor,
) -> KVCacheEntry:
"""Update the cache buffer with High Level Function Boundary annotation."""
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
)
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
k, v = builder.mark_outputs(k, v)
return KVCacheEntry(k, v)
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)

updated_cache = KVCacheEntry(k, v)
return updated_cache

0 comments on commit 1fa9d72

Please sign in to comment.