Skip to content

Commit

Permalink
Transition to multiple prefill sequence lengths conversion as default.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702481271
  • Loading branch information
hheydary authored and copybara-github committed Dec 3, 2024
1 parent 10e9c59 commit b517bd0
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 133 deletions.
3 changes: 3 additions & 0 deletions ai_edge_torch/generative/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ Then export the model to TFLite with:
https://github.com/google-ai-edge/ai-edge-torch/blob/853301630f2b2455bd2e2f73d8a47e1a1534c91c/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py#L133-L139

Please note that using the `prefill` and `decode` method conventions are required for easy integration into the Mediapipe LLM Inference API.

To further optimize the on-device execution, a model can be exported with more than one prefill signature. As such, we use `prefill_{SEQ-LENS}` to export models with multiple prefill sequence lengths. During inference, the signature closest the input sequence length is used to minimize throwaway results.

<br/>

### End-to-End Inference Pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,11 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,11 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
12 changes: 6 additions & 6 deletions ai_edge_torch/generative/examples/llama/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -66,11 +66,11 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
15 changes: 9 additions & 6 deletions ai_edge_torch/generative/examples/openelm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,14 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = (
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)

converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
12 changes: 6 additions & 6 deletions ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,11 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
12 changes: 6 additions & 6 deletions ai_edge_torch/generative/examples/phi/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,11 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
14 changes: 8 additions & 6 deletions ai_edge_torch/generative/examples/qwen/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -68,11 +68,13 @@ def main(_):
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
model_size = _MODEL_SIZE.value.replace('.', '_')
output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = (
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
12 changes: 6 additions & 6 deletions ai_edge_torch/generative/examples/smollm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,11 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
'/tmp/',
'The tflite file path to export.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
1024,
'The maximum size of prefill input tensor.',
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
(8, 64, 128, 256, 512, 1024),
'List of the maximum sizes of prefill input tensors.',
)
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
'kv_cache_max_len',
Expand All @@ -55,11 +55,13 @@ def main(_):
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = (
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
)

Expand Down
Loading

0 comments on commit b517bd0

Please sign in to comment.