From 16c3f41292e34b61267f8e7ab7cc0425fd827898 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Mon, 9 Dec 2024 19:05:09 -0800 Subject: [PATCH] Define its own class per Gen AI example as a useful debugging info. PiperOrigin-RevId: 704504130 --- .../examples/amd_llama_135m/amd_llama_135m.py | 11 ++++++-- .../generative/examples/gemma/gemma1.py | 11 ++++++-- .../generative/examples/gemma/gemma2.py | 18 ++++++------ .../generative/examples/llama/llama.py | 28 ++++++++----------- .../generative/examples/openelm/openelm.py | 11 ++++++-- .../generative/examples/paligemma/decoder.py | 16 +++++------ ai_edge_torch/generative/examples/phi/phi2.py | 11 ++++++-- ai_edge_torch/generative/examples/phi/phi3.py | 16 +++++------ .../generative/examples/qwen/qwen.py | 21 ++++++++------ .../generative/examples/smollm/smollm.py | 11 ++++++-- .../examples/tiny_llama/tiny_llama.py | 11 ++++++-- .../generative/utilities/model_builder.py | 5 ++-- 12 files changed, 96 insertions(+), 74 deletions(-) diff --git a/ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py b/ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py index d08f9b09..bb871f1c 100644 --- a/ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +++ b/ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py @@ -17,10 +17,16 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder +from torch import nn TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD +class AmdLlama(model_builder.DecoderOnlyModel): + """An AMD-Llama-135m model built from the Edge Generative API layers.""" + pass + + def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for an AMD-Llama-135m model. @@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=AmdLlama ) diff --git a/ai_edge_torch/generative/examples/gemma/gemma1.py b/ai_edge_torch/generative/examples/gemma/gemma1.py index 719279ac..ce35dfd3 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma1.py +++ b/ai_edge_torch/generative/examples/gemma/gemma1.py @@ -18,6 +18,7 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils +from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.up_proj", @@ -33,6 +34,11 @@ ) +class Gemma1(model_builder.DecoderOnlyModel): + """A Gemma1 model built from the Edge Generative API layers.""" + pass + + def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Gemma 2B model. @@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_2b_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_model_config_2b(**kwargs), tensor_names=TENSOR_NAMES, + model_class=Gemma1, ) diff --git a/ai_edge_torch/generative/examples/gemma/gemma2.py b/ai_edge_torch/generative/examples/gemma/gemma2.py index 469caf0b..a934c835 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma2.py +++ b/ai_edge_torch/generative/examples/gemma/gemma2.py @@ -22,8 +22,8 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils -from ai_edge_torch.generative.utilities.model_builder import ExportConfig import torch from torch import nn @@ -133,7 +133,7 @@ def forward( tokens: torch.Tensor, input_pos: torch.Tensor, kv_cache: kv_utils.KVCache, - export_config: Optional[ExportConfig] = None, + export_config: Optional[model_builder.ExportConfig] = None, ) -> dict[torch.Tensor, kv_utils.KVCache]: _, seq_len = tokens.size() assert self.config.max_seq_len >= seq_len, ( @@ -259,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_model_config_2b(**kwargs) - model = Gemma2(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config_2b(**kwargs), + tensor_names=TENSOR_NAMES, + model_class=Gemma2, + ) diff --git a/ai_edge_torch/generative/examples/llama/llama.py b/ai_edge_torch/generative/examples/llama/llama.py index 1fde9f69..4d2dc10e 100644 --- a/ai_edge_torch/generative/examples/llama/llama.py +++ b/ai_edge_torch/generative/examples/llama/llama.py @@ -20,7 +20,6 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder -import ai_edge_torch.generative.utilities.loader as loading_utils import torch TENSOR_NAMES = model_builder.TENSOR_NAMES @@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: def _build_model( checkpoint_path: str, config: cfg.ModelConfig -) -> model_builder.DecoderOnlyModel: - model = Llama(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model - - -def build_1b_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +) -> torch.nn.Module: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=config, + tensor_names=TENSOR_NAMES, + model_class=Llama, + ) + + +def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module: return _build_model(checkpoint_path, get_1b_model_config(**kwargs)) -def build_3b_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module: return _build_model(checkpoint_path, get_3b_model_config(**kwargs)) diff --git a/ai_edge_torch/generative/examples/openelm/openelm.py b/ai_edge_torch/generative/examples/openelm/openelm.py index ade62506..2329a996 100644 --- a/ai_edge_torch/generative/examples/openelm/openelm.py +++ b/ai_edge_torch/generative/examples/openelm/openelm.py @@ -18,6 +18,7 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils +from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="transformer.layers.{}.ffn.proj_1", @@ -34,6 +35,11 @@ ) +class OpenELM(model_builder.DecoderOnlyModel): + """An OpenELM model built from the Edge Generative API layers.""" + pass + + def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for an OpenELM model. @@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=OpenELM, ) diff --git a/ai_edge_torch/generative/examples/paligemma/decoder.py b/ai_edge_torch/generative/examples/paligemma/decoder.py index c90e040d..58379798 100644 --- a/ai_edge_torch/generative/examples/paligemma/decoder.py +++ b/ai_edge_torch/generative/examples/paligemma/decoder.py @@ -130,12 +130,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_decoder( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: - decoder = Decoder(get_decoder_config(**kwargs)) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Loose the strictness because only decoder is being loaded. - loader.load(decoder, strict=False) - decoder.eval() - return decoder +def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_decoder_config(**kwargs), + tensor_names=TENSOR_NAMES, + model_class=Decoder, + ) diff --git a/ai_edge_torch/generative/examples/phi/phi2.py b/ai_edge_torch/generative/examples/phi/phi2.py index 25cfcf91..fb2b0e5d 100644 --- a/ai_edge_torch/generative/examples/phi/phi2.py +++ b/ai_edge_torch/generative/examples/phi/phi2.py @@ -18,6 +18,7 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils +from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.fc1", @@ -33,6 +34,11 @@ ) +class Phi2(model_builder.DecoderOnlyModel): + """A Phi-2 model built from the Edge Generative API layers.""" + pass + + def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Phi-2 model. @@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=Phi2, ) diff --git a/ai_edge_torch/generative/examples/phi/phi3.py b/ai_edge_torch/generative/examples/phi/phi3.py index 6958dad9..ae14ccdd 100644 --- a/ai_edge_torch/generative/examples/phi/phi3.py +++ b/ai_edge_torch/generative/examples/phi/phi3.py @@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_model(checkpoint_path: str, **kwargs) -> nn.Module: """Instantiates the model instance and load checkpoint if provided.""" - config = get_model_config(**kwargs) - model = Phi3_5Mini(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - loader.load(model) - model.eval() - return model + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + model_class=Phi3_5Mini, + ) diff --git a/ai_edge_torch/generative/examples/qwen/qwen.py b/ai_edge_torch/generative/examples/qwen/qwen.py index 0347758b..76202b88 100644 --- a/ai_edge_torch/generative/examples/qwen/qwen.py +++ b/ai_edge_torch/generative/examples/qwen/qwen.py @@ -17,10 +17,16 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder +from torch import nn TENSOR_NAMES = model_builder.TENSOR_NAMES +class Qwen(model_builder.DecoderOnlyModel): + """A Qwen model built from the Edge Generative API layers.""" + pass + + def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Qwen 2.5 3B model. @@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_3b_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_3b_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=Qwen, ) -def build_1_5b_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_1_5b_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=Qwen, ) -def build_0_5b_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_0_5b_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=Qwen, ) diff --git a/ai_edge_torch/generative/examples/smollm/smollm.py b/ai_edge_torch/generative/examples/smollm/smollm.py index 2e5942f8..1c03005a 100644 --- a/ai_edge_torch/generative/examples/smollm/smollm.py +++ b/ai_edge_torch/generative/examples/smollm/smollm.py @@ -17,10 +17,16 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder +from torch import nn TENSOR_NAMES = model_builder.TENSOR_NAMES +class SmolLM(model_builder.DecoderOnlyModel): + """A SmolLM model built from the Edge Generative API layers.""" + pass + + def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a SmolLM 135M model. @@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=SmolLM, ) diff --git a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py index 5de9c237..bdc029b4 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +++ b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py @@ -17,10 +17,16 @@ import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.utilities import model_builder +from torch import nn TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD +class TinyLlama(model_builder.DecoderOnlyModel): + """A TinyLlama model built from the Edge Generative API layers.""" + pass + + def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a TinyLlama model. @@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_model( - checkpoint_path: str, **kwargs -) -> model_builder.DecoderOnlyModel: +def build_model(checkpoint_path: str, **kwargs) -> nn.Module: return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=get_model_config(**kwargs), tensor_names=TENSOR_NAMES, + model_class=TinyLlama, ) diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index 96921770..2714eeaa 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -165,8 +165,9 @@ def build_decoder_only_model( checkpoint_path: str, config: cfg.ModelConfig, tensor_names: loading_utils.ModelLoader.TensorNames, -) -> DecoderOnlyModel: - transformer = DecoderOnlyModel(config) + model_class: type[nn.Module] = DecoderOnlyModel, +) -> nn.Module: + transformer = model_class(config) loader = loading_utils.ModelLoader(checkpoint_path, tensor_names) loader.load( transformer, strict=not config.lm_head_share_weight_with_embedding