diff --git a/ai_edge_torch/generative/examples/paligemma/decoder.py b/ai_edge_torch/generative/examples/paligemma/decoder.py index df48eda4..760e7aad 100644 --- a/ai_edge_torch/generative/examples/paligemma/decoder.py +++ b/ai_edge_torch/generative/examples/paligemma/decoder.py @@ -19,6 +19,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils import torch @@ -61,8 +62,12 @@ def forward( assert input_embeds is not None repo_pos = input_pos + 1 # PaliGemma position is 1-based. - cos, sin = self.rope_cache - rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos)) + # ROPE parameters for all attn_configs are the same. Take the first one. + attn_config = self.config.block_config(0).attn_config + n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) + rope = rotary_pos_emb.build_rope( + repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base + ) # The first part of input_embeds are image embeddings. Diagonal causal mask # doesn't work here. diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index d259a5be..eece2d08 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -107,8 +107,6 @@ def forward( # token embeddings of shape (b, t, n_embd) input_embeds = self.tok_embedding(tokens) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] # ROPE parameters for all attn_configs are the same. Take the first one. attn_config = self.config.block_config(0).attn_config @@ -117,6 +115,9 @@ def forward( input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base ) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, : self.config.kv_cache_max] + return self.forward_with_embeds( input_embeds, rope, mask, input_pos, kv_cache, export_config )