Skip to content

Commit

Permalink
Convert PaliGemma2
Browse files Browse the repository at this point in the history
- Set version 2 as default in conversion and verifications
- Reduce parameters of fake model configs
- PaliGemma unittests are still crashing <50% times which has been improved from crashing 70% times before. Reducing some parameters seem helpful.
- Updated README.md accordingly

PiperOrigin-RevId: 708838151
  • Loading branch information
ai-edge-bot authored and copybara-github committed Dec 22, 2024
1 parent fa6b74d commit 562c93d
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 22 deletions.
14 changes: 8 additions & 6 deletions ai_edge_torch/generative/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ Gemma is Google's open-source LLM. The model has both a 2B and 7B versions. See
## PaliGemma
PaliGemma is a multimodal LLM which gets images and text as input, then
generates text as output. See
[model's Kaggle page](https://www.kaggle.com/models/google/paligemma).
The example we provide is PaliGemma 3B with 224 image size. Since Kaggle has
only Jax-version of PaliGemma, PyTorch model can be download from
[here](https://huggingface.co/google/paligemma-3b-mix-224/tree/main).

Note that PaliGemma can be converted to TfLite only with [ODML Torch conversion
[model's Kaggle page](https://www.kaggle.com/models/google/paligemma2).
The examples we provide are PaliGemma2 and 1 of 3B with 224 image size.
The checkpoint for PaliGemma2 can be downloaded from
[here](https://www.kaggle.com/models/google/paligemma-2/transformers/paligemma2-3b-pt-224).
Since Kaggle has only Jax-version of PaliGemma1, PyTorch model of PaliGemma1 can
be download from [here](https://huggingface.co/google/paligemma-3b-mix-224/tree/main).

Note that PaliGemma models can be converted to TfLite only with [ODML Torch conversion
backend](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental)

## Llama
Expand Down
14 changes: 11 additions & 3 deletions ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch

_VERSION = flags.DEFINE_enum(
'version',
'2',
['1', '2'],
'The version of PaliGemma model to verify.',
)
_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
Expand Down Expand Up @@ -63,10 +69,12 @@

def main(_):
pytorch_model = paligemma.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
_CHECKPOINT_PATH.value,
version=int(_VERSION.value),
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/paligemma/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
config.vocab_size = 128
config.num_layers = 2
config.max_seq_len = 2 * kv_cache_max_len
config.embedding_dim = 128
config.embedding_scale = 128**0.5
return config


Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/paligemma/decoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
config.vocab_size = 128
config.num_layers = 2
config.max_seq_len = 2 * kv_cache_max_len
config.embedding_dim = 128
config.embedding_scale = 128**0.5
return config


Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/paligemma/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
return PaliGemmaConfig(
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
decoder_config=get_decoder_config(**kwargs),
image_token_id=257152,
image_projection_scale=2048**0.5,
image_token_id=127,
image_projection_scale=128**0.5,
image_projection_use_bias=True,
)

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/paligemma/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

_VERSION = flags.DEFINE_enum(
"version",
"1",
"2",
["1", "2"],
"The version of PaliGemma model to verify.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

_VERSION = flags.DEFINE_enum(
"version",
"1",
"2",
["1", "2"],
"The version of PaliGemma vision model to verify.",
)
Expand Down
37 changes: 28 additions & 9 deletions ai_edge_torch/generative/test/test_model_conversion_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ai_edge_torch.generative.examples.gemma import gemma2
from ai_edge_torch.generative.examples.llama import llama
from ai_edge_torch.generative.examples.openelm import openelm
from ai_edge_torch.generative.examples.paligemma import decoder
from ai_edge_torch.generative.examples.paligemma import decoder2
from ai_edge_torch.generative.examples.paligemma import paligemma
from ai_edge_torch.generative.examples.phi import phi2
from ai_edge_torch.generative.examples.phi import phi3
Expand Down Expand Up @@ -171,13 +173,9 @@ def test_amd_llama_135m(self):
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)

@googletest.skipIf(
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def disabled_test_paligemma(self):
config = paligemma.get_fake_model_config()
pytorch_model = paligemma.PaliGemma(config).eval()
def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
config = paligemma.get_fake_model_config(decoder_config)
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()

image_embedding_config = config.image_encoder_config.image_embedding
num_patches = (
Expand Down Expand Up @@ -215,11 +213,32 @@ def disabled_test_paligemma(self):
kv,
pixel_values=pixel_values,
signature_name="prefill_pixel",
atol=1e-3,
rtol=1e-5,
atol=atol,
rtol=rtol,
)
)

@googletest.skipIf(
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def disabled_test_paligemma1(self):
self._test_paligemma_model(
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
)

@googletest.skipIf(
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def disabled_test_paligemma2(self):
self._test_paligemma_model(
decoder2.Decoder2,
decoder2.get_fake_decoder2_config,
atol=1e-3,
rtol=1e-5,
)

@googletest.skipIf(
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
Expand Down

0 comments on commit 562c93d

Please sign in to comment.