Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rope calculation of PaliGemma decoder when pixel_embeds is passed. #433

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ai_edge_torch/generative/examples/paligemma/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
Expand Down Expand Up @@ -61,8 +62,12 @@ def forward(
assert input_embeds is not None

repo_pos = input_pos + 1 # PaliGemma position is 1-based.
cos, sin = self.rope_cache
rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = self.config.block_config(0).attn_config
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
rope = rotary_pos_emb.build_rope(
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

# The first part of input_embeds are image embeddings. Diagonal causal mask
# doesn't work here.
Expand Down
5 changes: 3 additions & 2 deletions ai_edge_torch/generative/utilities/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def forward(

# token embeddings of shape (b, t, n_embd)
input_embeds = self.tok_embedding(tokens)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

# ROPE parameters for all attn_configs are the same. Take the first one.
attn_config = self.config.block_config(0).attn_config
Expand All @@ -117,6 +115,9 @@ def forward(
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
)

mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

return self.forward_with_embeds(
input_embeds, rope, mask, input_pos, kv_cache, export_config
)
Expand Down
Loading