Skip to content

Commit

Permalink
Define its own class per Gen AI example as a useful debugging info.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704504130
  • Loading branch information
ai-edge-bot authored and copybara-github committed Dec 10, 2024
1 parent e029f9b commit 16c3f41
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@

import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
from torch import nn

TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD


class AmdLlama(model_builder.DecoderOnlyModel):
"""An AMD-Llama-135m model built from the Edge Generative API layers."""
pass


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for an AMD-Llama-135m model.
Expand Down Expand Up @@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
return config


def build_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=AmdLlama
)
11 changes: 8 additions & 3 deletions ai_edge_torch/generative/examples/gemma/gemma1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
from torch import nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
Expand All @@ -33,6 +34,11 @@
)


class Gemma1(model_builder.DecoderOnlyModel):
"""A Gemma1 model built from the Edge Generative API layers."""
pass


def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a Gemma 2B model.
Expand Down Expand Up @@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
return config


def build_2b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Gemma1,
)
18 changes: 8 additions & 10 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch
from torch import nn

Expand Down Expand Up @@ -133,7 +133,7 @@ def forward(
tokens: torch.Tensor,
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
export_config: Optional[ExportConfig] = None,
export_config: Optional[model_builder.ExportConfig] = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
_, seq_len = tokens.size()
assert self.config.max_seq_len >= seq_len, (
Expand Down Expand Up @@ -259,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:


def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
config = get_model_config_2b(**kwargs)
model = Gemma2(config)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Since embedding and lm-head use the same weight, we need to set strict
# to False.
loader.load(model, strict=False)
model.eval()
return model
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Gemma2,
)
28 changes: 11 additions & 17 deletions ai_edge_torch/generative/examples/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch

TENSOR_NAMES = model_builder.TENSOR_NAMES
Expand Down Expand Up @@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:

def _build_model(
checkpoint_path: str, config: cfg.ModelConfig
) -> model_builder.DecoderOnlyModel:
model = Llama(config)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Since embedding and lm-head use the same weight, we need to set strict
# to False.
loader.load(model, strict=False)
model.eval()
return model


def build_1b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
) -> torch.nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=config,
tensor_names=TENSOR_NAMES,
model_class=Llama,
)


def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
return _build_model(checkpoint_path, get_1b_model_config(**kwargs))


def build_3b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
11 changes: 8 additions & 3 deletions ai_edge_torch/generative/examples/openelm/openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
from torch import nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="transformer.layers.{}.ffn.proj_1",
Expand All @@ -34,6 +35,11 @@
)


class OpenELM(model_builder.DecoderOnlyModel):
"""An OpenELM model built from the Edge Generative API layers."""
pass


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for an OpenELM model.
Expand Down Expand Up @@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
return config


def build_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=OpenELM,
)
16 changes: 7 additions & 9 deletions ai_edge_torch/generative/examples/paligemma/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
return config


def build_decoder(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
decoder = Decoder(get_decoder_config(**kwargs))
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Loose the strictness because only decoder is being loaded.
loader.load(decoder, strict=False)
decoder.eval()
return decoder
def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_decoder_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Decoder,
)
11 changes: 8 additions & 3 deletions ai_edge_torch/generative/examples/phi/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
from torch import nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.fc1",
Expand All @@ -33,6 +34,11 @@
)


class Phi2(model_builder.DecoderOnlyModel):
"""A Phi-2 model built from the Edge Generative API layers."""
pass


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a Phi-2 model.
Expand Down Expand Up @@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
return config


def build_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Phi2,
)
16 changes: 7 additions & 9 deletions ai_edge_torch/generative/examples/phi/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
return config


def build_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
"""Instantiates the model instance and load checkpoint if provided."""
config = get_model_config(**kwargs)
model = Phi3_5Mini(config)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
loader.load(model)
model.eval()
return model
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Phi3_5Mini,
)
21 changes: 12 additions & 9 deletions ai_edge_torch/generative/examples/qwen/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@

import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
from torch import nn

TENSOR_NAMES = model_builder.TENSOR_NAMES


class Qwen(model_builder.DecoderOnlyModel):
"""A Qwen model built from the Edge Generative API layers."""
pass


def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a Qwen 2.5 3B model.
Expand Down Expand Up @@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
return config


def build_3b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_3b_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Qwen,
)


def build_1_5b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_1_5b_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Qwen,
)


def build_0_5b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_0_5b_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Qwen,
)
11 changes: 8 additions & 3 deletions ai_edge_torch/generative/examples/smollm/smollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@

import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
from torch import nn

TENSOR_NAMES = model_builder.TENSOR_NAMES


class SmolLM(model_builder.DecoderOnlyModel):
"""A SmolLM model built from the Edge Generative API layers."""
pass


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a SmolLM 135M model.
Expand Down Expand Up @@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
return config


def build_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=SmolLM,
)
11 changes: 8 additions & 3 deletions ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@

import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
from torch import nn

TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD


class TinyLlama(model_builder.DecoderOnlyModel):
"""A TinyLlama model built from the Edge Generative API layers."""
pass


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a TinyLlama model.
Expand Down Expand Up @@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
return config


def build_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=TinyLlama,
)
5 changes: 3 additions & 2 deletions ai_edge_torch/generative/utilities/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ def build_decoder_only_model(
checkpoint_path: str,
config: cfg.ModelConfig,
tensor_names: loading_utils.ModelLoader.TensorNames,
) -> DecoderOnlyModel:
transformer = DecoderOnlyModel(config)
model_class: type[nn.Module] = DecoderOnlyModel,
) -> nn.Module:
transformer = model_class(config)
loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
loader.load(
transformer, strict=not config.lm_head_share_weight_with_embedding
Expand Down

0 comments on commit 16c3f41

Please sign in to comment.