diff --git a/ai_edge_torch/generative/layers/kv_cache.py b/ai_edge_torch/generative/layers/kv_cache.py index 50433a98..66f906c4 100644 --- a/ai_edge_torch/generative/layers/kv_cache.py +++ b/ai_edge_torch/generative/layers/kv_cache.py @@ -20,6 +20,7 @@ from ai_edge_torch import hlfb from ai_edge_torch.generative.layers import model_config +from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice import torch import torch.utils._pytree as pytree @@ -159,8 +160,6 @@ def update( Returns: KVCacheEntry: The updated KVCache entry based on the passed inputs. """ - # Turn dynamic_update_slice updates off for now. - use_dus=False update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl return update_kv_cache(cache, input_pos, k_slice, v_slice)