From eafcb12e2dddb74a1f6e43302f5524c16a1860b1 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Fri, 20 Dec 2024 12:58:15 -0800 Subject: [PATCH] Fix rope calculation of PaliGemma decoder when pixel_embeds is passed. PiperOrigin-RevId: 708404348 --- ai_edge_torch/generative/examples/paligemma/decoder.py | 9 +++++++-- ai_edge_torch/generative/utilities/model_builder.py | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ai_edge_torch/generative/examples/paligemma/decoder.py b/ai_edge_torch/generative/examples/paligemma/decoder.py index df48eda4..760e7aad 100644 --- a/ai_edge_torch/generative/examples/paligemma/decoder.py +++ b/ai_edge_torch/generative/examples/paligemma/decoder.py @@ -19,6 +19,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils import torch @@ -61,8 +62,12 @@ def forward( assert input_embeds is not None repo_pos = input_pos + 1 # PaliGemma position is 1-based. - cos, sin = self.rope_cache - rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos)) + # ROPE parameters for all attn_configs are the same. Take the first one. + attn_config = self.config.block_config(0).attn_config + n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) + rope = rotary_pos_emb.build_rope( + repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base + ) # The first part of input_embeds are image embeddings. Diagonal causal mask # doesn't work here. diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index d259a5be..eece2d08 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -107,8 +107,6 @@ def forward( # token embeddings of shape (b, t, n_embd) input_embeds = self.tok_embedding(tokens) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] # ROPE parameters for all attn_configs are the same. Take the first one. attn_config = self.config.block_config(0).attn_config @@ -117,6 +115,9 @@ def forward( input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base ) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, : self.config.kv_cache_max] + return self.forward_with_embeds( input_embeds, rope, mask, input_pos, kv_cache, export_config )