diff --git a/ai_edge_torch/generative/examples/openelm/openelm.py b/ai_edge_torch/generative/examples/openelm/openelm.py index 73ea678e..94d800c2 100644 --- a/ai_edge_torch/generative/examples/openelm/openelm.py +++ b/ai_edge_torch/generative/examples/openelm/openelm.py @@ -15,16 +15,12 @@ """Example of building an OpenELM model.""" -import os -import pathlib - from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg import ai_edge_torch.generative.utilities.loader as loading_utils -import numpy as np import torch from torch import nn @@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module: loader.load(model, strict=False) model.eval() return model - - -def define_and_run(checkpoint_path: str) -> None: - """Instantiates and runs an OpenELM model.""" - - current_dir = pathlib.Path(__file__).parent.resolve() - openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt") - kv_cache_max_len = 1024 - model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) - idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") - tokens[0, :4] = idx - input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int) - kv = kv_utils.KVCache.from_model_config(model.config) - output = model.forward(tokens, input_pos, kv) - assert torch.allclose( - openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05 - ) - - -if __name__ == "__main__": - input_checkpoint_path = os.path.join( - pathlib.Path.home(), "Downloads/llm_data/openelm" - ) - define_and_run(input_checkpoint_path) diff --git a/ai_edge_torch/generative/examples/openelm/openelm_lm_logits.pt b/ai_edge_torch/generative/examples/openelm/openelm_lm_logits.pt deleted file mode 100644 index 6957c1ae..00000000 Binary files a/ai_edge_torch/generative/examples/openelm/openelm_lm_logits.pt and /dev/null differ diff --git a/ai_edge_torch/generative/examples/openelm/verify.py b/ai_edge_torch/generative/examples/openelm/verify.py new file mode 100644 index 00000000..c28579fd --- /dev/null +++ b/ai_edge_torch/generative/examples/openelm/verify.py @@ -0,0 +1,61 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored OpenELM-3B model.""" + +import pathlib + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.openelm import openelm +from ai_edge_torch.generative.utilities import verifier +import transformers + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + "What is the meaning of life?", + "The input prompts to generate answers.", +) + + +def main(_): + checkpoint = "apple/OpenELM-3B" + verifier.log_msg("Loading the original model from", checkpoint) + original_model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint, trust_remote_code=True + ) + + # Locate the cached dir. + cached_config_file = transformers.utils.cached_file( + checkpoint, transformers.utils.CONFIG_NAME + ) + reauthored_checkpoint = pathlib.Path(cached_config_file).parent + verifier.log_msg("Building the reauthored model from", reauthored_checkpoint) + reauthored_model = openelm.build_model(reauthored_checkpoint) + + tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf" + verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint) + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint) + + verifier.verify_reauthored_model( + original_model=original_model, + reauthored_model=reauthored_model, + tokenizer=tokenizer, + prompts=_PROMPTS.value, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/phi/phi2.py b/ai_edge_torch/generative/examples/phi/phi2.py index 7879077b..4d753479 100644 --- a/ai_edge_torch/generative/examples/phi/phi2.py +++ b/ai_edge_torch/generative/examples/phi/phi2.py @@ -15,16 +15,12 @@ """Example of building a Phi-2 model.""" -import os -import pathlib - from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg import ai_edge_torch.generative.utilities.loader as loading_utils -import numpy as np import torch from torch import nn @@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: intermediate_size=10240, use_bias=True, ) - norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM) + norm_config = cfg.NormalizationConfig( + type=cfg.NormalizationType.LAYER_NORM, + use_input_shape=False, # Phi-2 does layer-norm with the weight shape. + ) block_config = cfg.TransformerBlockConfig( attn_config=attn_config, ff_config=ff_config, @@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module: loader.load(model) model.eval() return model - - -def define_and_run(checkpoint_path: str) -> None: - """Instantiates and runs a Phi-2 model.""" - - current_dir = pathlib.Path(__file__).parent.resolve() - phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt") - kv_cache_max_len = 1024 - model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) - idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") - tokens[0, :4] = idx - input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int) - kv = kv_utils.KVCache.from_model_config(model.config) - output = model.forward(tokens, input_pos, kv) - print("comparing with goldens..") - assert torch.allclose( - phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02 - ) - - -if __name__ == "__main__": - input_checkpoint_path = os.path.join( - pathlib.Path.home(), "Downloads/llm_data/phi2" - ) - define_and_run(input_checkpoint_path) diff --git a/ai_edge_torch/generative/examples/phi/phi2_lm_logits.pt b/ai_edge_torch/generative/examples/phi/phi2_lm_logits.pt deleted file mode 100644 index d45e8b76..00000000 Binary files a/ai_edge_torch/generative/examples/phi/phi2_lm_logits.pt and /dev/null differ diff --git a/ai_edge_torch/generative/examples/phi/verify.py b/ai_edge_torch/generative/examples/phi/verify.py new file mode 100644 index 00000000..ddecf75e --- /dev/null +++ b/ai_edge_torch/generative/examples/phi/verify.py @@ -0,0 +1,53 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored Phi-2 model.""" + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.phi import phi2 +from ai_edge_torch.generative.utilities import verifier +import kagglehub +import transformers + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + "What is the meaning of life?", + "The input prompts to generate answers.", +) + + +def main(_): + checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2") + verifier.log_msg("Loading the original model from", checkpoint) + original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint) + + verifier.log_msg("Building the reauthored model from", checkpoint) + reauthored_model = phi2.build_model(checkpoint) + + verifier.log_msg("Loading the tokenizer from", checkpoint) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + + verifier.verify_reauthored_model( + original_model=original_model, + reauthored_model=reauthored_model, + tokenizer=tokenizer, + prompts=_PROMPTS.value, + atol=1e-03, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/smollm/smollm.py b/ai_edge_torch/generative/examples/smollm/smollm.py index 1f4d1cbd..7f6ac53e 100644 --- a/ai_edge_torch/generative/examples/smollm/smollm.py +++ b/ai_edge_torch/generative/examples/smollm/smollm.py @@ -16,15 +16,10 @@ """Example of building a SmolLM model.""" import copy -import os -import pathlib from ai_edge_torch.generative.examples.tiny_llama import tiny_llama -from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.model_config as cfg import ai_edge_torch.generative.utilities.loader as loading_utils -import numpy as np -import torch from torch import nn TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES) @@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module: loader.load(model, strict=False) model.eval() return model - - -def define_and_run(checkpoint_path: str) -> None: - """Instantiates and runs a SmolLM model.""" - - current_dir = pathlib.Path(__file__).parent.resolve() - smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt") - kv_cache_max_len = 1024 - model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) - idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") - tokens[0, :4] = idx - input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int) - kv = kv_utils.KVCache.from_model_config(model.config) - output = model.forward(tokens, input_pos, kv) - assert torch.allclose( - smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05 - ) - - -if __name__ == "__main__": - input_checkpoint_path = os.path.join( - pathlib.Path.home(), "Downloads/llm_data/smollm" - ) - define_and_run(input_checkpoint_path) diff --git a/ai_edge_torch/generative/examples/smollm/smollm_lm_logits.pt b/ai_edge_torch/generative/examples/smollm/smollm_lm_logits.pt deleted file mode 100644 index 74eeda5a..00000000 Binary files a/ai_edge_torch/generative/examples/smollm/smollm_lm_logits.pt and /dev/null differ diff --git a/ai_edge_torch/generative/examples/smollm/verify.py b/ai_edge_torch/generative/examples/smollm/verify.py new file mode 100644 index 00000000..7a578841 --- /dev/null +++ b/ai_edge_torch/generative/examples/smollm/verify.py @@ -0,0 +1,59 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored SmolLM-135M model.""" + +import pathlib + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.smollm import smollm +from ai_edge_torch.generative.utilities import verifier +import transformers + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + "What is the meaning of life?", + "The input prompts to generate answers.", +) + + +def main(_): + checkpoint = "HuggingFaceTB/SmolLM-135M" + verifier.log_msg("Loading the original model from", checkpoint) + original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint) + + # Locate the cached dir. + cached_config_file = transformers.utils.cached_file( + checkpoint, transformers.utils.CONFIG_NAME + ) + reauthored_checkpoint = pathlib.Path(cached_config_file).parent + verifier.log_msg("Building the reauthored model from", reauthored_checkpoint) + reauthored_model = smollm.build_model(reauthored_checkpoint) + + verifier.log_msg("Loading the tokenizer from", checkpoint) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + + verifier.verify_reauthored_model( + original_model=original_model, + reauthored_model=reauthored_model, + tokenizer=tokenizer, + prompts=_PROMPTS.value, + atol=1e-04, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py index 53ce4a15..f8368ee4 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +++ b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py @@ -15,16 +15,12 @@ """Example of building a TinyLlama model.""" -import os -import pathlib - from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg import ai_edge_torch.generative.utilities.loader as loading_utils -import numpy as np import torch from torch import nn @@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module: loader.load(model) model.eval() return model - - -def define_and_run(checkpoint_path: str) -> None: - """Instantiates and runs a TinyLlama model.""" - - current_dir = pathlib.Path(__file__).parent.resolve() - tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt") - kv_cache_max_len = 1024 - model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) - idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") - tokens[0, :4] = idx - input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int) - kv = kv_utils.KVCache.from_model_config(model.config) - output = model.forward(tokens, input_pos, kv) - assert torch.allclose( - tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02 - ) - - -if __name__ == "__main__": - input_checkpoint_path = os.path.join( - pathlib.Path.home(), "Downloads/llm_data/tiny_llama" - ) - define_and_run(input_checkpoint_path) diff --git a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama_lm_logits.pt b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama_lm_logits.pt deleted file mode 100644 index c92a7861..00000000 Binary files a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama_lm_logits.pt and /dev/null differ diff --git a/ai_edge_torch/generative/examples/tiny_llama/verify.py b/ai_edge_torch/generative/examples/tiny_llama/verify.py new file mode 100644 index 00000000..990ce9d6 --- /dev/null +++ b/ai_edge_torch/generative/examples/tiny_llama/verify.py @@ -0,0 +1,61 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored TinyLlama-1.1B model.""" + +import pathlib + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.tiny_llama import tiny_llama +from ai_edge_torch.generative.utilities import verifier +import transformers + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + "Show me the program to add 2 and 3.", + "The input prompts to generate answers.", +) + + +def main(_): + checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + verifier.log_msg("Loading the original model from", checkpoint) + original_model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint, trust_remote_code=True + ) + + # Locate the cached dir. + cached_config_file = transformers.utils.cached_file( + checkpoint, transformers.utils.CONFIG_NAME + ) + reauthored_checkpoint = pathlib.Path(cached_config_file).parent + verifier.log_msg("Building the reauthored model from", reauthored_checkpoint) + reauthored_model = tiny_llama.build_model(reauthored_checkpoint) + + verifier.log_msg("Loading the tokenizer from", checkpoint) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + + verifier.verify_reauthored_model( + original_model=original_model, + reauthored_model=reauthored_model, + tokenizer=tokenizer, + prompts=_PROMPTS.value, + atol=1e-04, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/layers/builder.py b/ai_edge_torch/generative/layers/builder.py index 28ff402b..10d64c7b 100644 --- a/ai_edge_torch/generative/layers/builder.py +++ b/ai_edge_torch/generative/layers/builder.py @@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig): zero_centered_gamma=config.zero_centered, ) elif config.type == cfg.NormalizationType.LAYER_NORM: - return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb) + return normalization.LayerNorm( + dim, config.epsilon, config.enable_hlfb, config.use_input_shape + ) elif config.type == cfg.NormalizationType.GROUP_NORM: return normalization.GroupNorm( config.group_num, dim, config.epsilon, config.enable_hlfb diff --git a/ai_edge_torch/generative/layers/model_config.py b/ai_edge_torch/generative/layers/model_config.py index 07bb8da2..ab9a7aa1 100644 --- a/ai_edge_torch/generative/layers/model_config.py +++ b/ai_edge_torch/generative/layers/model_config.py @@ -69,6 +69,9 @@ class NormalizationConfig: zero_centered: bool = False # Number of groups used in group normalization. group_num: Optional[float] = None + # Whether to use the input shape to determine the dimension of normalization + # when type is LAYER_NORM. + use_input_shape: bool = True @dataclass diff --git a/ai_edge_torch/generative/layers/normalization.py b/ai_edge_torch/generative/layers/normalization.py index 3a4c90cc..a506135b 100644 --- a/ai_edge_torch/generative/layers/normalization.py +++ b/ai_edge_torch/generative/layers/normalization.py @@ -78,7 +78,7 @@ def __init__( group_num (int): Number of groups to separate the channels into. dim (int): Dimension of the input tensor. eps (float): A small float value to ensure numerical stability (default: - 1e-6). + 1e-5). enable_hlfb (bool): Whether to convert this normalization into a single op. """ @@ -112,7 +112,13 @@ def forward(self, x): class LayerNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False): + def __init__( + self, + dim: int, + eps: float = 1e-5, + enable_hlfb: bool = False, + use_input_shape: bool = True, + ): """Initialize the LayerNorm layer. Args: @@ -121,9 +127,12 @@ def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False): 1e-6). enable_hlfb (bool): Whether to convert this normalization into a single op. + use_input_shape (bool): Whether to use the input shape to determine the + dimension of normalization (default: True). """ super().__init__() self.enable_hlfb = enable_hlfb + self.use_input_shape = use_input_shape self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) self.bias = torch.nn.Parameter(torch.ones(dim)) @@ -139,19 +148,18 @@ def forward(self, x): """ if self.enable_hlfb: return layer_norm_with_hlfb( - x, - self.weight, - self.bias, - self.eps, + x, self.weight, self.bias, self.eps, self.use_input_shape ) + + if self.use_input_shape: + normalized_shape = x.shape + weight = self.weight.broadcast_to(x.shape) + bias = self.bias.broadcast_to(x.shape) else: - return F.layer_norm( - x, - x.shape, - self.weight.broadcast_to(x.shape), - self.bias.broadcast_to(x.shape), - self.eps, - ) + normalized_shape = self.weight.shape + weight = self.weight + bias = self.bias + return F.layer_norm(x, normalized_shape, weight, bias, self.eps) def group_norm_with_hlfb( @@ -193,6 +201,7 @@ def layer_norm_with_hlfb( w: torch.Tensor, b: torch.Tensor, eps: float, + use_input_shape: bool, ): """Layer Normalization with high-level function boundary enabled. @@ -201,18 +210,20 @@ def layer_norm_with_hlfb( w (torch.Tensor): The weight tensor for the normalization. b (torch.Tensor): The bias tensor for the normalization. eps (float): A small float value to ensure numerical stability. + use_input_shape (bool): Whether to use the input shape to determine the + dimension of normalization. Returns: The output tensor of Layer Normalization. """ builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps}) x, w, b = builder.mark_inputs(x, w, b) - y = F.layer_norm( - x, - x.shape, - weight=w.broadcast_to(x.shape), - bias=b.broadcast_to(x.shape), - eps=eps, - ) + if use_input_shape: + normalized_shape = x.shape + w = w.broadcast_to(x.shape) + b = b.broadcast_to(x.shape) + else: + normalized_shape = w.shape + y = F.layer_norm(x, normalized_shape, w, b, eps=eps) y = builder.mark_outputs(y) return y diff --git a/ai_edge_torch/generative/utilities/verifier.py b/ai_edge_torch/generative/utilities/verifier.py new file mode 100644 index 00000000..66f62603 --- /dev/null +++ b/ai_edge_torch/generative/utilities/verifier.py @@ -0,0 +1,200 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Common utility functions to verify the reauthored models.""" + +import datetime +from typing import List + +from ai_edge_torch.generative.layers import kv_cache as kv_utils +import numpy as np +import torch + + +def log_msg(*args): + print("[%s]" % datetime.datetime.now(), *args) + + +def forward( + model: torch.nn.Module, + tokens: torch.Tensor, + kv_cache: kv_utils.KVCache, +) -> tuple[torch.Tensor, kv_utils.KVCache]: + """Forwards the model reauthored with ai_edge_torch Generative API. + + Args: + model (torch.nn.Module): The model to forward. It should be a model built + with ai_edge_torch Generative API. + tokens (torch.Tensor): The input tokens to forward. + kv_cache (KVCache): The KV cache to forward. + + Returns: + The output logits and the updated KV cache. + """ + input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int) + output = model.forward(tokens, input_pos, kv_cache) + return output["logits"], output["kv_cache"] + + +def generate( + model: torch.nn.Module, prompts: torch.Tensor, response_len: int +) -> torch.Tensor: + """Generates the response to the prompts. + + It appends tokens output by the model to the prompts and feeds them back to + the model up to decode_len. + + Args: + model (torch.nn.Module): The model to generate. It should be a model built + with ai_edge_torch Generative API. + prompts (torch.Tensor): The prompts to generate. + response_len (int): The number of tokens to generate. + + Returns: + The generated tokens. + """ + input_ids = prompts[0].int().tolist() + kv_cache = kv_utils.KVCache.from_model_config(model.config) + for _ in range(response_len - len(input_ids)): + logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache) + generated_token = logits[0][-1].argmax().item() + input_ids.append(generated_token) + return torch.tensor([input_ids]) + + +def verify_with_input_ids( + original_model: torch.nn.Module, + reauthored_model: torch.nn.Module, + input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(), + kv_cache_max_len: int = 1024, + rtol: float = 1e-05, + atol: float = 1e-05, +) -> bool: + """Verifies if the model reauthored generates the same output of the oringal. + + It compares only one outputs from the original and the reauthored model. + + Args: + original_model (torch.nn.Module): The original model. + reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch + Generative API. + input_ids (torch.Tensor): The input token IDs to forward. + kv_cache_max_len (int): The maximum sequence length of the KV cache. + rtol (float): The relative tolerance for the comparison. + atol (float): The absolute tolerance for the comparison. + + Returns: + True if the model reauthored generates the same output of the original. + """ + tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") + input_ids_len = input_ids.shape[1] + tokens[0, :input_ids_len] = input_ids + + log_msg("Forwarding the original model...") + outputs_original = original_model.forward(tokens) + logits_original = outputs_original.logits[0, input_ids_len - 1, :] + log_msg("logits_original: ", logits_original) + + log_msg("Forwarding the reauthored model...") + kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config) + outputs_reauthored = forward(reauthored_model, tokens, kv_cache) + logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :] + log_msg("logits_reauthored:", logits_reauthored) + + return torch.allclose( + logits_original, logits_reauthored, rtol=rtol, atol=atol + ) + + +def verify_model_with_prompts( + original_model: torch.nn.Module, + reauthored_model: torch.nn.Module, + tokenizer: torch.nn.Module, + prompts: str, +) -> bool: + """Verifies if the model reauthored generates the same answer of the oringal. + + It compares an answer, i.e. multiple continuous outputs generated by the + original and the reauthored model. + + Args: + original_model (torch.nn.Module): The original model. + reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch + Generative API. + tokenizer (torch.nn.Module): The tokenizer. + prompts (str): The input prompts to generate answers. + + Returns: + True if the model reauthored generates the same answer of the original. + """ + prompt_tokens = tokenizer.encode(prompts, return_tensors="pt") + + log_msg("Generating answer with the original model...") + outputs_original = original_model.generate(prompt_tokens) + response_original = tokenizer.decode(outputs_original[0]) + log_msg("outputs_from_original_model: [[", response_original, "]]") + + log_msg("Generating answer with the reauthored model...") + generate_len = len(outputs_original[0]) + outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len) + response_reauthored = tokenizer.decode(outputs_reauthored[0]) + log_msg("outputs from reauthored model: [[", response_reauthored, "]]") + + return response_original == response_reauthored + + +def verify_reauthored_model( + original_model: torch.nn.Module, + reauthored_model: torch.nn.Module, + tokenizer: torch.nn.Module, + prompts: List[str], + rtol: float = 1e-05, + atol: float = 1e-05, +): + """Verifies the reauthored model against the original model. + + It verifies the reauthored model with two methods: + 1. It compares the output of the original and the reauthored model with an + arbitrary input. + 2. It compares the answer generated by the original and the reauthored model + with a prompt. + + It prints out "PASS" or "FAILED" to the console. + + Args: + original_model (torch.nn.Module): The original model. + reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch + Generative API. + tokenizer (torch.nn.Module): The tokenizer. + prompts (List[str]): List of the input prompts to generate answers. + rtol (float): The relative tolerance for the comparison. + atol (float): The absolute tolerance for the comparison. + """ + log_msg("Verifying the reauthored model with an arbitrary input...") + if verify_with_input_ids( + original_model, reauthored_model, rtol=rtol, atol=atol + ): + log_msg("PASS") + else: + log_msg("FAILED") + + for p in prompts: + log_msg("Verifying the reauthored model with prompts:", p) + if verify_model_with_prompts( + original_model, reauthored_model, tokenizer, p + ): + log_msg("PASS") + else: + log_msg("FAILED") diff --git a/requirements.txt b/requirements.txt index 9e38042b..4a73db07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ ai-edge-quantizer-nightly scipy numpy tabulate -safetensors \ No newline at end of file +safetensors +kagglehub +transformers \ No newline at end of file