Skip to content

Commit

Permalink
Prefill signatures don't do head FC or output logits
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704378247
  • Loading branch information
talumbau authored and copybara-github committed Dec 9, 2024
1 parent 86e07ca commit e029f9b
Show file tree
Hide file tree
Showing 18 changed files with 111 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -61,6 +62,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.gemma import gemma1
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -61,6 +62,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.gemma import gemma2
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -61,6 +62,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
9 changes: 9 additions & 0 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch
from torch import nn

Expand Down Expand Up @@ -132,6 +133,7 @@ def forward(
tokens: torch.Tensor,
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
export_config: Optional[ExportConfig] = None,
) -> dict[torch.Tensor, kv_utils.KVCache]:
_, seq_len = tokens.size()
assert self.config.max_seq_len >= seq_len, (
Expand Down Expand Up @@ -162,6 +164,13 @@ def forward(
updated_kv_entires.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))

if export_config is not None:
if (
torch.numel(input_pos) > 1
and not export_config.output_logits_on_prefill
):
return {"kv_cache": updated_kv_cache}

x = self.final_norm(x)
res = self.lm_head(x) # (b, t, vocab_size)
if self.config.final_logit_softcap is not None:
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/llama/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.llama import llama
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_MODEL_SIZE = flags.DEFINE_enum(
'model_size',
Expand Down Expand Up @@ -72,6 +73,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.openelm import openelm
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -64,6 +65,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from absl import flags
from ai_edge_torch.generative.examples.paligemma import paligemma
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch

_CHECKPOINT_PATH = flags.DEFINE_string(
Expand Down Expand Up @@ -73,6 +74,7 @@ def main(_):
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
quantize=_QUANTIZE.value,
config=pytorch_model.config.decoder_config,
export_config=ExportConfig(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.phi import phi3
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -61,6 +62,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/phi/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.phi import phi2
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -61,6 +62,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/qwen/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.qwen import qwen
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_MODEL_SIZE = flags.DEFINE_enum(
'model_size',
Expand Down Expand Up @@ -76,6 +77,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
3 changes: 3 additions & 0 deletions ai_edge_torch/generative/examples/smollm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.smollm import smollm
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -54,13 +55,15 @@ def main(_):
pytorch_model = smollm.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)

quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""

from typing import Tuple
from typing import Optional, Tuple

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch
from torch import nn

Expand Down Expand Up @@ -62,6 +63,7 @@ def forward(
tokens: torch.Tensor,
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
export_config: Optional[ExportConfig] = None,
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
x = self.tok_embedding(tokens)
cos, sin = self.rope_cache
Expand All @@ -77,8 +79,16 @@ def forward(
if kv_entry:
updated_kv_entires.append(kv_entry)

x = self.final_norm(x)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))

if export_config is not None:
if (
torch.numel(input_pos) > 1
and not export_config.output_logits_on_prefill
):
return {'kv_cache': updated_kv_cache}

x = self.final_norm(x)
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
Expand Down Expand Up @@ -63,6 +64,7 @@ def main(_):
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
export_config=ExportConfig(),
)


Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/tools/batch_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ai_edge_torch.generative.examples.smollm import smollm
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch

_CHECKPOINT_ROOT_PATH = flags.DEFINE_string(
Expand Down Expand Up @@ -281,6 +282,7 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
tflite_path=os.path.join(config.tflite_output_path, output_filename),
prefill_seq_len=config.prefill_seq_lens,
quantize=True if precision == ExportPrecision.INT8 else False,
export_config=ExportConfig(),
)
logging.info("Successfully converted model: %s", output_filename)

Expand Down
29 changes: 25 additions & 4 deletions ai_edge_torch/generative/utilities/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,28 @@

"""Common utility functions for model conversion."""

from typing import Union
from functools import partial
from typing import Any, Union

from ai_edge_torch._convert import converter as converter_utils
import ai_edge_torch.generative.layers.kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.quantize import quant_recipes
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
import torch
import torch.nn as nn


class ExportableModule(torch.nn.Module):

def __init__(self, module, **extra_kwargs):
super().__init__()
self.module = module
self.extra_kwargs = extra_kwargs

def forward(self, *export_args, **export_kwargs):
full_kwargs = {**export_kwargs, **self.extra_kwargs}
return self.module(*export_args, **full_kwargs)


def convert_to_tflite(
Expand All @@ -31,6 +46,7 @@ def convert_to_tflite(
pixel_values_size: torch.Size = None,
quantize: bool = True,
config: cfg.ModelConfig = None,
export_config: ExportConfig = None,
):
"""Converts a nn.Module model to multi-signature tflite model.
Expand Down Expand Up @@ -97,6 +113,11 @@ def convert_to_tflite(
)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None

# For export, we create a module that captures any non-exportable,
# arugments, e.g. the generation config object.
mod = ExportableModule(pytorch_model, export_config=export_config)

converter = converter_utils.Converter()
for i in range(len(prefill_seq_lens)):
prefill_seq_len = prefill_seq_lens[i]
Expand All @@ -108,7 +129,7 @@ def convert_to_tflite(
prefill_signature_name = f'prefill_{prefill_seq_len}'
converter.add_signature(
prefill_signature_name,
pytorch_model,
mod,
sample_kwargs={
'tokens': prefill_tokens,
'input_pos': prefill_input_pos,
Expand All @@ -118,7 +139,7 @@ def convert_to_tflite(
if prefill_pixel_values is not None:
converter.add_signature(
prefill_signature_name + '_pixel',
pytorch_model,
mod,
sample_kwargs={
'tokens': prefill_tokens,
'input_pos': prefill_input_pos,
Expand All @@ -129,7 +150,7 @@ def convert_to_tflite(

converter.add_signature(
'decode',
pytorch_model,
mod,
sample_kwargs={
'tokens': decode_token,
'input_pos': decode_input_pos,
Expand Down
Loading

0 comments on commit e029f9b

Please sign in to comment.