Skip to content

Commit

Permalink
Move common converting logic of examples into utilities.converter.
Browse files Browse the repository at this point in the history
It would make it easier to add more examples.

PiperOrigin-RevId: 675750433
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 17, 2024
1 parent 47e20da commit 8593517
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 336 deletions.
92 changes: 36 additions & 56 deletions ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,49 @@
import os
import pathlib

import ai_edge_torch
from absl import app
from absl import flags
from ai_edge_torch.generative.examples.gemma import gemma2
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.quantize import quant_recipes
import torch
from ai_edge_torch.generative.utilities import converter

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
'/tmp/gemma2_q8_seq512_ekv1024.tflite',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
512,
'The maximum size of prefill input tensor.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
1024,
'The maximum size of KV cache buffer, including both prefill and decode.',
)
_QUANTIZE = flags.DEFINE_bool(
'quantize',
True,
'Whether the model should be quantized.',
)

def convert_gemma2_to_tflite(
checkpoint_path: str,
prefill_seq_len: int = 512,
kv_cache_max_len: int = 1024,
quantize: bool = True,
):
"""Converts a Gemma2 2B model to multi-signature tflite model.

Args:
checkpoint_path (str): The filepath to the model checkpoint, or directory
holding the checkpoint.
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
Defaults to 512.
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
including both prefill and decode. Defaults to 1024.
quantize (bool, optional): Whether the model should be quanized. Defaults
to True.
"""
def main(_):
pytorch_model = gemma2.build_2b_model(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill',
pytorch_model,
sample_kwargs={
'tokens': prefill_tokens,
'input_pos': prefill_input_pos,
'kv_cache': kv,
},
)
.signature(
'decode',
pytorch_model,
sample_kwargs={
'tokens': decode_token,
'input_pos': decode_input_pos,
'kv_cache': kv,
},
)
.convert(quant_config=quant_config)
)
quant_suffix = 'q8' if quantize else 'f32'
edge_model.export(
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=_TFLITE_PATH.value,
prefill_seq_len=_PREFILL_SEQ_LEN.value,
quantize=_QUANTIZE.value,
)


if __name__ == '__main__':
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
convert_gemma2_to_tflite(path)
app.run(main)
92 changes: 36 additions & 56 deletions ai_edge_torch/generative/examples/gemma/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,49 @@
import os
import pathlib

import ai_edge_torch
from absl import app
from absl import flags
from ai_edge_torch.generative.examples.gemma import gemma
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.quantize import quant_recipes
import torch
from ai_edge_torch.generative.utilities import converter

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
'/tmp/gemma_q8_seq512_ekv1024.tflite',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
512,
'The maximum size of prefill input tensor.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
1024,
'The maximum size of KV cache buffer, including both prefill and decode.',
)
_QUANTIZE = flags.DEFINE_bool(
'quantize',
True,
'Whether the model should be quantized.',
)

def convert_gemma_to_tflite(
checkpoint_path: str,
prefill_seq_len: int = 512,
kv_cache_max_len: int = 1024,
quantize: bool = True,
):
"""Converts a Gemma 2B model to multi-signature tflite model.

Args:
checkpoint_path (str): The filepath to the model checkpoint, or directory
holding the checkpoint.
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
Defaults to 512.
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
including both prefill and decode. Defaults to 1024.
quantize (bool, optional): Whether the model should be quanized. Defaults
to True.
"""
def main(_):
pytorch_model = gemma.build_2b_model(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill',
pytorch_model,
sample_kwargs={
'tokens': prefill_tokens,
'input_pos': prefill_input_pos,
'kv_cache': kv,
},
)
.signature(
'decode',
pytorch_model,
sample_kwargs={
'tokens': decode_token,
'input_pos': decode_input_pos,
'kv_cache': kv,
},
)
.convert(quant_config=quant_config)
)
quant_suffix = 'q8' if quantize else 'f32'
edge_model.export(
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=_TFLITE_PATH.value,
prefill_seq_len=_PREFILL_SEQ_LEN.value,
quantize=_QUANTIZE.value,
)


if __name__ == '__main__':
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
convert_gemma_to_tflite(path)
app.run(main)
92 changes: 36 additions & 56 deletions ai_edge_torch/generative/examples/openelm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,49 @@
import os
import pathlib

import ai_edge_torch
from absl import app
from absl import flags
from ai_edge_torch.generative.examples.openelm import openelm
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.quantize import quant_recipes
import torch
from ai_edge_torch.generative.utilities import converter

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
'/tmp/openelm_q8_seq512_ekv1024.tflite',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
512,
'The maximum size of prefill input tensor.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
1024,
'The maximum size of KV cache buffer, including both prefill and decode.',
)
_QUANTIZE = flags.DEFINE_bool(
'quantize',
True,
'Whether the model should be quantized.',
)

def convert_openelm_to_tflite(
checkpoint_path: str,
prefill_seq_len: int = 512,
kv_cache_max_len: int = 1024,
quantize: bool = True,
):
"""Converts OpenELM model to multi-signature tflite model.

Args:
checkpoint_path (str): The filepath to the model checkpoint, or directory
holding the checkpoint.
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
Defaults to 512.
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
including both prefill and decode. Defaults to 1024.
quantize (bool, optional): Whether the model should be quanized. Defaults
to True.
"""
def main(_):
pytorch_model = openelm.build_model(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill',
pytorch_model,
sample_kwargs={
'tokens': prefill_tokens,
'input_pos': prefill_input_pos,
'kv_cache': kv,
},
)
.signature(
'decode',
pytorch_model,
sample_kwargs={
'tokens': decode_token,
'input_pos': decode_input_pos,
'kv_cache': kv,
},
)
.convert(quant_config=quant_config)
)
quant_suffix = 'q8' if quantize else 'f32'
edge_model.export(
f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=_TFLITE_PATH.value,
prefill_seq_len=_PREFILL_SEQ_LEN.value,
quantize=_QUANTIZE.value,
)


if __name__ == '__main__':
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
convert_openelm_to_tflite(path)
app.run(main)
Loading

0 comments on commit 8593517

Please sign in to comment.