From 2d5c0ff85517a02334fc253a4e9b754bd48a6c32 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Sat, 21 Dec 2024 14:00:32 -0800 Subject: [PATCH] Convert PaliGemma2 - Set version 2 as default in conversion and verifications - Reduce parameters of fake model configs - PaliGemma unittests are still crashing <50% times which has been improved from crashing 70% times before. Reducing some parameters seem helpful. - Updated README.md accordingly PiperOrigin-RevId: 708657532 --- ai_edge_torch/generative/examples/README.md | 14 ++++--- .../examples/paligemma/convert_to_tflite.py | 14 +++++-- .../generative/examples/paligemma/decoder.py | 2 + .../generative/examples/paligemma/decoder2.py | 2 + .../examples/paligemma/paligemma.py | 4 +- .../generative/examples/paligemma/verify.py | 2 +- .../paligemma/verify_image_encoder.py | 2 +- .../test/test_model_conversion_large.py | 37 ++++++++++++++----- 8 files changed, 55 insertions(+), 22 deletions(-) diff --git a/ai_edge_torch/generative/examples/README.md b/ai_edge_torch/generative/examples/README.md index a8fc7888..d902061f 100644 --- a/ai_edge_torch/generative/examples/README.md +++ b/ai_edge_torch/generative/examples/README.md @@ -7,12 +7,14 @@ Gemma is Google's open-source LLM. The model has both a 2B and 7B versions. See ## PaliGemma PaliGemma is a multimodal LLM which gets images and text as input, then generates text as output. See -[model's Kaggle page](https://www.kaggle.com/models/google/paligemma). -The example we provide is PaliGemma 3B with 224 image size. Since Kaggle has -only Jax-version of PaliGemma, PyTorch model can be download from -[here](https://huggingface.co/google/paligemma-3b-mix-224/tree/main). - -Note that PaliGemma can be converted to TfLite only with [ODML Torch conversion +[model's Kaggle page](https://www.kaggle.com/models/google/paligemma2). +The examples we provide are PaliGemma2 and 1 of 3B with 224 image size. +The checkpoint for PaliGemma2 can be downloaded from +[here](https://www.kaggle.com/models/google/paligemma-2/transformers/paligemma2-3b-pt-224). +Since Kaggle has only Jax-version of PaliGemma1, PyTorch model of PaliGemma1 can +be download from [here](https://huggingface.co/google/paligemma-3b-mix-224/tree/main). + +Note that PaliGemma models can be converted to TfLite only with [ODML Torch conversion backend](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental) ## Llama diff --git a/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py b/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py index c60ed743..47f2cab3 100644 --- a/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py @@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig import torch +_VERSION = flags.DEFINE_enum( + 'version', + '2', + ['1', '2'], + 'The version of PaliGemma model to verify.', +) _CHECKPOINT_PATH = flags.DEFINE_string( 'checkpoint_path', - os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'), + os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) _TFLITE_PATH = flags.DEFINE_string( @@ -63,10 +69,12 @@ def main(_): pytorch_model = paligemma.build_model( - _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value + _CHECKPOINT_PATH.value, + version=int(_VERSION.value), + kv_cache_max_len=_KV_CACHE_MAX_LEN.value, ) quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' + output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), diff --git a/ai_edge_torch/generative/examples/paligemma/decoder.py b/ai_edge_torch/generative/examples/paligemma/decoder.py index 4be4969e..79f71f16 100644 --- a/ai_edge_torch/generative/examples/paligemma/decoder.py +++ b/ai_edge_torch/generative/examples/paligemma/decoder.py @@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: config.vocab_size = 128 config.num_layers = 2 config.max_seq_len = 2 * kv_cache_max_len + config.embedding_dim = 16 + config.embedding_scale = 16**0.5 return config diff --git a/ai_edge_torch/generative/examples/paligemma/decoder2.py b/ai_edge_torch/generative/examples/paligemma/decoder2.py index 87835632..2976355d 100644 --- a/ai_edge_torch/generative/examples/paligemma/decoder2.py +++ b/ai_edge_torch/generative/examples/paligemma/decoder2.py @@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: config.vocab_size = 128 config.num_layers = 2 config.max_seq_len = 2 * kv_cache_max_len + config.embedding_dim = 16 + config.embedding_scale = 16**0.5 return config diff --git a/ai_edge_torch/generative/examples/paligemma/paligemma.py b/ai_edge_torch/generative/examples/paligemma/paligemma.py index 08facf94..e117df10 100644 --- a/ai_edge_torch/generative/examples/paligemma/paligemma.py +++ b/ai_edge_torch/generative/examples/paligemma/paligemma.py @@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig: return PaliGemmaConfig( image_encoder_config=image_encoder.get_fake_image_encoder_config(), decoder_config=get_decoder_config(**kwargs), - image_token_id=257152, - image_projection_scale=2048**0.5, + image_token_id=127, + image_projection_scale=128**0.5, image_projection_use_bias=True, ) diff --git a/ai_edge_torch/generative/examples/paligemma/verify.py b/ai_edge_torch/generative/examples/paligemma/verify.py index 8873a54f..87764f32 100644 --- a/ai_edge_torch/generative/examples/paligemma/verify.py +++ b/ai_edge_torch/generative/examples/paligemma/verify.py @@ -30,7 +30,7 @@ _VERSION = flags.DEFINE_enum( "version", - "1", + "2", ["1", "2"], "The version of PaliGemma model to verify.", ) diff --git a/ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py b/ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py index 5da4acdd..435bcdf0 100644 --- a/ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +++ b/ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py @@ -28,7 +28,7 @@ _VERSION = flags.DEFINE_enum( "version", - "1", + "2", ["1", "2"], "The version of PaliGemma vision model to verify.", ) diff --git a/ai_edge_torch/generative/test/test_model_conversion_large.py b/ai_edge_torch/generative/test/test_model_conversion_large.py index 94ca776e..1f7afc83 100644 --- a/ai_edge_torch/generative/test/test_model_conversion_large.py +++ b/ai_edge_torch/generative/test/test_model_conversion_large.py @@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma2 from ai_edge_torch.generative.examples.llama import llama from ai_edge_torch.generative.examples.openelm import openelm +from ai_edge_torch.generative.examples.paligemma import decoder +from ai_edge_torch.generative.examples.paligemma import decoder2 from ai_edge_torch.generative.examples.paligemma import paligemma from ai_edge_torch.generative.examples.phi import phi2 from ai_edge_torch.generative.examples.phi import phi3 @@ -171,13 +173,9 @@ def test_amd_llama_135m(self): pytorch_model = amd_llama_135m.AmdLlama(config).eval() self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5) - @googletest.skipIf( - ai_edge_torch.config.in_oss, - reason="tests with custom ops are not supported in oss", - ) - def disabled_test_paligemma(self): - config = paligemma.get_fake_model_config() - pytorch_model = paligemma.PaliGemma(config).eval() + def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol): + config = paligemma.get_fake_model_config(decoder_config) + pytorch_model = paligemma.PaliGemma(config, decoder_class).eval() image_embedding_config = config.image_encoder_config.image_embedding num_patches = ( @@ -215,11 +213,32 @@ def disabled_test_paligemma(self): kv, pixel_values=pixel_values, signature_name="prefill_pixel", - atol=1e-3, - rtol=1e-5, + atol=atol, + rtol=rtol, ) ) + @googletest.skipIf( + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", + ) + def test_paligemma1(self): + self._test_paligemma_model( + decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5 + ) + + @googletest.skipIf( + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", + ) + def test_paligemma2(self): + self._test_paligemma_model( + decoder2.Decoder2, + decoder2.get_fake_decoder2_config, + atol=1e-3, + rtol=1e-5, + ) + @googletest.skipIf( ai_edge_torch.config.in_oss, reason="tests with custom ops are not supported in oss",