From 4e38df691efb2df81e982b21b3c81d39eb468351 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Tue, 7 Jan 2025 10:04:37 -0800 Subject: [PATCH] Build RoPE Inline with configurable function PiperOrigin-RevId: 712951237 --- .../generative/examples/gemma/gemma2.py | 71 ++++++++++++------- .../generative/examples/llama/llama.py | 52 +++++++------- ai_edge_torch/generative/examples/phi/phi3.py | 49 +++++++------ .../test_models/toy_model_with_kv_cache.py | 6 +- ai_edge_torch/generative/layers/attention.py | 33 ++------- .../generative/layers/model_config.py | 8 ++- .../layers/rotary_position_embedding.py | 62 ++++++++-------- .../generative/utilities/model_builder.py | 28 ++++---- 8 files changed, 163 insertions(+), 146 deletions(-) diff --git a/ai_edge_torch/generative/examples/gemma/gemma2.py b/ai_edge_torch/generative/examples/gemma/gemma2.py index a934c835..2d5b3692 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma2.py +++ b/ai_edge_torch/generative/examples/gemma/gemma2.py @@ -15,13 +15,14 @@ """Example of building a Gemma2 model.""" -from typing import Optional, Tuple +from typing import List, Optional, Tuple from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils import torch @@ -103,17 +104,12 @@ def __init__(self, config: cfg.ModelConfig): config.embedding_dim, config.final_norm_config, ) - # Gemma2 has same hyper parameters for each layer except for attention - # types. Use the first layer. - attn_config = config.block_config(0).attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) self.mask_cache = attn_utils.build_causal_mask_cache( size=config.kv_cache_max, ) + # Gemma2 has same hyper parameters for each layer except for attention + # types. Use the first layer. + attn_config = config.block_config(0).attn_config self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache( size=config.kv_cache_max, window_size=attn_config.sliding_window_size, @@ -140,29 +136,51 @@ def forward( f"Cannot forward sequence of length {seq_len}, max seq length is only" f" {self.config.max_seq_len}" ) + + # token embeddings of shape (b, t, n_embd) + input_embeds = self.tok_embedding(tokens) + # RoPE parameters are the same for all blocks. Use the first layer. + attn_config = self.config.block_config(0).attn_config + n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) + rope = rotary_pos_emb.build_rope( + input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base + ) + mask = [ + self.get_attention_mask( + self.config.block_config(i).attn_config.attn_type, input_pos + ) + for i in range(self.config.num_layers) + ] + + return self._forward_with_embeds( + input_embeds, rope, mask, input_pos, kv_cache, export_config + ) + + def _forward_with_embeds( + self, + input_embeds: torch.Tensor, + rope: Tuple[torch.Tensor, torch.Tensor], + mask: List[torch.Tensor], + input_pos: torch.Tensor, + kv_cache: kv_utils.KVCache, + export_config: Optional[model_builder.ExportConfig] = None, + ) -> dict[torch.Tensor, kv_utils.KVCache]: + """Forwards the model with input embeddings.""" assert len(self.transformer_blocks) == len(kv_cache.caches), ( "The number of transformer blocks and the number of KV cache entries" " must be the same." ) - cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) - - # token embeddings of shape (b, t, n_embd) - x = self.tok_embedding(tokens) - x = x * (self.config.embedding_dim**0.5) - - updated_kv_entires = [] + if self.config.embedding_scale is not None: + input_embeds = input_embeds * self.config.embedding_scale + x = input_embeds + updated_kv_entries = [] for i, block in enumerate(self.transformer_blocks): - mask = self.get_attention_mask( - block.config.attn_config.attn_type, input_pos - ) kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) + x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry) if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) + updated_kv_entries.append(kv_entry) + updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries)) if export_config is not None: if ( @@ -228,11 +246,13 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig: ) num_layers = 26 + embedding_dim = 2304 config = cfg.ModelConfig( vocab_size=256000, num_layers=num_layers, max_seq_len=8192, - embedding_dim=2304, + embedding_dim=embedding_dim, + embedding_scale=embedding_dim**0.5, kv_cache_max_len=kv_cache_max_len, block_configs=[get_block_config(i) for i in range(num_layers)], final_norm_config=norm_config, @@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: config.num_layers = 2 config.max_seq_len = 2 * kv_cache_max_len config.embedding_dim = 128 + config.embedding_scale = config.embedding_dim**0.5 config.block_configs = config.block_configs[: config.num_layers] for block_config in config.block_configs: block_config.attn_config.num_heads = 4 diff --git a/ai_edge_torch/generative/examples/llama/llama.py b/ai_edge_torch/generative/examples/llama/llama.py index 4d2dc10e..762353b3 100644 --- a/ai_edge_torch/generative/examples/llama/llama.py +++ b/ai_edge_torch/generative/examples/llama/llama.py @@ -15,6 +15,7 @@ """Example of building Llama 3.2 models.""" +from functools import partial import math from typing import Tuple @@ -26,8 +27,8 @@ def _build_llama3_rope_cache( - size: int, - dim: int, + input_pos: torch.Tensor, + n_elem: int, base: int, condense_ratio: int, dtype: torch.dtype, @@ -36,8 +37,9 @@ def _build_llama3_rope_cache( low_freq_factor: float, high_freq_factor: float, max_seq_len: int, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Precomputes Rotary Positional Embeddings for Llama 3.2 model. + """Computes Rotary Positional Embeddings for Llama 3.2 model. It's a modified version of attn_utils.build_rope_cache with additional arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin @@ -47,13 +49,12 @@ def _build_llama3_rope_cache( https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307 Args: - size (int): The size of the built cache. - dim (int): Each sequence's dimmension. - base (int, optional): Rope base value. - condense_ratio (int, optional): The ratio by which sequence indicies are - condensed. - dtype (torch.dtype, optional): Output tensor's data type. - device (torch.device, optional): Output tensor's data type. + input_pos (torch.Tensor): the given input sequence positions + n_elem (int): Each sequence's dimmension. + base (int): Rope base value. + condense_ratio (int): The ratio by which sequence indicies are condensed. + dtype (torch.dtype): Output tensor's data type. + device (torch.device): Output tensor's data type. factor (float): Factor to scale theta down for tokens in long range in the sequence. low_freq_factor (float): Factor to determine if tokens are in long range @@ -66,7 +67,7 @@ def _build_llama3_rope_cache( Returns: Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves. """ - theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem)) low_freq_wavelen = max_seq_len / low_freq_factor high_freq_wavelen = max_seq_len / high_freq_factor wavelen = 2 * math.pi / theta @@ -81,7 +82,7 @@ def _build_llama3_rope_cache( is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) theta = torch.where(is_medium, smoothed_theta, theta) - seq_idx = torch.arange(size) / condense_ratio + seq_idx = input_pos / condense_ratio idx_theta = torch.outer(seq_idx, theta) cos = torch.cos(idx_theta).to(dtype=dtype, device=device) sin = torch.sin(idx_theta).to(dtype=dtype, device=device) @@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel): def __init__(self, config: cfg.ModelConfig): super().__init__(config) attn_config = self.config.block_config(0).attn_config - self.rope_cache = _build_llama3_rope_cache( - size=self.config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - condense_ratio=1, - dtype=torch.float32, - device=torch.device("cpu"), - factor=32.0, - low_freq_factor=1.0, - high_freq_factor=4.0, - max_seq_len=self.config.max_seq_len, - ) def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: @@ -140,6 +129,20 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: pre_attention_norm_config=norm_config, post_attention_norm_config=norm_config, ) + + max_seq_len = 8192 + # Create the RoPE callable + build_rope = partial( + _build_llama3_rope_cache, + condense_ratio=1, + dtype=torch.float32, + device=torch.device("cpu"), + factor=32.0, + low_freq_factor=1.0, + high_freq_factor=4.0, + max_seq_len=max_seq_len, + ) + config = cfg.ModelConfig( vocab_size=128256, num_layers=16, @@ -149,6 +152,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: block_configs=block_config, final_norm_config=norm_config, enable_hlfb=True, + build_rope=build_rope, ) return config diff --git a/ai_edge_torch/generative/examples/phi/phi3.py b/ai_edge_torch/generative/examples/phi/phi3.py index a8e1620b..5cd01d61 100644 --- a/ai_edge_torch/generative/examples/phi/phi3.py +++ b/ai_edge_torch/generative/examples/phi/phi3.py @@ -15,6 +15,7 @@ """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens.""" +from functools import partial import math from typing import Tuple @@ -93,40 +94,41 @@ ] -def _build_rope_cache( - size: int, - dim: int, +def _build_phi3_rope( + input_pos: int, + n_elem: int, base: int, condense_ratio: int, dtype: torch.dtype, device: torch.device, theta_factors: torch.Tensor, scale: float, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Precomputes Rotary Positional Embeddings for Phi-3.5 model. + """Computes Rotary Positional Embeddings for Phi-3.5 model. It's a modified version of attn_utils.build_rope_cache with additional arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and Cos values with scaling factors for quick lookup during the inference. Args: - size (int): The size of the built cache. - dim (int): Each sequence's dimmension. + input_pos (torch.Tensor): the given input sequence positions + n_elem (int): Each sequence's dimmension. base (int, optional): Rope base value. condense_ratio (int, optional): The ratio by which sequence indicies are condensed. dtype (torch.dtype, optional): Output tensor's data type. device (torch.device, optional): Output tensor's data type. - theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to - scale the theta values. + theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used + to scale the theta values. scale (float, optional): A float used to scale the rope values. Returns: Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves. """ - theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem)) theta = theta / theta_factors - seq_idx = torch.arange(size) / condense_ratio + seq_idx = input_pos / condense_ratio idx_theta = torch.outer(seq_idx, theta) cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale @@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel): def __init__(self, config: cfg.ModelConfig): super().__init__(config) attn_config = self.config.block_config(0).attn_config - self.rope_cache = _build_rope_cache( - size=self.config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - condense_ratio=1, - dtype=torch.float32, - device=torch.device("cpu"), - theta_factors=torch.tensor(ROPE_SHORT_FACTOR), - scale=math.sqrt( - 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len) - ), - ) def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: @@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: pre_attention_norm_config=norm_config, post_attention_norm_config=norm_config, ) + max_seq_len = 4096 + # Create the RoPE callable + build_rope = partial( + _build_phi3_rope, + condense_ratio=1, + dtype=torch.float32, + device=torch.device("cpu"), + theta_factors=torch.tensor(ROPE_SHORT_FACTOR), + scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)), + max_seq_len=max_seq_len, + ) + config = cfg.ModelConfig( vocab_size=32064, num_layers=32, - max_seq_len=4096, + max_seq_len=max_seq_len, kv_cache_max_len=kv_cache_max_len, embedding_dim=3072, block_configs=block_config, final_norm_config=norm_config, lm_head_share_weight_with_embedding=False, enable_hlfb=True, + build_rope=build_rope, ) return config diff --git a/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py b/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py index d4d44845..aeb8431e 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py @@ -72,14 +72,14 @@ def forward( mask = self.mask_cache.index_select(2, input_pos) mask = mask[:, :, :, : self.config.max_seq_len] - updated_kv_entires = [] + updated_kv_entries = [] for i, block in enumerate(self.transformer_blocks): kv_entry = kv_cache.caches[i] if kv_cache else None x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) if kv_entry: - updated_kv_entires.append(kv_entry) + updated_kv_entries.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) + updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries)) if export_config is not None: if ( diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index f4c43e40..b52e04b4 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -26,33 +26,6 @@ from torch import nn -def _embed_rope( - q: torch.Tensor, - k: torch.Tensor, - n_elem: int, - rope: Tuple[torch.Tensor, torch.Tensor], -) -> Tuple[torch.Tensor, torch.Tensor]: - """Embed rotary positional embedding for query and key. - - Args: - q (torch.Tensor): query tensor. - k (torch.Tensor): key tensor. - n_elem (int): number of elements to embed rotarty positional embedding. - rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor. - """ - if n_elem > 0: - cos, sin = rope - q_roped = rotary_pos_emb.apply_rope( - q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - k_roped = rotary_pos_emb.apply_rope( - k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) - k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) - return q, k - - class TransformerBlock(nn.Module): def __init__( @@ -238,7 +211,8 @@ def forward( if rope is not None: # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.config.head_dim) - q, k = _embed_rope(q, k, n_elem, rope) + cos, sin = rope + q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin) if kv_cache is not None: kv_cache = kv_utils.update(kv_cache, input_pos, k, v) @@ -374,7 +348,8 @@ def forward( if rope is not None: # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.config.head_dim) - q, k = _embed_rope(q, k, n_elem, rope) + cos, sin = rope + q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin) if kv_cache is not None: kv_cache = kv_utils.update(kv_cache, input_pos, k, v) diff --git a/ai_edge_torch/generative/layers/model_config.py b/ai_edge_torch/generative/layers/model_config.py index 78d62701..2eedce92 100644 --- a/ai_edge_torch/generative/layers/model_config.py +++ b/ai_edge_torch/generative/layers/model_config.py @@ -17,8 +17,8 @@ import dataclasses import enum -from typing import Optional, Sequence, Union - +from typing import Callable, Optional, Sequence, Union +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb @enum.unique class ActivationType(enum.Enum): @@ -218,6 +218,10 @@ class ModelConfig: # Softcap on the model output logits. final_logit_softcap: Optional[float] = None + # The function to call to create the RoPE sin and cos vectors during the + # forward pass. Defaults to a standard implementation. + build_rope: Callable = rotary_pos_emb.build_rope + @property def kv_cache_max(self) -> int: if self.kv_cache_max_len > 0: diff --git a/ai_edge_torch/generative/layers/rotary_position_embedding.py b/ai_edge_torch/generative/layers/rotary_position_embedding.py index c06dc818..ddf1ec79 100644 --- a/ai_edge_torch/generative/layers/rotary_position_embedding.py +++ b/ai_edge_torch/generative/layers/rotary_position_embedding.py @@ -32,57 +32,63 @@ def apply_rope( """ x = x.transpose(1, 2) head_size = x.size(-1) - x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - roped = (x * cos) + (rotated * sin) + x1, x2 = torch.split(x, head_size // 2, dim=-1) + left = x1 * cos - x2 * sin + right = x2 * cos + x1 * sin + roped = torch.cat([left, right], dim=-1) return roped.transpose(1, 2).type_as(x) -def apply_rope_inline( - q: torch.Tensor, - k: torch.Tensor, +def build_rope( input_pos: torch.Tensor, n_elem: int, + head_dim: int, base: int = 10_000, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes rotary positional embedding inline for a query and key. + """Computes rotary positional embedding cosine and sine tensors. Args: - q: the query tensor. - k: the key tensor. input_pos: the sequence indices for the query and key n_elem: number of elements of the head dimension for RoPE computation + base: the base of the exponentiated value for RoPE. Returns: - output the RoPE'd query and key. + cos, sin tensors """ if n_elem <= 0: - return q, k + return None, None - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem)) freq_exponents = (2.0 / n_elem) * torch.arange( - q.shape[-1] // 2, dtype=torch.float32 + head_dim // 2, dtype=torch.float32 ) timescale = float(base) ** freq_exponents radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze( 0 ).unsqueeze(0) - cos = torch.cos(radians).type_as(q) - sin = torch.sin(radians).type_as(q) + cos = torch.cos(radians) + sin = torch.sin(radians) + return cos, sin + - def apply(x, sin, cos): - x = x.transpose(1, 2) - b, h, s, d = x.shape - ans = torch.split(x, d // 2, dim=-1) - x1, x2 = ans - left = x1 * cos - x2 * sin - right = x2 * cos + x1 * sin - res = torch.cat([left, right], dim=-1) - res = res.transpose(1, 2) - return res +def apply_rope_inline( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes rotary positional embedding inline for a query and key. + + Args: + q: the query tensor. + k: the key tensor. + cos: the cosine tensor. + sin: the sine tensor. + + Returns: + output the RoPE'd query and key. + """ - q_roped = apply(q, sin, cos) - k_roped = apply(k, sin, cos) + q_roped = apply_rope(q, cos, sin) + k_roped = apply_rope(k, cos, sin) return q_roped, k_roped diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index 2714eeaa..d99116b6 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb import ai_edge_torch.generative.utilities.loader as loading_utils import torch from torch import nn @@ -85,13 +86,6 @@ def __init__(self, config: cfg.ModelConfig): config.embedding_dim, config.final_norm_config, ) - # ROPE parameters for all attn_configs are the same. Take the first one. - attn_config = config.block_config(0).attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) self.mask_cache = attn_utils.build_causal_mask_cache( size=config.kv_cache_max, ) @@ -113,8 +107,18 @@ def forward( # token embeddings of shape (b, t, n_embd) input_embeds = self.tok_embedding(tokens) - cos, sin = self.rope_cache - rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos)) + + # ROPE parameters for all attn_configs are the same. Take the first one. + attn_config = self.config.block_config(0).attn_config + n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) + rope = self.config.build_rope( + input_pos=input_pos, + n_elem=n_elem, + base=attn_config.rotary_base, + head_dim=attn_config.head_dim, + # input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base + ) + mask = self.mask_cache.index_select(2, input_pos) mask = mask[:, :, :, : self.config.kv_cache_max] @@ -141,13 +145,13 @@ def forward_with_embeds( if self.config.embedding_scale is not None: x = x * self.config.embedding_scale - updated_kv_entires = [] + updated_kv_entries = [] for i, block in enumerate(self.transformer_blocks): kv_entry = kv_cache.caches[i] if kv_cache else None x, kv_entry = block(x, rope, mask, input_pos, kv_entry) if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) + updated_kv_entries.append(kv_entry) + updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries)) if export_config is not None: if (