-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PaliGemma decoder example to ai_edge_torch.
- Kaggle has only JAX model. Downloads PyTorch model from HF. - ImageProcessing/Encoder will be added in a following change. PiperOrigin-RevId: 694701725
- Loading branch information
1 parent
56a9b17
commit f70448e
Showing
3 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Example of building a decoder of PaliGemma 3B model which is Gemma1.""" | ||
|
||
import ai_edge_torch.generative.layers.model_config as cfg | ||
from ai_edge_torch.generative.utilities import model_builder | ||
import ai_edge_torch.generative.utilities.loader as loading_utils | ||
|
||
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( | ||
ff_up_proj="language_model.model.layers.{}.mlp.up_proj", | ||
ff_down_proj="language_model.model.layers.{}.mlp.down_proj", | ||
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj", | ||
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj", | ||
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj", | ||
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj", | ||
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj", | ||
pre_attn_norm="language_model.model.layers.{}.input_layernorm", | ||
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm", | ||
embedding="language_model.model.embed_tokens", | ||
final_norm="language_model.model.norm", | ||
lm_head=None, | ||
) | ||
|
||
|
||
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: | ||
"""Returns the model config for the decoder of a PaliGemma 3B model. | ||
Args: | ||
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default | ||
is 1024. | ||
Returns: | ||
The model config for the decoder of a PaliGemma 3B model. | ||
""" | ||
attn_config = cfg.AttentionConfig( | ||
num_heads=8, | ||
head_dim=256, | ||
num_query_groups=1, | ||
rotary_base=10000, | ||
rotary_percentage=1.0, | ||
) | ||
ff_config = cfg.FeedForwardConfig( | ||
type=cfg.FeedForwardType.GATED, | ||
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH), | ||
intermediate_size=16384, | ||
) | ||
norm_config = cfg.NormalizationConfig( | ||
type=cfg.NormalizationType.RMS_NORM, | ||
epsilon=1e-6, | ||
zero_centered=True, | ||
) | ||
block_config = cfg.TransformerBlockConfig( | ||
attn_config=attn_config, | ||
ff_config=ff_config, | ||
pre_attention_norm_config=norm_config, | ||
post_attention_norm_config=norm_config, | ||
) | ||
config = cfg.ModelConfig( | ||
vocab_size=257216, | ||
num_layers=18, | ||
max_seq_len=8192, | ||
embedding_dim=2048, | ||
embedding_scale=2048**0.5, | ||
kv_cache_max_len=kv_cache_max_len, | ||
block_configs=block_config, | ||
final_norm_config=norm_config, | ||
lm_head_use_bias=False, | ||
enable_hlfb=True, | ||
) | ||
return config | ||
|
||
|
||
def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: | ||
config = get_decoder_config(kv_cache_max_len) | ||
# PaliGemma decoder has only one block config. | ||
config.block_config(0).ff_config.intermediate_size = 128 | ||
config.vocab_size = 128 | ||
config.num_layers = 2 | ||
config.max_seq_len = 2 * kv_cache_max_len | ||
return config | ||
|
||
|
||
def build_decoder( | ||
checkpoint_path: str, **kwargs | ||
) -> model_builder.DecoderOnlyModel: | ||
return model_builder.build_decoder_only_model( | ||
checkpoint_path=checkpoint_path, | ||
config=get_decoder_config(**kwargs), | ||
tensor_names=TENSOR_NAMES, | ||
) |
75 changes: 75 additions & 0 deletions
75
ai_edge_torch/generative/examples/paligemma/verify_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Verifies the reauthored decoder of PaliGemma 3B model.""" | ||
|
||
import logging | ||
import pathlib | ||
|
||
from absl import app | ||
from absl import flags | ||
from ai_edge_torch.generative.examples.paligemma import decoder | ||
from ai_edge_torch.generative.utilities import transformers_verifier | ||
from ai_edge_torch.generative.utilities import verifier | ||
import transformers | ||
|
||
_PROMPTS = flags.DEFINE_multi_string( | ||
"prompts", | ||
"What is the meaning of life?", | ||
"The input prompts to generate answers.", | ||
) | ||
_MAX_NEW_TOKENS = flags.DEFINE_integer( | ||
"max_new_tokens", | ||
30, | ||
"The maximum size of the generated tokens.", | ||
) | ||
|
||
|
||
def main(_): | ||
checkpoint = "google/paligemma-3b-mix-224" | ||
logging.info("Loading the original model from: %s", checkpoint) | ||
original_full_model = ( | ||
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint) | ||
) | ||
original_language_model = original_full_model.eval().language_model | ||
|
||
# Locate the cached dir. | ||
cached_config_file = transformers.utils.cached_file( | ||
checkpoint, transformers.utils.CONFIG_NAME | ||
) | ||
reauthored_checkpoint = pathlib.Path(cached_config_file).parent | ||
logging.info("Building the reauthored model from: %s", reauthored_checkpoint) | ||
reauthored_model = decoder.build_decoder(reauthored_checkpoint) | ||
|
||
logging.info("Loading the tokenizer from: %s", checkpoint) | ||
# It works only when GemmaTokenizerFast is available. In some environments, | ||
# use_fast=False doeesn't work either if the tokenizer cannot load the | ||
# sentencepiece model file properly. | ||
processor = transformers.AutoProcessor.from_pretrained(checkpoint) | ||
|
||
verifier.verify_reauthored_model( | ||
original_model=transformers_verifier.TransformersModelWrapper( | ||
original_language_model | ||
), | ||
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model), | ||
tokenizer=verifier.TokenizerWrapper(processor.tokenizer), | ||
generate_prompts=_PROMPTS.value, | ||
max_new_tokens=_MAX_NEW_TOKENS.value, | ||
atol=1e-04, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |