diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index 189d9937..f4c43e40 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -241,9 +241,7 @@ def forward( q, k = _embed_rope(q, k, n_elem, rope) if kv_cache is not None: - kv_cache = kv_utils.update( - kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb - ) + kv_cache = kv_utils.update(kv_cache, input_pos, k, v) k, v = kv_cache.k_cache, kv_cache.v_cache y = self.sdpa_func( @@ -379,9 +377,7 @@ def forward( q, k = _embed_rope(q, k, n_elem, rope) if kv_cache is not None: - kv_cache = kv_utils.update( - kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb - ) + kv_cache = kv_utils.update(kv_cache, input_pos, k, v) k, v = kv_cache.k_cache, kv_cache.v_cache if mask is None: mask = torch.zeros( diff --git a/ai_edge_torch/generative/layers/kv_cache.py b/ai_edge_torch/generative/layers/kv_cache.py index 2b192e69..50433a98 100644 --- a/ai_edge_torch/generative/layers/kv_cache.py +++ b/ai_edge_torch/generative/layers/kv_cache.py @@ -146,7 +146,7 @@ def update( input_pos: torch.Tensor, k_slice: torch.Tensor, v_slice: torch.Tensor, - enable_hlfb: bool = True, + use_dus: bool = True, ) -> KVCacheEntry: """Out of place update of Cache buffer. @@ -155,17 +155,14 @@ def update( input_pos (torch.Tensor): The update slice positions. k_slice (torch.Tensor): The K slice to be updated in the new cache. v_slice (torch.Tensor): The V slice to be updated in the new cache. - enable_hlfb (bool, optional): Whether the op is annotated for export with - High Level Function Boundary. Defaults to True. Returns: KVCacheEntry: The updated KVCache entry based on the passed inputs. """ - # Don't enable HLFB for kv cache op for now, since it won't work with LLM - # inference engine. Remove this part once we ship a new LLM inference engine. - enable_hlfb=False - update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl - return update_func(cache, input_pos, k_slice, v_slice) + # Turn dynamic_update_slice updates off for now. + use_dus=False + update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl + return update_kv_cache(cache, input_pos, k_slice, v_slice) def _update_kv_base_impl( @@ -181,18 +178,28 @@ def _update_kv_base_impl( return updated_cache -def _update_kv_hlfb_impl( +def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor: + """Dynamic Update Slice updates are a variadic sequence of 0-rank tensors.""" + + zero = torch.zeros([]).int() + positions = positions.int()[0].reshape([]) + return [zero, positions, zero, zero] + + +def _update_kv_impl( cache: KVCacheEntry, input_pos: torch.Tensor, k_slice: torch.Tensor, v_slice: torch.Tensor, ) -> KVCacheEntry: - """Update the cache buffer with High Level Function Boundary annotation.""" - builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache") - k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs( - cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice - ) - k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice) - v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice) - k, v = builder.mark_outputs(k, v) - return KVCacheEntry(k, v) + """Update the cache buffer for K and V caches.""" + # NB: Here assume that input_pos == range(input_pos[0], len(input_pos)) + + k_slice_indices = _get_slice_indices(input_pos) + v_slice_indices = _get_slice_indices(input_pos) + + k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices) + v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices) + + updated_cache = KVCacheEntry(k, v) + return updated_cache diff --git a/ai_edge_torch/generative/test/test_kv_cache.py b/ai_edge_torch/generative/test/test_kv_cache.py index 8a6ff088..0cbe08b4 100644 --- a/ai_edge_torch/generative/test/test_kv_cache.py +++ b/ai_edge_torch/generative/test/test_kv_cache.py @@ -71,18 +71,18 @@ def test_cache_udpate(self): [0, 0, 5, 5, 0, 0, 0, 0], ) # multi-slice update - input_pos = torch.tensor([0, 3]) + input_pos = torch.tensor([0, 1]) k_slice = v_slice = torch.full( (1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float ) updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice) self.assertEqual( updated_entry.k_cache.numpy().flatten().tolist(), - [7, 7, 0, 0, 0, 0, 7, 7], + [7, 7, 7, 7, 0, 0, 0, 0], ) self.assertEqual( updated_entry.v_cache.numpy().flatten().tolist(), - [7, 7, 0, 0, 0, 0, 7, 7], + [7, 7, 7, 7, 0, 0, 0, 0], ) def test_serialization(self): diff --git a/ai_edge_torch/generative/test/test_model_conversion.py b/ai_edge_torch/generative/test/test_model_conversion.py index 2f819dfa..a21be58b 100644 --- a/ai_edge_torch/generative/test/test_model_conversion.py +++ b/ai_edge_torch/generative/test/test_model_conversion.py @@ -100,8 +100,8 @@ def test_toy_model_with_kv_cache_with_hlfb(self): ai_edge_config.Config.use_torch_xla, reason="tests with custom ops are not supported on oss", ) - def test_toy_model_has_ekv_op(self): - """Tests that the model has the external kv cache op.""" + def test_toy_model_has_dus_op(self): + """Tests that the model has the dynamic update slice op.""" _, edge_model, _ = self._get_params(enable_hlfb=True) interpreter_ = interpreter.InterpreterWithCustomOps( custom_op_registerers=["GenAIOpsRegisterer"], @@ -111,7 +111,7 @@ def test_toy_model_has_ekv_op(self): # pylint: disable=protected-access op_names = [op["op_name"] for op in interpreter_._get_ops_details()] - self.assertIn("odml.update_external_kv_cache", op_names) + self.assertIn("DYNAMIC_UPDATE_SLICE", op_names) def _test_multisig_model(self, config, pytorch_model, atol, rtol): # prefill