diff --git a/ai_edge_torch/generative/examples/paligemma/decoder.py b/ai_edge_torch/generative/examples/paligemma/decoder.py index afeaa15a..c90e040d 100644 --- a/ai_edge_torch/generative/examples/paligemma/decoder.py +++ b/ai_edge_torch/generative/examples/paligemma/decoder.py @@ -15,9 +15,11 @@ """Example of building a decoder of PaliGemma 3B model which is Gemma1.""" +from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils +import torch TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="language_model.model.layers.{}.mlp.up_proj", @@ -35,6 +37,41 @@ ) +class Decoder(model_builder.DecoderOnlyModel): + """A decoder of PaliGemma 3B model which is Gemma1. + + Besides a tensor of text token IDs, forward() can also take a tensor of + embeddings which may include text or image or both. + """ + + @torch.inference_mode + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + kv_cache: kv_utils.KVCache, + input_embeds: torch.Tensor = None, + ) -> dict[torch.Tensor, kv_utils.KVCache]: + if input_embeds is None: + return super().forward(tokens, input_pos, kv_cache) + + assert input_embeds is not None + + repo_pos = input_pos + 1 # PaliGemma position is 1-based. + cos, sin = self.rope_cache + rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos)) + + # The first part of input_embeds are image embeddings. Diagonal causal mask + # doesn't work here. + embeds_len = input_embeds.shape[1] + mask = torch.zeros(embeds_len, self.config.kv_cache_max) + mask[:, embeds_len:] = float("-inf") + + return self.forward_with_embeds( + input_embeds, rope, mask, input_pos, kv_cache + ) + + def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for the decoder of a PaliGemma 3B model. @@ -96,8 +133,9 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: def build_decoder( checkpoint_path: str, **kwargs ) -> model_builder.DecoderOnlyModel: - return model_builder.build_decoder_only_model( - checkpoint_path=checkpoint_path, - config=get_decoder_config(**kwargs), - tensor_names=TENSOR_NAMES, - ) + decoder = Decoder(get_decoder_config(**kwargs)) + loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) + # Loose the strictness because only decoder is being loaded. + loader.load(decoder, strict=False) + decoder.eval() + return decoder diff --git a/ai_edge_torch/generative/examples/paligemma/paligemma.py b/ai_edge_torch/generative/examples/paligemma/paligemma.py new file mode 100644 index 00000000..664e47e1 --- /dev/null +++ b/ai_edge_torch/generative/examples/paligemma/paligemma.py @@ -0,0 +1,135 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Example of building a full-stack of PaliGemma model.""" + +from dataclasses import dataclass + +from ai_edge_torch.generative.examples.paligemma import decoder +from ai_edge_torch.generative.examples.paligemma import image_encoder +import ai_edge_torch.generative.layers.kv_cache as kv_utils +import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.utilities.loader as loading_utils +import torch +from torch import nn + +PROJECTION_TENSOR_NAME = "multi_modal_projector.linear" + + +@dataclass +class PaliGemmaConfig: + """PaliGemma model configurations.""" + + image_encoder_config: cfg.ModelConfig + decoder_config: cfg.ModelConfig + + image_token_id: int + image_projection_use_bias: bool = False + + +class PaliGemma(nn.Module): + """PaliGemma model from the Edge Generative API.""" + + def __init__(self, config: PaliGemmaConfig): + super().__init__() + + self.image_encoder = image_encoder.SiglipVisionEncoder( + config.image_encoder_config + ) + self.image_projection = nn.Linear( + config.image_encoder_config.embedding_dim, + config.decoder_config.embedding_dim, + bias=config.image_projection_use_bias, + ) + self.decoder = decoder.Decoder(config.decoder_config) + self.config = config + + @torch.inference_mode + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + kv_cache: kv_utils.KVCache, + pixel_values: torch.Tensor = None, + ) -> dict[torch.Tensor, kv_utils.KVCache]: + if pixel_values is None: + return self.decoder(tokens, input_pos, kv_cache) + + input_embeds = self.decoder.tok_embedding(tokens) + + image_encoded = self.image_encoder(pixel_values=pixel_values) + image_embeds = self.image_projection(image_encoded) + if self.config.decoder_config.embedding_scale is not None: + image_embeds = image_embeds / self.config.decoder_config.embedding_scale + + # Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration. + image_mask = tokens == self.config.image_token_id + image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(image_mask, image_embeds) + + return self.decoder( + tokens=None, + input_pos=input_pos, + kv_cache=kv_cache, + input_embeds=input_embeds, + ) + + +def get_model_config() -> PaliGemmaConfig: + """Returns the model config for a PaliGemma 3B-224 model. + + Returns: + The model config for a PaliGemma 3B model. + """ + return PaliGemmaConfig( + image_encoder_config=image_encoder.get_image_encoder_config(), + decoder_config=decoder.get_decoder_config(), + image_projection_use_bias=True, + image_token_id=257152, + ) + + +def get_fake_image_encoder_config() -> PaliGemmaConfig: + return PaliGemmaConfig( + image_encoder_config=image_encoder.get_fake_image_encoder_config(), + decoder_config=decoder.get_fake_decoder_config(), + image_projection_use_bias=True, + image_token_id=257152, + ) + + +def build_model(checkpoint_path: str) -> PaliGemma: + config = get_model_config() + model = PaliGemma(config) + # Load the parameters of image encoder. + loader = loading_utils.ModelLoader( + checkpoint_path, image_encoder.TENSOR_NAMES + ) + loader.load(model.image_encoder, strict=False) + # Load the parameters of decoder. + loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES) + loader.load(model.decoder, strict=False) + + # Load the parameters of image projection. + loader = loading_utils.ModelLoader(checkpoint_path, None) + state = loader.get_state() + converted_state = dict() + converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight") + if config.image_projection_use_bias: + converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias") + model.image_projection.load_state_dict(converted_state) + + model.eval() + return model diff --git a/ai_edge_torch/generative/examples/paligemma/verify.py b/ai_edge_torch/generative/examples/paligemma/verify.py new file mode 100644 index 00000000..6ff4acd6 --- /dev/null +++ b/ai_edge_torch/generative/examples/paligemma/verify.py @@ -0,0 +1,134 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored PaliGemma 3B model.""" + +import logging +import pathlib +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.paligemma import paligemma +from ai_edge_torch.generative.layers import kv_cache +from ai_edge_torch.generative.utilities import verifier +from PIL import Image +import requests +import torch +import transformers + +_IMAGE_URL = flags.DEFINE_string( + "image_url", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true", + "The image URI to encode.", +) +_PROMPTS = flags.DEFINE_string( + "prompts", + "Caption en", + "The input prompts to generate answers.", +) +_MAX_NEW_TOKENS = flags.DEFINE_integer( + "max_new_tokens", + 30, + "The maximum size of the generated tokens.", +) + + +class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper): + """Reauthored PaliGemma model wrapper.""" + + def _init_kv_cache(self): + return kv_cache.KVCache.from_model_config(self.model.config.decoder_config) + + +def main(_): + checkpoint = "google/paligemma-3b-mix-224" + logging.info("Loading the original model from: %s", checkpoint) + original_model = ( + transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint) + ) + + # Locate the cached dir. + cached_config_file = transformers.utils.cached_file( + checkpoint, transformers.utils.CONFIG_NAME + ) + reauthored_checkpoint = pathlib.Path(cached_config_file).parent + logging.info("Building the reauthored model from: %s", reauthored_checkpoint) + reauthored_model = paligemma.build_model(reauthored_checkpoint) + + logging.info("Loading the processor from: %s", checkpoint) + # It works only when GemmaTokenizerFast is available. In some environments, + # use_fast=False doeesn't work either if the tokenizer cannot load the + # sentencepiece model file properly. + processor = transformers.AutoProcessor.from_pretrained(checkpoint) + + logging.info("Loading the image from: %s", _IMAGE_URL.value) + image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw) + inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt") + + logging.info("Verifying the reauthored model with model.forward()...") + logging.info("Forwarding the original model...") + outputs_original = original_model.forward( + input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"] + ) + outputs_original = outputs_original.logits + logging.info("outputs_original: %s", outputs_original) + + logging.info("Forwarding the reauthored model...") + wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model) + outputs_reauthored = wrapped_reauthored_model.forward( + tokens=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + ) + logging.info("outputs_reauthored: %s", outputs_reauthored) + + try: + assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03) + except AssertionError as e: + logging.error("*** FAILED *** verify with forward()") + raise e + else: + logging.info("*** PASSED *** verify with forward()") + + logging.info("Verifying the reauthored model with model.generate()...") + logging.info("Generating answer with the original model...") + outputs_original = original_model.generate( + **inputs, max_new_tokens=_MAX_NEW_TOKENS.value, do_sample=False + ) + response_original = processor.decode( + outputs_original[0], skip_special_tokens=True + ) + logging.info("outputs_from_original_model: [[%s]]", response_original) + + logging.info("Generating answer with the reauthored model...") + outputs_reauthored = wrapped_reauthored_model.generate( + prompts=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=_MAX_NEW_TOKENS.value, + ) + response_reauthored = processor.decode( + outputs_reauthored[0], skip_special_tokens=True + ) + logging.info("outputs from reauthored model: [[%s]]", response_reauthored) + + try: + assert response_original == response_reauthored + except AssertionError as e: + logging.error("*** FAILED *** verify with generate()") + raise e + else: + logging.info("*** PASSED *** verify with generate()") + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index 064d5fca..bd6699d9 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -131,6 +131,9 @@ def __init__(self, file_name: str, names: TensorNames) -> None: self._names = names self._loader = self._get_loader() + def get_state(self) -> Dict[str, torch.Tensor]: + return self._loader(self._file_name) + def load( self, model: torch.nn.Module, strict: bool = True ) -> Tuple[List[str], List[str]]: @@ -150,7 +153,7 @@ def load( ValueError: If conversion results in unmapped tensors and strict mode is enabled. """ - state = self._loader(self._file_name) + state = self.get_state() state = state["model_state_dict"] if "model_state_dict" in state else state converted_state = dict() if self._names.embedding is not None: diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index 9565fad4..463e753a 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -16,6 +16,7 @@ """Utilities to be used for re-authoring transformer models.""" import copy +from typing import Tuple from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder @@ -98,26 +99,40 @@ def forward( f"Cannot forward sequence of length {seq_len}, max seq length is only" f" {self.config.max_seq_len}" ) - assert len(self.transformer_blocks) == len(kv_cache.caches), ( - "The number of transformer blocks and the number of KV cache entries" - " must be the same." - ) + # token embeddings of shape (b, t, n_embd) + input_embeds = self.tok_embedding(tokens) cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) + rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos)) mask = self.mask_cache.index_select(2, input_pos) mask = mask[:, :, :, : self.config.kv_cache_max] - # token embeddings of shape (b, t, n_embd) - x = self.tok_embedding(tokens) + return self.forward_with_embeds( + input_embeds, rope, mask, input_pos, kv_cache + ) + + def forward_with_embeds( + self, + input_embeds: torch.Tensor, + rope: Tuple[torch.Tensor, torch.Tensor], + mask: torch.Tensor, + input_pos: torch.Tensor, + kv_cache: kv_utils.KVCache, + ) -> dict[torch.Tensor, kv_utils.KVCache]: + """Forwards the model with input embeddings.""" + assert len(self.transformer_blocks) == len(kv_cache.caches), ( + "The number of transformer blocks and the number of KV cache entries" + " must be the same." + ) + + x = input_embeds if self.config.embedding_scale is not None: x = x * self.config.embedding_scale updated_kv_entires = [] for i, block in enumerate(self.transformer_blocks): kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) + x, kv_entry = block(x, rope, mask, input_pos, kv_entry) if kv_entry: updated_kv_entires.append(kv_entry) updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) diff --git a/ai_edge_torch/generative/utilities/verifier.py b/ai_edge_torch/generative/utilities/verifier.py index 2b7cdc08..189da7bb 100644 --- a/ai_edge_torch/generative/utilities/verifier.py +++ b/ai_edge_torch/generative/utilities/verifier.py @@ -41,7 +41,9 @@ def __init__(self, model: torch.nn.Module): super().__init__() self.model = model - def forward(self, tokens: torch.Tensor) -> torch.Tensor: + def forward( + self, tokens: torch.Tensor, pixel_values: torch.Tensor = None + ) -> torch.Tensor: """Gets output logits by forwarding the input tokens. Args: @@ -54,7 +56,10 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: raise NotImplementedError("forward() is not implemented.") def generate( - self, prompts: torch.Tensor, max_new_tokens: int + self, + prompts: torch.Tensor, + max_new_tokens: int, + pixel_values: torch.Tensor = None, ) -> torch.IntTensor: """Returns the response token IDs to the given prompts tensor. @@ -83,35 +88,59 @@ def _init_kv_cache(self): def _forward_with_kv_cache( self, tokens: torch.Tensor, + input_pos: torch.Tensor, kv_cache: kv_utils.KVCache, + pixel_values: torch.Tensor, ) -> tuple[torch.Tensor, kv_utils.KVCache]: """Forwards the model and updates an external KV cache. Args: tokens (torch.Tensor): The input tokens to forward. + input_pos (torch.Tensor): The input positions to forward. kv_cache (KVCache): The KV cache to forward. + pixel_values (torch.Tensor): The input pixel values to forward. Returns: The output logits and the updated KV cache. """ - input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int) - output = self.model.forward(tokens, input_pos, kv_cache) + # Since the reauthored model doesn't include keyword arguments, pass + # pixel_values only when it is not None. Otherwise, it may raise an error. + if pixel_values is None: + output = self.model.forward(tokens, input_pos, kv_cache) + else: + output = self.model.forward( + tokens, input_pos, kv_cache, pixel_values=pixel_values + ) return output["logits"], output["kv_cache"] - def forward(self, tokens: torch.Tensor) -> torch.Tensor: - logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache()) + def forward( + self, tokens: torch.Tensor, pixel_values: torch.Tensor = None + ) -> torch.Tensor: + input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int) + logits, _ = self._forward_with_kv_cache( + tokens, input_pos, self._init_kv_cache(), pixel_values + ) return logits def generate( - self, prompts: torch.Tensor, max_new_tokens: int + self, + prompts: torch.Tensor, + max_new_tokens: int, + pixel_values: torch.Tensor = None, ) -> torch.IntTensor: input_ids = prompts[0].int().tolist() + tokens = torch.tensor([input_ids]) + input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int) kv_cache = self._init_kv_cache() for _ in range(max_new_tokens): - tokens = torch.tensor([input_ids]) - logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache) + logits, kv_cache = self._forward_with_kv_cache( + tokens, input_pos, kv_cache, pixel_values + ) generated_token = logits[0][-1].argmax().item() input_ids.append(generated_token) + tokens = torch.tensor([[generated_token]]) + input_pos = torch.tensor([len(input_ids) - 1]) + pixel_values = None # Pass only for the first time. return torch.tensor([input_ids])