From f70448e67e8c67137c7f89cc50e4c263d709cc9b Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Fri, 8 Nov 2024 17:41:14 -0800 Subject: [PATCH] Add PaliGemma decoder example to ai_edge_torch. - Kaggle has only JAX model. Downloads PyTorch model from HF. - ImageProcessing/Encoder will be added in a following change. PiperOrigin-RevId: 694701725 --- .../generative/examples/paligemma/__init__.py | 14 +++ .../generative/examples/paligemma/decoder.py | 103 ++++++++++++++++++ .../examples/paligemma/verify_decoder.py | 75 +++++++++++++ 3 files changed, 192 insertions(+) create mode 100644 ai_edge_torch/generative/examples/paligemma/__init__.py create mode 100644 ai_edge_torch/generative/examples/paligemma/decoder.py create mode 100644 ai_edge_torch/generative/examples/paligemma/verify_decoder.py diff --git a/ai_edge_torch/generative/examples/paligemma/__init__.py b/ai_edge_torch/generative/examples/paligemma/__init__.py new file mode 100644 index 00000000..57b12003 --- /dev/null +++ b/ai_edge_torch/generative/examples/paligemma/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/ai_edge_torch/generative/examples/paligemma/decoder.py b/ai_edge_torch/generative/examples/paligemma/decoder.py new file mode 100644 index 00000000..afeaa15a --- /dev/null +++ b/ai_edge_torch/generative/examples/paligemma/decoder.py @@ -0,0 +1,103 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Example of building a decoder of PaliGemma 3B model which is Gemma1.""" + +import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder +import ai_edge_torch.generative.utilities.loader as loading_utils + +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( + ff_up_proj="language_model.model.layers.{}.mlp.up_proj", + ff_down_proj="language_model.model.layers.{}.mlp.down_proj", + ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj", + attn_query_proj="language_model.model.layers.{}.self_attn.q_proj", + attn_key_proj="language_model.model.layers.{}.self_attn.k_proj", + attn_value_proj="language_model.model.layers.{}.self_attn.v_proj", + attn_output_proj="language_model.model.layers.{}.self_attn.o_proj", + pre_attn_norm="language_model.model.layers.{}.input_layernorm", + post_attn_norm="language_model.model.layers.{}.post_attention_layernorm", + embedding="language_model.model.embed_tokens", + final_norm="language_model.model.norm", + lm_head=None, +) + + +def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: + """Returns the model config for the decoder of a PaliGemma 3B model. + + Args: + kv_cache_max_len (int): The maximum sequence length of the KV cache. Default + is 1024. + + Returns: + The model config for the decoder of a PaliGemma 3B model. + """ + attn_config = cfg.AttentionConfig( + num_heads=8, + head_dim=256, + num_query_groups=1, + rotary_base=10000, + rotary_percentage=1.0, + ) + ff_config = cfg.FeedForwardConfig( + type=cfg.FeedForwardType.GATED, + activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH), + intermediate_size=16384, + ) + norm_config = cfg.NormalizationConfig( + type=cfg.NormalizationType.RMS_NORM, + epsilon=1e-6, + zero_centered=True, + ) + block_config = cfg.TransformerBlockConfig( + attn_config=attn_config, + ff_config=ff_config, + pre_attention_norm_config=norm_config, + post_attention_norm_config=norm_config, + ) + config = cfg.ModelConfig( + vocab_size=257216, + num_layers=18, + max_seq_len=8192, + embedding_dim=2048, + embedding_scale=2048**0.5, + kv_cache_max_len=kv_cache_max_len, + block_configs=block_config, + final_norm_config=norm_config, + lm_head_use_bias=False, + enable_hlfb=True, + ) + return config + + +def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: + config = get_decoder_config(kv_cache_max_len) + # PaliGemma decoder has only one block config. + config.block_config(0).ff_config.intermediate_size = 128 + config.vocab_size = 128 + config.num_layers = 2 + config.max_seq_len = 2 * kv_cache_max_len + return config + + +def build_decoder( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_decoder_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/examples/paligemma/verify_decoder.py b/ai_edge_torch/generative/examples/paligemma/verify_decoder.py new file mode 100644 index 00000000..6241cdc0 --- /dev/null +++ b/ai_edge_torch/generative/examples/paligemma/verify_decoder.py @@ -0,0 +1,75 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored decoder of PaliGemma 3B model.""" + +import logging +import pathlib + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.paligemma import decoder +from ai_edge_torch.generative.utilities import transformers_verifier +from ai_edge_torch.generative.utilities import verifier +import transformers + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + "What is the meaning of life?", + "The input prompts to generate answers.", +) +_MAX_NEW_TOKENS = flags.DEFINE_integer( + "max_new_tokens", + 30, + "The maximum size of the generated tokens.", +) + + +def main(_): + checkpoint = "google/paligemma-3b-mix-224" + logging.info("Loading the original model from: %s", checkpoint) + original_full_model = ( + transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint) + ) + original_language_model = original_full_model.eval().language_model + + # Locate the cached dir. + cached_config_file = transformers.utils.cached_file( + checkpoint, transformers.utils.CONFIG_NAME + ) + reauthored_checkpoint = pathlib.Path(cached_config_file).parent + logging.info("Building the reauthored model from: %s", reauthored_checkpoint) + reauthored_model = decoder.build_decoder(reauthored_checkpoint) + + logging.info("Loading the tokenizer from: %s", checkpoint) + # It works only when GemmaTokenizerFast is available. In some environments, + # use_fast=False doeesn't work either if the tokenizer cannot load the + # sentencepiece model file properly. + processor = transformers.AutoProcessor.from_pretrained(checkpoint) + + verifier.verify_reauthored_model( + original_model=transformers_verifier.TransformersModelWrapper( + original_language_model + ), + reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model), + tokenizer=verifier.TokenizerWrapper(processor.tokenizer), + generate_prompts=_PROMPTS.value, + max_new_tokens=_MAX_NEW_TOKENS.value, + atol=1e-04, + ) + + +if __name__ == "__main__": + app.run(main)