Skip to content

Commit

Permalink
Fix PaliGemma conversion and verify.py of examples broken by ExportCo…
Browse files Browse the repository at this point in the history
…nfig.

PiperOrigin-RevId: 704971311
  • Loading branch information
ai-edge-bot authored and copybara-github committed Dec 11, 2024
1 parent 27b7077 commit 06be52c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
3 changes: 3 additions & 0 deletions ai_edge_torch/generative/examples/paligemma/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""

from typing import Optional

from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
Expand Down Expand Up @@ -51,6 +53,7 @@ def forward(
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
input_embeds: torch.Tensor = None,
export_config: Optional[model_builder.ExportConfig] = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
if input_embeds is None:
return super().forward(tokens, input_pos, kv_cache)
Expand Down
12 changes: 11 additions & 1 deletion ai_edge_torch/generative/examples/paligemma/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""Example of building a full-stack of PaliGemma model."""

from dataclasses import dataclass
from typing import Optional

from ai_edge_torch.generative.examples.paligemma import decoder
from ai_edge_torch.generative.examples.paligemma import image_encoder
import ai_edge_torch.generative.layers.kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
from torch import nn
Expand Down Expand Up @@ -67,9 +69,16 @@ def forward(
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
pixel_values: torch.Tensor = None,
export_config: Optional[model_builder.ExportConfig] = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
if pixel_values is None:
return self.decoder(tokens, input_pos, kv_cache)
return self.decoder(
tokens=tokens,
input_pos=input_pos,
kv_cache=kv_cache,
input_embeds=None,
export_config=export_config
)

input_embeds = self.decoder.tok_embedding(tokens)

Expand Down Expand Up @@ -100,6 +109,7 @@ def forward(
input_pos=input_pos,
kv_cache=kv_cache,
input_embeds=input_embeds,
export_config=export_config,
)


Expand Down
6 changes: 3 additions & 3 deletions ai_edge_torch/generative/utilities/transformers_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TransformersModelWrapper(verifier.ModelWrapper):
an object with `logits` field.
Transformers models get `max_new_tokens` settings for generate() via
ExportConfig.
GenerationConfig.
"""

def forward(self, tokens: torch.Tensor) -> torch.Tensor:
Expand All @@ -38,5 +38,5 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
def generate(
self, inputs: torch.Tensor, max_new_tokens: int
) -> torch.IntTensor:
export_config = transformers.ExportConfig(max_new_tokens=max_new_tokens)
return self.model.generate(inputs=inputs, generation_config=export_config)
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
return self.model.generate(inputs=inputs, generation_config=gen_config)
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/utilities/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _forward_with_kv_cache(
# pixel_values only when it is not None. Otherwise, it may raise an error.
if pixel_values is None:
output = self.model.forward(
tokens, input_pos, kv_cache, self.export_config
tokens, input_pos, kv_cache, export_config=self.export_config
)
else:
output = self.model.forward(
Expand Down

0 comments on commit 06be52c

Please sign in to comment.