Skip to content

Commit

Permalink
Fix rope calculation of PaliGemma decoder when pixel_embeds is passed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708404348
  • Loading branch information
ai-edge-bot authored and copybara-github committed Dec 20, 2024
1 parent 9d387ec commit eafcb12
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
9 changes: 7 additions & 2 deletions ai_edge_torch/generative/examples/paligemma/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
Expand Down Expand Up @@ -61,8 +62,12 @@ def forward(
assert input_embeds is not None

repo_pos = input_pos + 1 # PaliGemma position is 1-based.
cos, sin = self.rope_cache
rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

# The first part of input_embeds are image embeddings. Diagonal causal mask
# doesn't work here.
Expand Down
5 changes: 3 additions & 2 deletions ai_edge_torch/generative/utilities/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def forward(

# token embeddings of shape (b, t, n_embd)
input_embeds = self.tok_embedding(tokens)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = self.config.block_config(0).attn_config
Expand All @@ -117,6 +115,9 @@ def forward(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

return self.forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache, export_config
)
Expand Down

0 comments on commit eafcb12

Please sign in to comment.