Skip to content

Commit

Permalink
Verify some examples re-authored from transformers package.
Browse files Browse the repository at this point in the history
- Compare outputs from the re-authored model against ones from the original model on the fly
  * logits from an arbitrary input token
  * answer from a prompt string
- Add a flag to change normalized shape
- Update re-authored phi-2 model
- Use absl.app to run verify.py files
- Remove unnecessary lm_logits.pt files

PiperOrigin-RevId: 676022650
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 18, 2024
1 parent 8593517 commit 586e9ff
Show file tree
Hide file tree
Showing 17 changed files with 478 additions and 141 deletions.
29 changes: 0 additions & 29 deletions ai_edge_torch/generative/examples/openelm/openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@

"""Example of building an OpenELM model."""

import os
import pathlib

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils
import numpy as np
import torch
from torch import nn

Expand Down Expand Up @@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
loader.load(model, strict=False)
model.eval()
return model


def define_and_run(checkpoint_path: str) -> None:
"""Instantiates and runs an OpenELM model."""

current_dir = pathlib.Path(__file__).parent.resolve()
openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
assert torch.allclose(
openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
)


if __name__ == "__main__":
input_checkpoint_path = os.path.join(
pathlib.Path.home(), "Downloads/llm_data/openelm"
)
define_and_run(input_checkpoint_path)
Binary file not shown.
61 changes: 61 additions & 0 deletions ai_edge_torch/generative/examples/openelm/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Verifies the reauthored OpenELM-3B model."""

import pathlib

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.openelm import openelm
from ai_edge_torch.generative.utilities import verifier
import transformers

_PROMPTS = flags.DEFINE_multi_string(
"prompts",
"What is the meaning of life?",
"The input prompts to generate answers.",
)


def main(_):
checkpoint = "apple/OpenELM-3B"
verifier.log_msg("Loading the original model from", checkpoint)
original_model = transformers.AutoModelForCausalLM.from_pretrained(
checkpoint, trust_remote_code=True
)

# Locate the cached dir.
cached_config_file = transformers.utils.cached_file(
checkpoint, transformers.utils.CONFIG_NAME
)
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
reauthored_model = openelm.build_model(reauthored_checkpoint)

tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)

verifier.verify_reauthored_model(
original_model=original_model,
reauthored_model=reauthored_model,
tokenizer=tokenizer,
prompts=_PROMPTS.value,
)


if __name__ == "__main__":
app.run(main)
35 changes: 4 additions & 31 deletions ai_edge_torch/generative/examples/phi/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@

"""Example of building a Phi-2 model."""

import os
import pathlib

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils
import numpy as np
import torch
from torch import nn

Expand Down Expand Up @@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
intermediate_size=10240,
use_bias=True,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
norm_config = cfg.NormalizationConfig(
type=cfg.NormalizationType.LAYER_NORM,
use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
)
block_config = cfg.TransformerBlockConfig(
attn_config=attn_config,
ff_config=ff_config,
Expand Down Expand Up @@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
loader.load(model)
model.eval()
return model


def define_and_run(checkpoint_path: str) -> None:
"""Instantiates and runs a Phi-2 model."""

current_dir = pathlib.Path(__file__).parent.resolve()
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
print("comparing with goldens..")
assert torch.allclose(
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
)


if __name__ == "__main__":
input_checkpoint_path = os.path.join(
pathlib.Path.home(), "Downloads/llm_data/phi2"
)
define_and_run(input_checkpoint_path)
Binary file not shown.
53 changes: 53 additions & 0 deletions ai_edge_torch/generative/examples/phi/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Verifies the reauthored Phi-2 model."""

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.phi import phi2
from ai_edge_torch.generative.utilities import verifier
import kagglehub
import transformers

_PROMPTS = flags.DEFINE_multi_string(
"prompts",
"What is the meaning of life?",
"The input prompts to generate answers.",
)


def main(_):
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
verifier.log_msg("Loading the original model from", checkpoint)
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)

verifier.log_msg("Building the reauthored model from", checkpoint)
reauthored_model = phi2.build_model(checkpoint)

verifier.log_msg("Loading the tokenizer from", checkpoint)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)

verifier.verify_reauthored_model(
original_model=original_model,
reauthored_model=reauthored_model,
tokenizer=tokenizer,
prompts=_PROMPTS.value,
atol=1e-03,
)


if __name__ == "__main__":
app.run(main)
30 changes: 0 additions & 30 deletions ai_edge_torch/generative/examples/smollm/smollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@
"""Example of building a SmolLM model."""

import copy
import os
import pathlib

from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils
import numpy as np
import torch
from torch import nn

TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
Expand Down Expand Up @@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
loader.load(model, strict=False)
model.eval()
return model


def define_and_run(checkpoint_path: str) -> None:
"""Instantiates and runs a SmolLM model."""

current_dir = pathlib.Path(__file__).parent.resolve()
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
assert torch.allclose(
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
)


if __name__ == "__main__":
input_checkpoint_path = os.path.join(
pathlib.Path.home(), "Downloads/llm_data/smollm"
)
define_and_run(input_checkpoint_path)
Binary file not shown.
59 changes: 59 additions & 0 deletions ai_edge_torch/generative/examples/smollm/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Verifies the reauthored SmolLM-135M model."""

import pathlib

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.smollm import smollm
from ai_edge_torch.generative.utilities import verifier
import transformers

_PROMPTS = flags.DEFINE_multi_string(
"prompts",
"What is the meaning of life?",
"The input prompts to generate answers.",
)


def main(_):
checkpoint = "HuggingFaceTB/SmolLM-135M"
verifier.log_msg("Loading the original model from", checkpoint)
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)

# Locate the cached dir.
cached_config_file = transformers.utils.cached_file(
checkpoint, transformers.utils.CONFIG_NAME
)
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
reauthored_model = smollm.build_model(reauthored_checkpoint)

verifier.log_msg("Loading the tokenizer from", checkpoint)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)

verifier.verify_reauthored_model(
original_model=original_model,
reauthored_model=reauthored_model,
tokenizer=tokenizer,
prompts=_PROMPTS.value,
atol=1e-04,
)


if __name__ == "__main__":
app.run(main)
29 changes: 0 additions & 29 deletions ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@

"""Example of building a TinyLlama model."""

import os
import pathlib

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils
import numpy as np
import torch
from torch import nn

Expand Down Expand Up @@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
loader.load(model)
model.eval()
return model


def define_and_run(checkpoint_path: str) -> None:
"""Instantiates and runs a TinyLlama model."""

current_dir = pathlib.Path(__file__).parent.resolve()
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
assert torch.allclose(
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
)


if __name__ == "__main__":
input_checkpoint_path = os.path.join(
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
)
define_and_run(input_checkpoint_path)
Binary file not shown.
Loading

0 comments on commit 586e9ff

Please sign in to comment.